package ai.djl.training.loss;

import ai.djl.ndarray.NDList;
import ai.djl.training.evaluator.Evaluator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;

/* loaded from: classes.dex */
public abstract class Loss extends Evaluator {
    private Map<String, Float> totalLoss;

    public Loss(String str) {
        super(str);
        this.totalLoss = new ConcurrentHashMap();
    }

    public static HingeLoss hingeLoss() {
        return new HingeLoss();
    }

    public static HingeLoss hingeLoss(String str) {
        return new HingeLoss(str);
    }

    public static HingeLoss hingeLoss(String str, int i, float f) {
        return new HingeLoss(str, i, f);
    }

    public static L1Loss l1Loss() {
        return new L1Loss();
    }

    public static L1Loss l1Loss(String str) {
        return new L1Loss(str);
    }

    public static L1Loss l1Loss(String str, float f) {
        return new L1Loss(str, f);
    }

    public static L2Loss l2Loss() {
        return new L2Loss();
    }

    public static L2Loss l2Loss(String str) {
        return new L2Loss(str);
    }

    public static L2Loss l2Loss(String str, float f) {
        return new L2Loss(str, f);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Long lambda$resetAccumulator$2(String str, Long l) {
        return 0L;
    }

    public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss() {
        return new MaskedSoftmaxCrossEntropyLoss();
    }

    public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(String str) {
        return new MaskedSoftmaxCrossEntropyLoss(str);
    }

    public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(String str, float f, int i, boolean z, boolean z2) {
        return new MaskedSoftmaxCrossEntropyLoss(str, f, i, z, z2);
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss() {
        return new SigmoidBinaryCrossEntropyLoss();
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(String str) {
        return new SigmoidBinaryCrossEntropyLoss(str);
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(String str, float f, boolean z) {
        return new SigmoidBinaryCrossEntropyLoss(str, f, z);
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss() {
        return new SoftmaxCrossEntropyLoss();
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(String str) {
        return new SoftmaxCrossEntropyLoss(str);
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(String str, float f, int i, boolean z, boolean z2) {
        return new SoftmaxCrossEntropyLoss(str, f, i, z, z2);
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void addAccumulator(String str) {
        this.totalInstances.put(str, 0L);
        this.totalLoss.put(str, Float.valueOf(0.0f));
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public float getAccumulator(String str) {
        Long l = this.totalInstances.get(str);
        if (l == null) {
            throw new IllegalArgumentException("No loss found at that path");
        }
        if (l.longValue() == 0) {
            return Float.NaN;
        }
        return this.totalLoss.get(str).floatValue() / ((float) this.totalInstances.get(str).longValue());
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void resetAccumulator(String str) {
        this.totalInstances.compute(str, new BiFunction() { // from class: ai.djl.training.loss.-$$Lambda$Loss$k26ukBlVwUCEcQ0qBjmVsm76ykU
            @Override // java.util.function.BiFunction
            public final Object apply(Object obj, Object obj2) {
                return Loss.lambda$resetAccumulator$2((String) obj, (Long) obj2);
            }
        });
        this.totalLoss.compute(str, new BiFunction() { // from class: ai.djl.training.loss.-$$Lambda$Loss$hwQKR0qQ5hbBSQSMvTLZnaGXg1I
            @Override // java.util.function.BiFunction
            public final Object apply(Object obj, Object obj2) {
                Float valueOf;
                valueOf = Float.valueOf(0.0f);
                return valueOf;
            }
        });
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public void updateAccumulator(String str, NDList nDList, NDList nDList2) {
        final float f = evaluate(nDList, nDList2).sum().getFloat(new long[0]);
        this.totalInstances.compute(str, new BiFunction() { // from class: ai.djl.training.loss.-$$Lambda$Loss$eNEvgJzek5uyCBcQg4manUA67tk
            @Override // java.util.function.BiFunction
            public final Object apply(Object obj, Object obj2) {
                Long valueOf;
                valueOf = Long.valueOf(((Long) obj2).longValue() + 1);
                return valueOf;
            }
        });
        this.totalLoss.compute(str, new BiFunction() { // from class: ai.djl.training.loss.-$$Lambda$Loss$-caElRf_V8RPQO2AxioR-ulc5JY
            @Override // java.util.function.BiFunction
            public final Object apply(Object obj, Object obj2) {
                Float valueOf;
                valueOf = Float.valueOf(((Float) obj2).floatValue() + f);
                return valueOf;
            }
        });
    }
}
