package ai.djl.training.optimizer.learningrate;

import ai.djl.training.optimizer.learningrate.LearningRateTracker;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: classes.dex */
public class FactorTracker extends LearningRateTracker {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) FactorTracker.class);
    private int count;
    private float factor;
    private int step;
    private float stopFactorLearningRate;

    /* loaded from: classes.dex */
    public static final class Builder extends LearningRateTracker.LrBaseBuilder<Builder> {
        int step;
        float factor = 1.0f;
        float stopFactorLearningRate = 1.0E-8f;

        public FactorTracker build() {
            if (this.step != 0) {
                return new FactorTracker(this);
            }
            throw new IllegalArgumentException("Step must be set to change learning rate every N steps");
        }

        public Builder optFactor(float f) {
            if (f > 1.0f) {
                throw new IllegalArgumentException("factor should be no more than 1");
            }
            this.factor = f;
            return this;
        }

        public Builder optStopFactorLearningRate(float f) {
            this.stopFactorLearningRate = f;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.training.optimizer.learningrate.LearningRateTracker.LrBaseBuilder
        public Builder self() {
            return this;
        }

        public Builder setStep(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("step should be larger or equal to 1");
            }
            this.step = i;
            return this;
        }
    }

    public FactorTracker(Builder builder) {
        super(builder);
        this.step = builder.step;
        this.factor = builder.factor;
        this.stopFactorLearningRate = builder.stopFactorLearningRate;
        this.count = 0;
    }

    @Override // ai.djl.training.optimizer.learningrate.LearningRateTracker
    public float getNewLearningRate(int i) {
        if (i < this.warmUpSteps) {
            return getWarmUpLearningRate(i);
        }
        while (true) {
            int i2 = this.count;
            int i3 = this.step;
            if (i <= i2 + i3) {
                checkLearningRate(this.baseLearningRate);
                return this.baseLearningRate;
            }
            this.count = i2 + i3;
            this.baseLearningRate *= this.factor;
            float f = this.baseLearningRate;
            float f2 = this.stopFactorLearningRate;
            if (f < f2) {
                this.baseLearningRate = f2;
                logger.debug("Update[{}]: now learning rate arrived at {}, will not change in the future", Integer.valueOf(i), String.format("%.5e", Float.valueOf(this.baseLearningRate)));
            } else {
                logger.debug("Update[{}]: Change learning rate to {}", Integer.valueOf(i), String.format("%.5e", Float.valueOf(this.baseLearningRate)));
            }
        }
    }
}
