package ai.djl.training.optimizer;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Nag;
import ai.djl.training.optimizer.Sgd;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;
import java.util.function.Function;

/* loaded from: classes.dex */
public abstract class Optimizer {
    private int beginNumUpdate;
    protected float clipGrad;
    private int numUpdate;
    protected float rescaleGrad;
    private Map<String, Integer> updateCounts = new ConcurrentHashMap();
    private float weightDecays;

    /* loaded from: classes.dex */
    public static abstract class OptimizerBuilder<T extends OptimizerBuilder> {
        private int beginNumUpdate;
        private float weightDecays;
        private float rescaleGrad = 1.0f;
        private float clipGrad = -1.0f;

        public T optBeginNumUpdate(int i) {
            this.beginNumUpdate = i;
            return self();
        }

        public T optClipGrad(float f) {
            this.clipGrad = f;
            return self();
        }

        public T optWeightDecays(float f) {
            this.weightDecays = f;
            return self();
        }

        protected abstract T self();

        public T setRescaleGrad(float f) {
            this.rescaleGrad = f;
            return self();
        }
    }

    public Optimizer(OptimizerBuilder<?> optimizerBuilder) {
        this.rescaleGrad = ((OptimizerBuilder) optimizerBuilder).rescaleGrad;
        this.weightDecays = ((OptimizerBuilder) optimizerBuilder).weightDecays;
        this.clipGrad = ((OptimizerBuilder) optimizerBuilder).clipGrad;
        this.beginNumUpdate = ((OptimizerBuilder) optimizerBuilder).beginNumUpdate;
    }

    public static Adam.Builder adam() {
        return new Adam.Builder();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Map lambda$withDefaultState$1(Function function, Device device, String str) {
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        NDArray nDArray = (NDArray) function.apply(str);
        nDArray.detach();
        concurrentHashMap.put(device, nDArray);
        return concurrentHashMap;
    }

    public static Nag.Builder nag() {
        return new Nag.Builder();
    }

    public static Sgd.Builder sgd() {
        return new Sgd.Builder();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float getWeightDecay() {
        return this.weightDecays;
    }

    public /* synthetic */ Integer lambda$updateCount$0$Optimizer(String str, Integer num) {
        return Integer.valueOf((num == null ? this.beginNumUpdate : num.intValue()) + 1);
    }

    public abstract void update(String str, NDArray nDArray, NDArray nDArray2);

    /* JADX INFO: Access modifiers changed from: protected */
    public int updateCount(String str) {
        int max = Math.max(this.numUpdate, this.updateCounts.compute(str, new BiFunction() { // from class: ai.djl.training.optimizer.-$$Lambda$Optimizer$Kx5oN_yJ1aA6Q54Vn4-Hwsv4Icw
            @Override // java.util.function.BiFunction
            public final Object apply(Object obj, Object obj2) {
                return Optimizer.this.lambda$updateCount$0$Optimizer((String) obj, (Integer) obj2);
            }
        }).intValue());
        this.numUpdate = max;
        return max;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public NDArray withDefaultState(Map<String, Map<Device, NDArray>> map, String str, final Device device, final Function<String, NDArray> function) {
        final Map<Device, NDArray> computeIfAbsent = map.computeIfAbsent(str, new Function() { // from class: ai.djl.training.optimizer.-$$Lambda$Optimizer$dQe-xPiiaCGdYiiLjXqSgUjDslw
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Optimizer.lambda$withDefaultState$1(function, device, (String) obj);
            }
        });
        return computeIfAbsent.computeIfAbsent(device, new Function() { // from class: ai.djl.training.optimizer.-$$Lambda$Optimizer$CVjZQgoB-yZX8nI8pRhN7aA_oDc
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                NDArray device2;
                device2 = ((NDArray) computeIfAbsent.values().toArray()[0]).toDevice(device, true);
                return device2;
            }
        });
    }
}
