package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.training.initializer.Initializer;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Objects;
import java.util.UUID;

/* loaded from: classes.dex */
public class Parameter implements AutoCloseable {
    private static final byte VERSION = 1;
    private NDArray array;
    private Block block;
    private SparseFormat gradientFormat;
    private String id;
    private Initializer initializer;
    private DataType mandatoryDataType;
    private String name;
    private boolean requiresGrad;
    private ParameterType type;

    public Parameter(String str, Block block, ParameterType parameterType) {
        this(str, block, parameterType, true, SparseFormat.DENSE);
    }

    public Parameter(String str, Block block, ParameterType parameterType, boolean z) {
        this(str, block, parameterType, z, SparseFormat.DENSE);
    }

    public Parameter(String str, Block block, ParameterType parameterType, boolean z, SparseFormat sparseFormat) {
        this.id = UUID.randomUUID().toString();
        this.name = str;
        this.block = block;
        this.type = parameterType;
        this.requiresGrad = z;
        this.initializer = parameterType.getInitializer();
        this.gradientFormat = sparseFormat;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        NDArray nDArray = this.array;
        if (nDArray != null) {
            nDArray.close();
            this.array = null;
        }
    }

    public NDArray getArray() {
        if (isInitialized()) {
            return this.array;
        }
        throw new IllegalStateException("The array has not been initialized");
    }

    public String getId() {
        return this.id;
    }

    public String getName() {
        String str = this.name;
        return str == null ? "" : str;
    }

    public ParameterType getType() {
        return this.type;
    }

    public void initialize(NDManager nDManager, DataType dataType, Shape[] shapeArr) {
        Objects.requireNonNull(this.initializer, "No initializer has been set");
        if (!isInitialized()) {
            Shape parameterShape = this.block.getParameterShape(this.name, shapeArr);
            Initializer initializer = this.initializer;
            DataType dataType2 = this.mandatoryDataType;
            if (dataType2 != null) {
                dataType = dataType2;
            }
            NDArray initialize = initializer.initialize(nDManager, parameterShape, dataType);
            this.array = initialize;
            initialize.setName(this.name);
        }
        if (requireGradient()) {
            this.array.attachGradient(this.gradientFormat);
        }
    }

    public boolean isInitialized() {
        return this.array != null;
    }

    public void load(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        char readChar = dataInputStream.readChar();
        if (readChar == 'N') {
            return;
        }
        if (readChar != 'P') {
            throw new MalformedModelException("Invalid input data.");
        }
        byte readByte = dataInputStream.readByte();
        if (readByte != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
        String readUTF = dataInputStream.readUTF();
        if (!readUTF.equals(getName())) {
            throw new MalformedModelException("Unexpected parameter name: " + readUTF + ", expected: " + this.name);
        }
        this.array = nDManager.decode(dataInputStream);
    }

    public boolean requireGradient() {
        return this.requiresGrad;
    }

    public void save(DataOutputStream dataOutputStream) throws IOException {
        if (!isInitialized()) {
            dataOutputStream.writeChar(78);
            return;
        }
        dataOutputStream.writeChar(80);
        dataOutputStream.writeByte(1);
        dataOutputStream.writeUTF(getName());
        dataOutputStream.write(this.array.encode());
    }

    public void setArray(NDArray nDArray) {
        this.array = nDArray;
        nDArray.setName(this.name);
    }

    public void setInitializer(Initializer initializer, boolean z) {
        if (z || this.initializer == null) {
            this.initializer = initializer;
        }
    }

    public void setMandatoryDataType(DataType dataType) {
        this.mandatoryDataType = dataType;
    }
}
