package ai.djl.nn.core;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;

/* loaded from: classes.dex */
public class Prelu extends AbstractBlock {
    private static final byte VERSION = 2;
    private Parameter alpha;

    public Prelu() {
        super((byte) 2);
        this.alpha = addParameter((Prelu) new Parameter("alpha", this, ParameterType.OTHER), new Shape(new long[0]));
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        return singletonOrThrow.getNDArrayInternal().prelu(new NDList(singletonOrThrow, parameterStore.getValue(this.alpha, singletonOrThrow.getDevice())), pairList);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[]{shapeArr[0]};
    }

    @Override // ai.djl.nn.AbstractBlock
    public void loadMetadata(byte b2, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b2 == 2) {
            readInputShapes(dataInputStream);
        } else if (b2 != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b2));
        }
    }
}
