package ai.djl.ndarray.index.full;

import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.types.Shape;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.function.ToIntFunction;

/* loaded from: classes.dex */
public final class NDIndexFullSlice {
    private long[] max;
    private long[] min;
    private Shape shape;
    private Shape squeezedShape;
    private long[] step;
    private int[] toSqueeze;

    private NDIndexFullSlice(long[] jArr, long[] jArr2, long[] jArr3, int[] iArr, Shape shape, Shape shape2) {
        this.min = jArr;
        this.max = jArr2;
        this.step = jArr3;
        this.toSqueeze = iArr;
        this.shape = shape;
        this.squeezedShape = shape2;
    }

    private static void addSliceInfo(NDIndexElement nDIndexElement, int i, Shape shape, long[] jArr, long[] jArr2, long[] jArr3, List<Integer> list, long[] jArr4, List<Long> list2) {
        if (nDIndexElement instanceof NDIndexFixed) {
            long index = ((NDIndexFixed) nDIndexElement).getIndex();
            if (index < 0) {
                index = Math.floorMod(index, shape.get(i));
            }
            jArr[i] = index;
            jArr2[i] = jArr[i] + 1;
            jArr3[i] = 1;
            list.add(Integer.valueOf(i));
            jArr4[i] = 1;
            return;
        }
        if (!(nDIndexElement instanceof NDIndexSlice)) {
            if (nDIndexElement instanceof NDIndexAll) {
                padIndexAll(i, shape, jArr, jArr2, jArr3, jArr4, list2);
                return;
            }
            return;
        }
        NDIndexSlice nDIndexSlice = (NDIndexSlice) nDIndexElement;
        long longValue = ((Long) Optional.ofNullable(nDIndexSlice.getMin()).orElse(0L)).longValue();
        if (longValue < 0) {
            longValue = Math.floorMod(longValue, shape.get(i));
        }
        jArr[i] = longValue;
        long longValue2 = ((Long) Optional.ofNullable(nDIndexSlice.getMax()).orElse(Long.valueOf(shape.size(i)))).longValue();
        if (longValue2 < 0) {
            longValue2 = Math.floorMod(longValue2, shape.get(i));
        }
        jArr2[i] = longValue2;
        jArr3[i] = ((Long) Optional.ofNullable(nDIndexSlice.getStep()).orElse(1L)).longValue();
        jArr4[i] = (jArr2[i] - jArr[i]) / jArr3[i];
        list2.add(Long.valueOf(jArr4[i]));
    }

    public static Optional<NDIndexFullSlice> fromIndex(NDIndex nDIndex, Shape shape) {
        ArrayList arrayList;
        long[] jArr;
        ArrayList arrayList2;
        if (!nDIndex.stream().allMatch(new Predicate() { // from class: ai.djl.ndarray.index.full.-$$Lambda$NDIndexFullSlice$Sw-dowLz1v0SBdWtum6R27Zfe9w
            @Override // java.util.function.Predicate
            public final boolean test(Object obj) {
                return NDIndexFullSlice.lambda$fromIndex$0((NDIndexElement) obj);
            }
        })) {
            return Optional.empty();
        }
        int ellipsisIndex = nDIndex.getEllipsisIndex();
        int rank = nDIndex.getRank();
        int dimension = shape.dimension();
        if (rank > shape.dimension()) {
            throw new IllegalArgumentException("The index has too many dimensions - " + rank + " dimensions for array with " + dimension + " dimensions");
        }
        long[] jArr2 = new long[dimension];
        long[] jArr3 = new long[dimension];
        long[] jArr4 = new long[dimension];
        ArrayList arrayList3 = new ArrayList(dimension);
        long[] jArr5 = new long[dimension];
        ArrayList arrayList4 = new ArrayList(dimension);
        if (ellipsisIndex == -1 || ellipsisIndex == rank) {
            arrayList = arrayList4;
            jArr = jArr5;
            arrayList2 = arrayList3;
            for (int i = 0; i < rank; i++) {
                addSliceInfo(nDIndex.get(i), i, shape, jArr2, jArr3, jArr4, arrayList2, jArr, arrayList);
            }
            while (rank < shape.dimension()) {
                padIndexAll(rank, shape, jArr2, jArr3, jArr4, jArr, arrayList);
                rank++;
            }
        } else if (ellipsisIndex == 0) {
            int i2 = dimension - rank;
            int i3 = 0;
            while (i3 < i2) {
                padIndexAll(i3, shape, jArr2, jArr3, jArr4, jArr5, arrayList4);
                i3++;
            }
            arrayList = arrayList4;
            while (i3 < dimension) {
                addSliceInfo(nDIndex.get(i3 - i2), i3, shape, jArr2, jArr3, jArr4, arrayList3, jArr5, arrayList);
                i3++;
                arrayList3 = arrayList3;
            }
            jArr = jArr5;
            arrayList2 = arrayList3;
        } else {
            arrayList = arrayList4;
            jArr = jArr5;
            arrayList2 = arrayList3;
            int i4 = dimension - rank;
            int i5 = 0;
            while (i5 < ellipsisIndex) {
                addSliceInfo(nDIndex.get(i5), i5, shape, jArr2, jArr3, jArr4, arrayList2, jArr, arrayList);
                i5++;
            }
            while (i5 < i4 + ellipsisIndex) {
                padIndexAll(i5, shape, jArr2, jArr3, jArr4, jArr, arrayList);
                i5++;
            }
            for (int i6 = i5; i6 < dimension; i6++) {
                addSliceInfo(nDIndex.get(i6 - i4), i6, shape, jArr2, jArr3, jArr4, arrayList2, jArr, arrayList);
            }
        }
        return Optional.of(new NDIndexFullSlice(jArr2, jArr3, jArr4, arrayList2.stream().mapToInt(new ToIntFunction() { // from class: ai.djl.ndarray.index.full.-$$Lambda$NDIndexFullSlice$ppTmWqxJvUjTkDuNvny9Io1dolk
            @Override // java.util.function.ToIntFunction
            public final int applyAsInt(Object obj) {
                int intValue;
                intValue = ((Integer) obj).intValue();
                return intValue;
            }
        }).toArray(), new Shape(jArr), new Shape(arrayList)));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ boolean lambda$fromIndex$0(NDIndexElement nDIndexElement) {
        return (nDIndexElement instanceof NDIndexAll) || (nDIndexElement instanceof NDIndexFixed) || (nDIndexElement instanceof NDIndexSlice);
    }

    private static void padIndexAll(int i, Shape shape, long[] jArr, long[] jArr2, long[] jArr3, long[] jArr4, List<Long> list) {
        jArr[i] = 0;
        jArr2[i] = shape.size(i);
        jArr3[i] = 1;
        jArr4[i] = shape.size(i);
        list.add(Long.valueOf(shape.size(i)));
    }

    public long[] getMax() {
        return this.max;
    }

    public long[] getMin() {
        return this.min;
    }

    public Shape getShape() {
        return this.shape;
    }

    public Shape getSqueezedShape() {
        return this.squeezedShape;
    }

    public long[] getStep() {
        return this.step;
    }

    public int[] getToSqueeze() {
        return this.toSqueeze;
    }
}
