package ai.djl.training.listener;

import ai.djl.Model;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListener;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: classes.dex */
public class CheckpointsTrainingListener implements TrainingListener {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) CheckpointsTrainingListener.class);
    private int epoch;
    private Consumer<Trainer> onSaveModel;
    private String outputDir;
    private String overrideModelName;
    private int step;

    public CheckpointsTrainingListener(String str) {
        this(str, null, -1);
    }

    public CheckpointsTrainingListener(String str, String str2) {
        this(str, str2, -1);
    }

    public CheckpointsTrainingListener(String str, String str2, int i) {
        this.outputDir = str;
        this.step = i;
        if (str == null) {
            throw new IllegalArgumentException("Can not save checkpoint without specifying an output directory");
        }
        this.overrideModelName = str2;
    }

    public String getOverrideModelName() {
        return this.overrideModelName;
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onEpoch(Trainer trainer) {
        int i;
        int i2 = this.epoch + 1;
        this.epoch = i2;
        if (this.outputDir != null && (i = this.step) > 0 && i2 % i == 0) {
            saveModel(trainer);
        }
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onTrainingBegin(Trainer trainer) {
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onTrainingEnd(Trainer trainer) {
        int i = this.step;
        if (i == -1 || this.epoch % i != 0) {
            saveModel(trainer);
        }
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
    }

    protected void saveModel(Trainer trainer) {
        Model model = trainer.getModel();
        String name = model.getName();
        String str = this.overrideModelName;
        if (str != null) {
            name = str;
        }
        try {
            model.setProperty("Epoch", String.valueOf(this.epoch));
            Consumer<Trainer> consumer = this.onSaveModel;
            if (consumer != null) {
                consumer.accept(trainer);
            }
            model.save(Paths.get(this.outputDir, new String[0]), name);
        } catch (IOException e) {
            logger.error("Failed to save checkpoint", (Throwable) e);
        }
    }

    public void setOverrideModelName(String str) {
        this.overrideModelName = str;
    }

    public void setSaveModelCallback(Consumer<Trainer> consumer) {
        this.onSaveModel = consumer;
    }
}
