package ai.djl.training.initializer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;

/* loaded from: classes.dex */
public class XavierInitializer implements Initializer {
    private FactorType factorType;
    private float magnitude;
    private RandomType randomType;

    /* renamed from: ai.djl.training.initializer.XavierInitializer$1, reason: invalid class name */
    /* loaded from: classes.dex */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$djl$training$initializer$XavierInitializer$FactorType;
        static final /* synthetic */ int[] $SwitchMap$ai$djl$training$initializer$XavierInitializer$RandomType;

        static {
            int[] iArr = new int[RandomType.values().length];
            $SwitchMap$ai$djl$training$initializer$XavierInitializer$RandomType = iArr;
            try {
                iArr[RandomType.UNIFORM.ordinal()] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                $SwitchMap$ai$djl$training$initializer$XavierInitializer$RandomType[RandomType.GAUSSIAN.ordinal()] = 2;
            } catch (NoSuchFieldError unused2) {
            }
            int[] iArr2 = new int[FactorType.values().length];
            $SwitchMap$ai$djl$training$initializer$XavierInitializer$FactorType = iArr2;
            try {
                iArr2[FactorType.AVG.ordinal()] = 1;
            } catch (NoSuchFieldError unused3) {
            }
            try {
                $SwitchMap$ai$djl$training$initializer$XavierInitializer$FactorType[FactorType.IN.ordinal()] = 2;
            } catch (NoSuchFieldError unused4) {
            }
            try {
                $SwitchMap$ai$djl$training$initializer$XavierInitializer$FactorType[FactorType.OUT.ordinal()] = 3;
            } catch (NoSuchFieldError unused5) {
            }
        }
    }

    /* loaded from: classes.dex */
    public enum FactorType {
        AVG,
        IN,
        OUT
    }

    /* loaded from: classes.dex */
    public enum RandomType {
        UNIFORM,
        GAUSSIAN
    }

    public XavierInitializer() {
        this(RandomType.UNIFORM, FactorType.AVG, 3.0f);
    }

    public XavierInitializer(RandomType randomType, FactorType factorType, float f) {
        this.randomType = randomType;
        this.factorType = factorType;
        this.magnitude = f;
    }

    @Override // ai.djl.training.initializer.Initializer
    public NDArray initialize(NDManager nDManager, Shape shape, DataType dataType) {
        long dimension = shape.dimension();
        if (dimension < 2) {
            throw new IllegalArgumentException("XavierInitializer cannot be applied to Shape with dimension: " + dimension + ", it requires shape to be at least 2D.");
        }
        float size = dimension == 2 ? 1.0f : (float) shape.slice(2).size();
        float f = ((float) shape.get(1)) * size;
        float head = ((float) shape.head()) * size;
        int i = AnonymousClass1.$SwitchMap$ai$djl$training$initializer$XavierInitializer$FactorType[this.factorType.ordinal()];
        if (i == 1) {
            f = (f + head) / 2.0f;
        } else if (i != 2) {
            if (i != 3) {
                throw new IllegalArgumentException("Invalid factor type, valid types are: avg, in, out");
            }
            f = head;
        }
        if (f == 0.0f) {
            throw new IllegalStateException("Xavier initializer factor is 0, please check your input shape.");
        }
        float sqrt = (float) StrictMath.sqrt(this.magnitude / f);
        int i2 = AnonymousClass1.$SwitchMap$ai$djl$training$initializer$XavierInitializer$RandomType[this.randomType.ordinal()];
        if (i2 == 1) {
            return nDManager.randomUniform(-sqrt, sqrt, shape, dataType, nDManager.getDevice());
        }
        if (i2 == 2) {
            return nDManager.randomNormal(0.0f, sqrt, shape, dataType, nDManager.getDevice());
        }
        throw new IllegalArgumentException("Invalid randomType");
    }
}
