package ai.djl.pytorch.jni;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.pooling.PoolingConvention;
import ai.djl.pytorch.engine.PtDeviceType;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Function;
import java.util.function.IntFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: classes.dex */
public final class JniUtils {
    private static Set<String> configs;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) JniUtils.class);

    private JniUtils() {
    }

    public static PtNDArray abs(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchAbs(ptNDArray.getHandle()));
    }

    public static PtNDArray acos(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchAcos(ptNDArray.getHandle()));
    }

    public static void adamUpdate(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3, PtNDArray ptNDArray4, float f, float f2, float f3, float f4, float f5, float f6, float f7) {
        PyTorchLibrary.LIB.adamUpdate(ptNDArray.getHandle(), ptNDArray2.getHandle(), ptNDArray3.getHandle(), ptNDArray4.getHandle(), f, f2, f3, f4, f5, f6, f7);
    }

    public static PtNDArray add(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchAdd(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static void addi(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchAddi(ptNDArray.getHandle(), ptNDArray2.getHandle());
    }

    public static PtNDArray all(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchAll(ptNDArray.getHandle()));
    }

    public static PtNDArray any(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchAny(ptNDArray.getHandle()));
    }

    public static PtNDArray arange(PtNDManager ptNDManager, float f, float f2, float f3, DataType dataType, Device device, SparseFormat sparseFormat) {
        int layoutMapper = layoutMapper(sparseFormat);
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        int ordinal = dataType.ordinal();
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return ptNDManager.create(pyTorchLibrary.torchArange(f, f2, f3, ordinal, layoutMapper, iArr, false));
    }

    public static PtNDArray argMax(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchArgMax(ptNDArray.getHandle()));
    }

    public static PtNDArray argMax(PtNDArray ptNDArray, long j, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchArgMax(ptNDArray.getHandle(), j, z));
    }

    public static PtNDArray argMin(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchArgMin(ptNDArray.getHandle()));
    }

    public static PtNDArray argMin(PtNDArray ptNDArray, long j, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchArgMin(ptNDArray.getHandle(), j, z));
    }

    public static PtNDArray argSort(PtNDArray ptNDArray, long j, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchArgSort(ptNDArray.getHandle(), j, z));
    }

    public static PtNDArray asin(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchASin(ptNDArray.getHandle()));
    }

    public static PtNDArray atan(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchAtan(ptNDArray.getHandle()));
    }

    public static void attachGradient(PtNDArray ptNDArray) {
        PyTorchLibrary.LIB.torchAttachGrad(ptNDArray.getHandle());
    }

    public static PtNDArray avgPool(PtNDArray ptNDArray, Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention, boolean z) {
        boolean equals = PoolingConvention.FULL.equals(poolingConvention);
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNNAvgPool(ptNDArray.getHandle(), shape.dimension(), shape.getShape(), shape2.getShape(), shape3.getShape(), equals, z));
    }

    public static void backward(PtNDArray ptNDArray, PtNDArray ptNDArray2, boolean z, boolean z2) {
        PyTorchLibrary.LIB.torchBackward(ptNDArray.getHandle(), ptNDArray2.getHandle(), z, z2);
    }

    public static PtNDArray batchNorm(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3, PtNDArray ptNDArray4, PtNDArray ptNDArray5, boolean z, double d, double d2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNNBatchNorm(ptNDArray.getHandle(), ptNDArray2.getHandle(), ptNDArray3.getHandle(), ptNDArray4.getHandle(), ptNDArray5.getHandle(), z, d, d2));
    }

    public static PtNDArray booleanMask(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMaskedSelect(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static void booleanMaskSet(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3) {
        PyTorchLibrary.LIB.torchMaskedPut(ptNDArray.getHandle(), ptNDArray2.getHandle(), ptNDArray3.getHandle());
    }

    public static PtNDArray broadcast(PtNDArray ptNDArray, Shape shape) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchExpand(ptNDArray.getHandle(), shape.getShape()));
    }

    public static PtNDArray cat(NDArray[] nDArrayArr, long j) {
        return ((PtNDManager) nDArrayArr[0].getManager()).create(PyTorchLibrary.LIB.torchCat((Pointer[]) Arrays.stream(nDArrayArr).map(new Function() { // from class: ai.djl.pytorch.jni.-$$Lambda$JniUtils$wm_G6cBSum6Ivf4C3l-Bh-MFBEc
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                Pointer handle;
                handle = ((PtNDArray) ((NDArray) obj)).getHandle();
                return handle;
            }
        }).toArray(new IntFunction() { // from class: ai.djl.pytorch.jni.-$$Lambda$JniUtils$LCg9BSMfT1-CIsHq0lzcrZGtSNE
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return JniUtils.lambda$cat$3(i);
            }
        }), j));
    }

    public static PtNDArray ceil(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchCeil(ptNDArray.getHandle()));
    }

    public static PtNDArray clip(PtNDArray ptNDArray, Number number, Number number2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchClamp(ptNDArray.getHandle(), ((PtNDArray) ptNDArray.getManager().create(number)).getHandle(), ((PtNDArray) ptNDArray.getManager().create(number2)).getHandle()));
    }

    public static PtNDArray clone(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.tensorClone(ptNDArray.getHandle()));
    }

    public static boolean contentEqual(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return PyTorchLibrary.LIB.contentEqual(ptNDArray.getHandle(), ptNDArray2.getHandle());
    }

    public static PtNDArray convolution(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3, Shape shape, Shape shape2, Shape shape3, int i, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNNConvNd(shape.dimension(), ptNDArray.getHandle(), ptNDArray2.getHandle(), z ? null : ptNDArray3.getHandle(), shape.getShape(), shape2.getShape(), shape3.getShape(), i, !z));
    }

    public static PtNDArray cos(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchCos(ptNDArray.getHandle()));
    }

    public static PtNDArray cosh(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchCosh(ptNDArray.getHandle()));
    }

    public static PtNDArray createEmptyNdArray(PtNDManager ptNDManager, Shape shape, DataType dataType, Device device, SparseFormat sparseFormat) {
        int layoutMapper = layoutMapper(sparseFormat);
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        long[] shape2 = shape.getShape();
        int ordinal = dataType.ordinal();
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return ptNDManager.create(pyTorchLibrary.torchEmpty(shape2, ordinal, layoutMapper, iArr, false));
    }

    public static PtNDArray createNdFromByteBuffer(PtNDManager ptNDManager, ByteBuffer byteBuffer, Shape shape, DataType dataType, SparseFormat sparseFormat, Device device) {
        int layoutMapper = layoutMapper(sparseFormat);
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        long[] shape2 = shape.getShape();
        int ordinal = dataType.ordinal();
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return ptNDManager.create(pyTorchLibrary.torchFromBlob(byteBuffer, shape2, ordinal, layoutMapper, iArr, false));
    }

    public static PtNDArray createOnesNdArray(PtNDManager ptNDManager, Shape shape, DataType dataType, Device device, SparseFormat sparseFormat) {
        int layoutMapper = layoutMapper(sparseFormat);
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        long[] shape2 = shape.getShape();
        int ordinal = dataType.ordinal();
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return ptNDManager.create(pyTorchLibrary.torchOnes(shape2, ordinal, layoutMapper, iArr, false));
    }

    public static PtNDArray createZerosNdArray(PtNDManager ptNDManager, Shape shape, DataType dataType, Device device, SparseFormat sparseFormat) {
        int layoutMapper = layoutMapper(sparseFormat);
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        long[] shape2 = shape.getShape();
        int ordinal = dataType.ordinal();
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return ptNDManager.create(pyTorchLibrary.torchZeros(shape2, ordinal, layoutMapper, iArr, false));
    }

    public static PtNDArray cumSum(PtNDArray ptNDArray, long j) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchCumSum(ptNDArray.getHandle(), j));
    }

    public static void deleteModule(Pointer pointer) {
        PyTorchLibrary.LIB.torchDeleteModule(pointer);
    }

    public static void deleteNdArray(Pointer pointer) {
        PyTorchLibrary.LIB.torchDeleteTensor(pointer);
    }

    public static PtNDArray detachGradient(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchDetachGrad(ptNDArray.getHandle()));
    }

    public static PtNDArray div(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchTrueDivide(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static void divi(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchTrueDividei(ptNDArray.getHandle(), ptNDArray2.getHandle());
    }

    public static PtNDArray dot(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getShape().dimension() == 1 ? ptNDArray.getManager().create(PyTorchLibrary.LIB.torchDot(ptNDArray.getHandle(), ptNDArray2.getHandle())) : ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMM(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray dropout(PtNDArray ptNDArray, double d, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNNDropout(ptNDArray.getHandle(), d, z));
    }

    public static void enableInferenceMode(PtSymbolBlock ptSymbolBlock) {
        PyTorchLibrary.LIB.moduleEval(ptSymbolBlock.getHandle());
    }

    public static void enableTrainingMode(PtSymbolBlock ptSymbolBlock) {
        PyTorchLibrary.LIB.moduleTrain(ptSymbolBlock.getHandle());
    }

    public static PtNDArray eq(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchEq(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray exp(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchExp(ptNDArray.getHandle()));
    }

    public static PtNDArray eye(PtNDManager ptNDManager, int i, int i2, DataType dataType, Device device, SparseFormat sparseFormat) {
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        int ordinal = dataType.ordinal();
        int layoutMapper = layoutMapper(sparseFormat);
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return ptNDManager.create(pyTorchLibrary.torchEye(i, i2, ordinal, layoutMapper, iArr, false));
    }

    public static PtNDArray flatten(PtNDArray ptNDArray, long j, long j2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchFlatten(ptNDArray.getHandle(), j, j2));
    }

    public static PtNDArray floor(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchFloor(ptNDArray.getHandle()));
    }

    public static PtNDArray fullyConnected(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNNLinear(ptNDArray.getHandle(), ptNDArray2.getHandle(), z ? null : ptNDArray3.getHandle(), !z));
    }

    public static ByteBuffer getByteBuffer(PtNDArray ptNDArray) {
        if (!ptNDArray.getDevice().equals(Device.cpu())) {
            ptNDArray = ptNDArray.toDevice(Device.cpu(), false);
        }
        return ByteBuffer.wrap(PyTorchLibrary.LIB.torchDataPtr(ptNDArray.getHandle())).order(ByteOrder.nativeOrder());
    }

    public static DataType getDataType(PtNDArray ptNDArray) {
        return DataType.values()[PyTorchLibrary.LIB.torchDType(ptNDArray.getHandle())];
    }

    public static Device getDevice(PtNDArray ptNDArray) {
        int[] iArr = PyTorchLibrary.LIB.torchDevice(ptNDArray.getHandle());
        return Device.of(PtDeviceType.fromDeviceType(iArr[0]), iArr[1]);
    }

    public static Set<String> getFeatures() {
        Set<String> set = configs;
        if (set != null) {
            return set;
        }
        HashSet hashSet = new HashSet();
        PyTorchLibrary.LIB.torchShowConfig(hashSet);
        configs = hashSet;
        return hashSet;
    }

    public static PtNDArray getGradient(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchGrad(ptNDArray.getHandle()));
    }

    public static String getGradientFunctionNames(PtNDArray ptNDArray) {
        return PyTorchLibrary.LIB.torchGradFnName(ptNDArray.getHandle());
    }

    public static Shape getShape(PtNDArray ptNDArray) {
        return new Shape(PyTorchLibrary.LIB.torchSizes(ptNDArray.getHandle()));
    }

    public static SparseFormat getSparseFormat(PtNDArray ptNDArray) {
        int i = PyTorchLibrary.LIB.torchLayout(ptNDArray.getHandle());
        if (i == 0) {
            return SparseFormat.DENSE;
        }
        if (i == 1) {
            return SparseFormat.COO;
        }
        throw new UnsupportedOperationException("Unsupported data format");
    }

    public static PtNDArray globalAvgPool(PtNDArray ptNDArray, int i) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNNAdaptiveAvgPool(ptNDArray.getHandle(), i, (i == 1 ? new Shape(1) : i == 2 ? new Shape(1, 1) : new Shape(1, 1, 1)).getShape()));
    }

    public static PtNDArray globalMaxPool(PtNDArray ptNDArray, int i) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNNAdaptiveMaxPool(ptNDArray.getHandle(), i, (i == 1 ? new Shape(1) : i == 2 ? new Shape(1, 1) : new Shape(1, 1, 1)).getShape()));
    }

    public static PtNDArray gt(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchGt(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray gte(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchGte(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray index(PtNDArray ptNDArray, long[] jArr, long[] jArr2, long[] jArr3) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchIndex(ptNDArray.getHandle(), jArr, jArr2, jArr3));
    }

    public static void indexSet(PtNDArray ptNDArray, PtNDArray ptNDArray2, long[] jArr, long[] jArr2, long[] jArr3) {
        PyTorchLibrary.LIB.torchIndexPut(ptNDArray.getHandle(), ptNDArray2.getHandle(), jArr, jArr2, jArr3);
    }

    public static PtNDArray isInf(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchIsInf(ptNDArray.getHandle()));
    }

    public static PtNDArray isNaN(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchIsNaN(ptNDArray.getHandle()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Pointer[] lambda$cat$3(int i) {
        return new Pointer[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Pointer[] lambda$stack$1(int i) {
        return new Pointer[i];
    }

    private static int layoutMapper(SparseFormat sparseFormat) {
        if (sparseFormat == SparseFormat.DENSE) {
            return Boolean.getBoolean("ai.djl.pytorch.use_mkldnn") ? 2 : 0;
        }
        if (sparseFormat == SparseFormat.COO) {
            return 1;
        }
        throw new IllegalArgumentException("Current PyTorch only support SparseFormat.DENSE and SparseFormat.COO");
    }

    public static PtNDArray linspace(PtNDManager ptNDManager, float f, float f2, int i, DataType dataType, Device device, SparseFormat sparseFormat) {
        int layoutMapper = layoutMapper(sparseFormat);
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        int ordinal = dataType.ordinal();
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return ptNDManager.create(pyTorchLibrary.torchLinspace(f, f2, i, ordinal, layoutMapper, iArr, false));
    }

    public static PtSymbolBlock loadModule(PtNDManager ptNDManager, Path path, Device device) {
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        String path2 = path.toString();
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return new PtSymbolBlock(ptNDManager, pyTorchLibrary.moduleLoad(path2, iArr));
    }

    public static PtNDArray log(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchLog(ptNDArray.getHandle()));
    }

    public static PtNDArray log10(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchLog10(ptNDArray.getHandle()));
    }

    public static PtNDArray log2(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchLog2(ptNDArray.getHandle()));
    }

    public static PtNDArray logSoftmax(PtNDArray ptNDArray, long j, DataType dataType) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchLogSoftmax(ptNDArray.getHandle(), j, dataType.ordinal()));
    }

    public static PtNDArray logicalAnd(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchLogicalAnd(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray logicalNot(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchLogicalNot(ptNDArray.getHandle()));
    }

    public static PtNDArray logicalOr(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchLogicalOr(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray logicalXor(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchLogicalXor(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray lt(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchLt(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray lte(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchLte(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray matmul(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMatmul(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray max(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMax(ptNDArray.getHandle()));
    }

    public static PtNDArray max(PtNDArray ptNDArray, long j, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMax(ptNDArray.getHandle(), j, z));
    }

    public static PtNDArray max(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMax(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray maxPool(PtNDArray ptNDArray, Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention) {
        boolean equals = PoolingConvention.FULL.equals(poolingConvention);
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNNMaxPool(ptNDArray.getHandle(), shape.dimension(), shape.getShape(), shape2.getShape(), shape3.getShape(), equals));
    }

    public static PtNDArray mean(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMean(ptNDArray.getHandle()));
    }

    public static PtNDArray mean(PtNDArray ptNDArray, long j, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMean(ptNDArray.getHandle(), j, z));
    }

    public static PtNDArray min(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMin(ptNDArray.getHandle()));
    }

    public static PtNDArray min(PtNDArray ptNDArray, long j, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMin(ptNDArray.getHandle(), j, z));
    }

    public static PtNDArray min(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMin(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray mul(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchMul(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static void muli(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchMuli(ptNDArray.getHandle(), ptNDArray2.getHandle());
    }

    public static PtNDArray neg(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNeg(ptNDArray.getHandle()));
    }

    public static void negi(PtNDArray ptNDArray) {
        PyTorchLibrary.LIB.torchNegi(ptNDArray.getHandle());
    }

    public static PtNDArray neq(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNeq(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static PtNDArray none(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNone(ptNDArray.getHandle()));
    }

    public static PtNDArray normal(PtNDManager ptNDManager, double d, double d2, Shape shape, DataType dataType, Device device) {
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        long[] shape2 = shape.getShape();
        int ordinal = dataType.ordinal();
        int layoutMapper = layoutMapper(SparseFormat.DENSE);
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return ptNDManager.create(pyTorchLibrary.atNormal(d, d2, shape2, ordinal, layoutMapper, iArr, false));
    }

    public static PtNDArray onesLike(PtNDArray ptNDArray, DataType dataType, Device device, SparseFormat sparseFormat) {
        int layoutMapper = layoutMapper(sparseFormat);
        PtNDManager manager = ptNDArray.getManager();
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        Pointer handle = ptNDArray.getHandle();
        int ordinal = dataType.ordinal();
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return manager.create(pyTorchLibrary.torchOnesLike(handle, ordinal, layoutMapper, iArr, false));
    }

    public static PtNDArray permute(PtNDArray ptNDArray, long[] jArr) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchPermute(ptNDArray.getHandle(), jArr));
    }

    public static PtNDArray pick(PtNDArray ptNDArray, PtNDArray ptNDArray2, long j) {
        Shape shape = ptNDArray2.getShape();
        Shape shape2 = ptNDArray.getShape();
        int dimension = shape.dimension();
        int dimension2 = shape2.dimension();
        if (dimension != dimension2) {
            int i = 0;
            while (true) {
                if (i >= dimension2 - dimension) {
                    break;
                }
                if (shape.equals(shape2.slice(i, dimension))) {
                    long[] shape3 = shape.getShape();
                    long[] jArr = new long[dimension2];
                    Arrays.fill(jArr, 0, i, 1L);
                    Arrays.fill(jArr, i, shape3.length + i, shape3[i]);
                    Arrays.fill(jArr, i + shape3.length, dimension2, 1L);
                    shape = new Shape(jArr);
                    break;
                }
                i++;
            }
            if (shape.equals(ptNDArray2.getShape())) {
                throw new IllegalArgumentException("expand shape failed! Cannot expand from " + shape + "to " + shape2);
            }
            ptNDArray2 = ptNDArray2.reshape(shape);
        }
        if (ptNDArray2.getDataType() != DataType.INT64) {
            ptNDArray2 = ptNDArray2.toType(DataType.INT64, true);
        }
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchGather(ptNDArray.getHandle(), ptNDArray2.getHandle(), j, false));
    }

    public static PtNDArray pow(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchPow(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static void powi(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchPowi(ptNDArray.getHandle(), ptNDArray2.getHandle());
    }

    public static PtNDArray prod(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchProd(ptNDArray.getHandle()));
    }

    public static PtNDArray prod(PtNDArray ptNDArray, long j, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchProd(ptNDArray.getHandle(), j, z));
    }

    public static PtNDArray relu(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchNNRelu(ptNDArray.getHandle()));
    }

    public static PtNDArray remainder(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchRemainder(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static void remainderi(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchRemainderi(ptNDArray.getHandle(), ptNDArray2.getHandle());
    }

    public static PtNDArray repeat(PtNDArray ptNDArray, long j, long j2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchRepeatInterleave(ptNDArray.getHandle(), j, j2));
    }

    public static boolean requiresGrad(PtNDArray ptNDArray) {
        return PyTorchLibrary.LIB.torchRequiresGrad(ptNDArray.getHandle());
    }

    public static PtNDArray reshape(PtNDArray ptNDArray, long[] jArr) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchReshape(ptNDArray.getHandle(), jArr));
    }

    public static PtNDArray round(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchRound(ptNDArray.getHandle()));
    }

    public static void set(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchSet(ptNDArray.getHandle(), ptNDArray2.getHandle());
    }

    public static void setNumInteropThreads(int i) {
        PyTorchLibrary.LIB.torchSetNumInteropThreads(i);
    }

    public static void setNumThreads(int i) {
        PyTorchLibrary.LIB.torchSetNumThreads(i);
    }

    public static void setSeed(long j) {
        PyTorchLibrary.LIB.torchManualSeed(j);
    }

    public static void sgdUpdate(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3, float f, float f2, float f3, float f4, float f5) {
        PyTorchLibrary.LIB.sgdUpdate(ptNDArray.getHandle(), ptNDArray2.getHandle(), ptNDArray3 == null ? null : ptNDArray3.getHandle(), f, f2, f3, f4, f5);
    }

    public static PtNDArray sigmoid(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSigmoid(ptNDArray.getHandle()));
    }

    public static PtNDArray sin(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSin(ptNDArray.getHandle()));
    }

    public static PtNDArray sinh(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSinh(ptNDArray.getHandle()));
    }

    public static PtNDArray slice(PtNDArray ptNDArray, long j, long j2, long j3, long j4) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSlice(ptNDArray.getHandle(), j, j2, j3, j4));
    }

    public static PtNDArray softmax(PtNDArray ptNDArray, long j, DataType dataType) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSoftmax(ptNDArray.getHandle(), j, dataType.ordinal()));
    }

    public static PtNDArray sort(PtNDArray ptNDArray, long j, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSort(ptNDArray.getHandle(), j, z));
    }

    public static NDList split(PtNDArray ptNDArray, long j, long j2) {
        Pointer[] pointerArr = PyTorchLibrary.LIB.torchSplit(ptNDArray.getHandle(), j, j2);
        NDList nDList = new NDList();
        for (Pointer pointer : pointerArr) {
            nDList.add(ptNDArray.getManager().create(pointer));
        }
        return nDList;
    }

    public static NDList split(PtNDArray ptNDArray, long[] jArr, long j) {
        Pointer[] pointerArr = PyTorchLibrary.LIB.torchSplit(ptNDArray.getHandle(), jArr, j);
        NDList nDList = new NDList();
        for (Pointer pointer : pointerArr) {
            nDList.add(ptNDArray.getManager().create(pointer));
        }
        return nDList;
    }

    public static PtNDArray sqrt(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSqrt(ptNDArray.getHandle()));
    }

    public static PtNDArray square(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSquare(ptNDArray.getHandle()));
    }

    public static PtNDArray squeeze(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSqueeze(ptNDArray.getHandle()));
    }

    public static PtNDArray squeeze(PtNDArray ptNDArray, long j) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSqueeze(ptNDArray.getHandle(), j));
    }

    public static PtNDArray stack(NDArray[] nDArrayArr, int i) {
        return ((PtNDManager) nDArrayArr[0].getManager()).create(PyTorchLibrary.LIB.torchStack((Pointer[]) Arrays.stream(nDArrayArr).map(new Function() { // from class: ai.djl.pytorch.jni.-$$Lambda$JniUtils$y-H_ISe1zk3yKnwh_ijSgfCXVps
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                Pointer handle;
                handle = ((PtNDArray) ((NDArray) obj)).getHandle();
                return handle;
            }
        }).toArray(new IntFunction() { // from class: ai.djl.pytorch.jni.-$$Lambda$JniUtils$P4CvRBGEBx4qqomILuJQ4gSW1hg
            @Override // java.util.function.IntFunction
            public final Object apply(int i2) {
                return JniUtils.lambda$stack$1(i2);
            }
        }), i));
    }

    public static PtNDArray sub(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSub(ptNDArray.getHandle(), ptNDArray2.getHandle()));
    }

    public static void subi(PtNDArray ptNDArray, PtNDArray ptNDArray2) {
        PyTorchLibrary.LIB.torchSubi(ptNDArray.getHandle(), ptNDArray2.getHandle());
    }

    public static PtNDArray sum(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSum(ptNDArray.getHandle()));
    }

    public static PtNDArray sum(PtNDArray ptNDArray, long[] jArr, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchSum(ptNDArray.getHandle(), jArr, z));
    }

    public static PtNDArray tan(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchTan(ptNDArray.getHandle()));
    }

    public static PtNDArray tanh(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchTanh(ptNDArray.getHandle()));
    }

    public static PtNDArray tile(PtNDArray ptNDArray, long[] jArr) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchRepeat(ptNDArray.getHandle(), jArr));
    }

    public static PtNDArray to(PtNDArray ptNDArray, DataType dataType, Device device, boolean z) {
        PtNDManager manager = ptNDArray.getManager();
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        Pointer handle = ptNDArray.getHandle();
        int ordinal = dataType.ordinal();
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return manager.create(pyTorchLibrary.torchTo(handle, ordinal, iArr, z));
    }

    public static PtNDArray toDense(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchToDense(ptNDArray.getHandle()));
    }

    public static PtNDArray toSparse(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchToSparse(ptNDArray.getHandle()));
    }

    public static PtNDArray transpose(PtNDArray ptNDArray, long j, long j2) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchTranspose(ptNDArray.getHandle(), j, j2));
    }

    public static PtNDArray trunc(PtNDArray ptNDArray) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchTrunc(ptNDArray.getHandle()));
    }

    public static PtNDArray uniform(PtNDManager ptNDManager, double d, double d2, Shape shape, DataType dataType, Device device) {
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        long[] shape2 = shape.getShape();
        int ordinal = dataType.ordinal();
        int layoutMapper = layoutMapper(SparseFormat.DENSE);
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return ptNDManager.create(pyTorchLibrary.tensorUniform(d, d2, shape2, ordinal, layoutMapper, iArr, false));
    }

    public static PtNDArray unsqueeze(PtNDArray ptNDArray, long j) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchUnsqueeze(ptNDArray.getHandle(), j));
    }

    public static PtNDArray upsampleBilinear2d(PtNDArray ptNDArray, long[] jArr, boolean z) {
        return ptNDArray.getManager().create(PyTorchLibrary.LIB.torchUpsampleBilinear2d(ptNDArray.getHandle(), jArr, z));
    }

    public static PtNDArray where(PtNDArray ptNDArray, PtNDArray ptNDArray2, PtNDArray ptNDArray3) {
        return ptNDArray2.getManager().create(PyTorchLibrary.LIB.torchWhere(ptNDArray.getHandle(), ptNDArray2.getHandle(), ptNDArray3.getHandle()));
    }

    public static void zeroGrad(PtNDArray ptNDArray) {
        PyTorchLibrary.LIB.zeroGrad(ptNDArray.getHandle());
    }

    public static PtNDArray zerosLike(PtNDArray ptNDArray, DataType dataType, Device device, SparseFormat sparseFormat) {
        int layoutMapper = layoutMapper(sparseFormat);
        PtNDManager manager = ptNDArray.getManager();
        PyTorchLibrary pyTorchLibrary = PyTorchLibrary.LIB;
        Pointer handle = ptNDArray.getHandle();
        int ordinal = dataType.ordinal();
        int[] iArr = new int[2];
        iArr[0] = PtDeviceType.toDeviceType(device);
        iArr[1] = device.equals(Device.cpu()) ? -1 : device.getDeviceId();
        return manager.create(pyTorchLibrary.torchZerosLike(handle, ordinal, layoutMapper, iArr, false));
    }
}
