package ai.djl.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.EpochTrainingListener;
import ai.djl.training.listener.EvaluatorTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Predicate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: classes.dex */
public class Trainer implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) Trainer.class);
    private Device[] devices;
    private List<Evaluator> evaluators;
    private boolean gradientsChecked;
    private List<TrainingListener> listeners;
    private Loss loss;
    private NDManager manager;
    private Metrics metrics;
    private Model model;
    private ParameterStore parameterStore;

    public Trainer(Model model, TrainingConfig trainingConfig) {
        this.model = model;
        this.manager = model.getNDManager().newSubManager();
        this.devices = trainingConfig.getDevices();
        Loss lossFunction = trainingConfig.getLossFunction();
        this.loss = lossFunction;
        if (lossFunction == null) {
            throw new IllegalArgumentException("You must specify a loss for the trainer");
        }
        ArrayList arrayList = new ArrayList(trainingConfig.getEvaluators());
        this.evaluators = arrayList;
        arrayList.add(this.loss);
        LocalParameterServer localParameterServer = new LocalParameterServer(trainingConfig.getOptimizer());
        ParameterStore parameterStore = new ParameterStore(this.manager, false);
        this.parameterStore = parameterStore;
        parameterStore.setParameterServer(localParameterServer, this.devices);
        this.listeners = trainingConfig.getTrainingListeners();
        notifyListeners(new Consumer() { // from class: ai.djl.training.-$$Lambda$Trainer$AWDJIImcoH_abyenU1r2ZmOyszg
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                Trainer.this.lambda$new$0$Trainer((TrainingListener) obj);
            }
        });
    }

    private void checkGradients() {
        final ArrayList arrayList = new ArrayList();
        this.model.getBlock().getParameters().values().stream().filter(new Predicate() { // from class: ai.djl.training.-$$Lambda$ja2ArbVtEPdkuZWiYohJhcfMAK0
            @Override // java.util.function.Predicate
            public final boolean test(Object obj) {
                return ((Parameter) obj).requireGradient();
            }
        }).forEach(new Consumer() { // from class: ai.djl.training.-$$Lambda$Trainer$wlYfRWqdeFwag_DVZofLhMnbLsI
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                Trainer.this.lambda$checkGradients$3$Trainer(arrayList, (Parameter) obj);
            }
        });
        NDList nDList = new NDList((NDArray[]) arrayList.stream().map(new Function() { // from class: ai.djl.training.-$$Lambda$x93fhBqRn3BHDPcs5_Y4NgM9rqQ
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return ((NDArray) obj).sum();
            }
        }).toArray(new IntFunction() { // from class: ai.djl.training.-$$Lambda$Trainer$BhtmClcDtZghQTpbemEkF-dguZk
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return Trainer.lambda$checkGradients$4(i);
            }
        }));
        NDArray stack = NDArrays.stack(nDList);
        nDList.close();
        NDArray sum = stack.sum();
        float[] floatArray = sum.toFloatArray();
        sum.close();
        stack.close();
        float f = 0.0f;
        for (float f2 : floatArray) {
            f += f2;
        }
        if (f == 0.0f) {
            throw new IllegalStateException("Gradient values are all zeros, please call gradientCollector.backward() onyour target NDArray (usually loss), before calling step() ");
        }
        this.gradientsChecked = true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDArray[] lambda$checkGradients$4(int i) {
        return new NDArray[i];
    }

    public void addMetric(String str, long j) {
        Metrics metrics = this.metrics;
        if (metrics == null || j <= 0) {
            return;
        }
        metrics.addMetric(str, Long.valueOf(System.nanoTime() - j));
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        notifyListeners(new Consumer() { // from class: ai.djl.training.-$$Lambda$Trainer$zwLa2zXo9k4wg53f5dib3N16d9o
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                Trainer.this.lambda$close$2$Trainer((TrainingListener) obj);
            }
        });
        this.parameterStore.sync();
        this.manager.close();
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            Logger logger2 = logger;
            if (logger2.isDebugEnabled()) {
                logger2.warn("Model was not closed explicitly: {}", getClass().getSimpleName());
            }
            close();
        }
        super.finalize();
    }

    public NDList forward(NDList nDList) {
        long nanoTime = System.nanoTime();
        try {
            return this.model.getBlock().forward(this.parameterStore, nDList, true);
        } finally {
            addMetric("forward", nanoTime);
        }
    }

    public NDList forward(NDList nDList, NDList nDList2) {
        long nanoTime = System.nanoTime();
        try {
            return this.model.getBlock().forward(this.parameterStore, nDList, nDList2, (PairList<String, Object>) null);
        } finally {
            addMetric("forward", nanoTime);
        }
    }

    public Device[] getDevices() {
        return this.devices;
    }

    public List<Evaluator> getEvaluators() {
        return this.evaluators;
    }

    public Loss getLoss() {
        return this.loss;
    }

    public NDManager getManager() {
        return this.manager;
    }

    public Metrics getMetrics() {
        return this.metrics;
    }

    public Model getModel() {
        return this.model;
    }

    public TrainingResult getTrainingResult() {
        TrainingResult trainingResult = new TrainingResult();
        for (TrainingListener trainingListener : this.listeners) {
            if (trainingListener instanceof EpochTrainingListener) {
                trainingResult.setEpoch(((EpochTrainingListener) trainingListener).getNumEpochs());
            } else if (trainingListener instanceof EvaluatorTrainingListener) {
                trainingResult.setEvaluations(((EvaluatorTrainingListener) trainingListener).getLatestEvaluations());
            }
        }
        return trainingResult;
    }

    public void initialize(Shape... shapeArr) {
        this.model.getBlock().initialize(this.model.getNDManager(), this.model.getDataType(), shapeArr);
        this.model.getBlock().getParameters().forEach(new Consumer() { // from class: ai.djl.training.-$$Lambda$Trainer$e4IW9xYpAHf2tc8MT0AQROdOiek
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                Trainer.this.lambda$initialize$1$Trainer((Pair) obj);
            }
        });
    }

    public Iterable<Batch> iterateDataset(Dataset dataset) {
        return dataset.getData(getManager());
    }

    public /* synthetic */ void lambda$checkGradients$3$Trainer(List list, Parameter parameter) {
        list.add(this.parameterStore.getValue(parameter, this.devices[0]).getGradient());
    }

    public /* synthetic */ void lambda$close$2$Trainer(TrainingListener trainingListener) {
        trainingListener.onTrainingEnd(this);
    }

    public /* synthetic */ void lambda$initialize$1$Trainer(Pair pair) {
        for (Device device : this.devices) {
            this.parameterStore.getValue((Parameter) pair.getValue(), device);
        }
    }

    public /* synthetic */ void lambda$new$0$Trainer(TrainingListener trainingListener) {
        trainingListener.onTrainingBegin(this);
    }

    public GradientCollector newGradientCollector() {
        return this.manager.getEngine().newGradientCollector();
    }

    public void notifyListeners(Consumer<TrainingListener> consumer) {
        this.listeners.forEach(consumer);
    }

    public NDList predict(NDList nDList) {
        return this.model.getBlock().forward(this.parameterStore, nDList, false);
    }

    public void setMetrics(Metrics metrics) {
        this.metrics = metrics;
    }

    public void step() {
        if (!this.gradientsChecked) {
            checkGradients();
        }
        long nanoTime = System.nanoTime();
        this.parameterStore.updateAllParameters();
        addMetric("step", nanoTime);
    }
}
