package ai.djl.nn.convolutional;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.function.Function;

/* loaded from: classes.dex */
public abstract class Convolution extends AbstractBlock {
    private static final byte VERSION = 2;
    protected Parameter bias;
    protected Shape dilate;
    protected boolean includeBias;
    protected Shape kernel;
    protected int numFilters;
    protected int numGroups;
    protected Shape pad;
    protected Shape stride;
    protected Parameter weight;

    /* loaded from: classes.dex */
    public static abstract class ConvolutionBuilder<T extends ConvolutionBuilder> {
        protected Shape dilate;
        protected Shape kernel;
        protected int numFilters;
        protected Shape pad;
        protected Shape stride;
        protected int numGroups = 1;
        protected boolean includeBias = true;

        public T optBias(boolean z) {
            this.includeBias = z;
            return self();
        }

        public T optDilate(Shape shape) {
            this.dilate = shape;
            return self();
        }

        public T optNumGroups(int i) {
            this.numGroups = i;
            return self();
        }

        public T optPad(Shape shape) {
            this.pad = shape;
            return self();
        }

        public T optStride(Shape shape) {
            this.stride = shape;
            return self();
        }

        protected abstract T self();

        public T setKernel(Shape shape) {
            this.kernel = shape;
            return self();
        }

        public T setNumFilters(int i) {
            this.numFilters = i;
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void validate() {
            if (this.kernel == null || this.numFilters == 0) {
                throw new IllegalArgumentException("Kernel and numFilters must be set");
            }
        }
    }

    public Convolution(ConvolutionBuilder<?> convolutionBuilder) {
        super((byte) 2);
        this.kernel = convolutionBuilder.kernel;
        this.stride = convolutionBuilder.stride;
        this.pad = convolutionBuilder.pad;
        this.dilate = convolutionBuilder.dilate;
        this.numFilters = convolutionBuilder.numFilters;
        this.numGroups = convolutionBuilder.numGroups;
        this.includeBias = convolutionBuilder.includeBias;
        this.weight = addParameter((Convolution) new Parameter("weight", this, ParameterType.WEIGHT), new Function() { // from class: ai.djl.nn.convolutional.-$$Lambda$Convolution$ecxIqkepkTeBSUH9c6sLwY6KnMM
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return Convolution.this.lambda$new$0$Convolution((Shape[]) obj);
            }
        });
        if (this.includeBias) {
            this.bias = addParameter((Convolution) new Parameter("bias", this, ParameterType.BIAS), new Shape(this.numFilters));
        }
    }

    private NDList opInputs(ParameterStore parameterStore, NDList nDList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        Device device = singletonOrThrow.getDevice();
        NDList nDList2 = new NDList(3);
        nDList2.add(singletonOrThrow);
        nDList2.add(parameterStore.getValue(this.weight, device));
        Parameter parameter = this.bias;
        if (parameter != null) {
            nDList2.add(parameterStore.getValue(parameter, device));
        }
        return nDList2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ai.djl.nn.AbstractBlock
    public void beforeInitialize(Shape[] shapeArr) {
        this.inputShapes = shapeArr;
        Block.validateLayout(getExpectedLayout(), shapeArr[0].getLayout());
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList opInputs = opInputs(parameterStore, nDList);
        return opInputs.head().getNDArrayInternal().convolution(opInputs, this.kernel, this.stride, this.pad, this.dilate, this.numFilters, this.numGroups, getStringLayout(), !this.includeBias, pairList);
    }

    protected abstract LayoutType[] getExpectedLayout();

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        long[] jArr = new long[numDimensions()];
        jArr[0] = shapeArr[0].get(0);
        jArr[1] = this.numFilters;
        for (int i = 0; i < numDimensions() - 2; i++) {
            int i2 = i + 2;
            jArr[i2] = ((((shapeArr[0].get(i2) + (this.pad.get(i) * 2)) - (this.dilate.get(0) * (this.kernel.get(i) - 1))) - 1) / this.stride.get(0)) + 1;
        }
        return new Shape[]{new Shape(jArr)};
    }

    protected abstract String getStringLayout();

    public /* synthetic */ Shape lambda$new$0$Convolution(Shape[] shapeArr) {
        return new Shape(this.numFilters, shapeArr[0].get(1)).addAll(this.kernel);
    }

    @Override // ai.djl.nn.AbstractBlock
    public void loadMetadata(byte b2, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b2 == 2) {
            readInputShapes(dataInputStream);
        } else if (b2 != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b2));
        }
    }

    protected abstract int numDimensions();
}
