package ai.djl.translate;

import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import java.util.Arrays;
import java.util.Iterator;
import java.util.function.IntFunction;
import java.util.function.LongUnaryOperator;
import java.util.stream.LongStream;

/* loaded from: classes.dex */
public class StackBatchifier implements Batchifier {
    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ NDList lambda$split$0(int i) {
        return new NDList();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ long lambda$split$1(int i, long j) {
        return j * i;
    }

    private NDList split(NDArray nDArray, int i, boolean z) {
        int intExact = Math.toIntExact(nDArray.size(0));
        if (intExact < i) {
            throw new IllegalArgumentException("Batch size(" + intExact + ") is less then slice number(" + i + ").");
        }
        if (z && intExact % i != 0) {
            throw new IllegalArgumentException("data with shape " + intExact + " cannot be evenly split into " + i + ". Use a batch size that's multiple of " + i + " or set even_split=true to allow uneven partitioning of data.");
        }
        if (z) {
            return nDArray.split(i);
        }
        final int ceil = (int) Math.ceil(intExact / i);
        return nDArray.split(LongStream.range(1L, i).map(new LongUnaryOperator() { // from class: ai.djl.translate.-$$Lambda$StackBatchifier$SMc4jaNeUify2cLgKfDIp_emIHc
            @Override // java.util.function.LongUnaryOperator
            public final long applyAsLong(long j) {
                return StackBatchifier.lambda$split$1(ceil, j);
            }
        }).toArray());
    }

    @Override // ai.djl.translate.Batchifier
    public NDList batchify(NDList[] nDListArr) {
        int length = nDListArr.length;
        int size = nDListArr[0].size();
        if (size == 0) {
            return new NDList();
        }
        try {
            NDList nDList = new NDList(size);
            for (int i = 0; i < size; i++) {
                NDList nDList2 = new NDList(length);
                for (NDList nDList3 : nDListArr) {
                    nDList2.add(nDList3.get(i));
                }
                nDList.add(NDArrays.stack(new NDList(nDList2)));
            }
            return nDList;
        } catch (EngineException | IndexOutOfBoundsException e) {
            for (NDList nDList4 : nDListArr) {
                if (nDList4.size() != size) {
                    throw new IllegalArgumentException("You cannot batch data with different numbers of inputs", e);
                }
            }
            for (int i2 = 0; i2 < size; i2++) {
                Shape shape = nDListArr[0].get(i2).getShape();
                DataType dataType = nDListArr[0].get(i2).getDataType();
                for (NDList nDList5 : nDListArr) {
                    NDArray nDArray = nDList5.get(i2);
                    if (!nDArray.getShape().equals(shape)) {
                        throw new IllegalArgumentException("You cannot batch data with different input shapes", e);
                    }
                    if (!nDArray.getDataType().equals(dataType)) {
                        throw new IllegalArgumentException("You cannot batch data with different input data types", e);
                    }
                }
            }
            throw e;
        }
    }

    @Override // ai.djl.translate.Batchifier
    public NDList[] split(NDList nDList, int i, boolean z) {
        int min = Math.min(i, Math.toIntExact(nDList.head().size(0)));
        NDList[] nDListArr = new NDList[min];
        Arrays.setAll(nDListArr, new IntFunction() { // from class: ai.djl.translate.-$$Lambda$StackBatchifier$TAfcjjrGJVKioJTAEz5vW8dh8QI
            @Override // java.util.function.IntFunction
            public final Object apply(int i2) {
                return StackBatchifier.lambda$split$0(i2);
            }
        });
        Iterator<NDArray> it = nDList.iterator();
        while (it.hasNext()) {
            NDArray next = it.next();
            String name = next.getName();
            NDList split = split(next, min, z);
            for (int i2 = 0; i2 < min; i2++) {
                NDArray nDArray = split.get(i2);
                nDArray.setName(name);
                nDListArr[i2].add(nDArray);
            }
        }
        return nDListArr;
    }

    @Override // ai.djl.translate.Batchifier
    public NDList[] unbatchify(NDList nDList) {
        if (nDList.size() == 0) {
            return new NDList[0];
        }
        int intExact = Math.toIntExact(nDList.head().size(0));
        if (intExact == 0) {
            return new NDList[0];
        }
        NDList[] nDListArr = new NDList[intExact];
        for (int i = 0; i < intExact; i++) {
            nDListArr[i] = new NDList();
        }
        Iterator<NDArray> it = nDList.iterator();
        while (it.hasNext()) {
            NDArray next = it.next();
            NDList split = next.split(intExact);
            for (int i2 = 0; i2 < intExact; i2++) {
                NDArray squeeze = split.get(i2).squeeze(0);
                squeeze.setName(next.getName());
                nDListArr[i2].add(squeeze);
            }
        }
        return nDListArr;
    }
}
