package ai.djl.ndarray;

import ai.djl.Device;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.function.Consumer;
import java.util.function.Predicate;

/* loaded from: classes.dex */
public class NDList extends ArrayList<NDArray> implements AutoCloseable {
    private static final long serialVersionUID = 1;

    public NDList() {
    }

    public NDList(int i) {
        super(i);
    }

    public NDList(Collection<NDArray> collection) {
        super(collection);
    }

    public NDList(NDArray... nDArrayArr) {
        super(Arrays.asList(nDArrayArr));
    }

    public static NDList decode(NDManager nDManager, byte[] bArr) {
        try {
            DataInputStream dataInputStream = new DataInputStream(new ByteArrayInputStream(bArr));
            try {
                int readInt = dataInputStream.readInt();
                NDList nDList = new NDList(readInt);
                for (int i = 0; i < readInt; i++) {
                    nDList.add(i, nDManager.decode(dataInputStream));
                }
                dataInputStream.close();
                return nDList;
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalArgumentException("Malformed data", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ boolean lambda$asInDevice$0(Device device, NDArray nDArray) {
        return nDArray.getDevice() == device;
    }

    public NDList addAll(NDList nDList) {
        Iterator<NDArray> it = nDList.iterator();
        while (it.hasNext()) {
            add(it.next());
        }
        return this;
    }

    public NDList asInDevice(final Device device, final boolean z) {
        if (!z && stream().allMatch(new Predicate() { // from class: ai.djl.ndarray.-$$Lambda$NDList$yxWSHT4Jvt9quwDyDcNpd4KpIQg
            @Override // java.util.function.Predicate
            public final boolean test(Object obj) {
                return NDList.lambda$asInDevice$0(Device.this, (NDArray) obj);
            }
        })) {
            return this;
        }
        final NDList nDList = new NDList(size());
        forEach(new Consumer() { // from class: ai.djl.ndarray.-$$Lambda$NDList$Bop3BqSXfg4F9NnFWNPI0N7Axqo
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                NDList.this.add(((NDArray) obj).toDevice(device, z));
            }
        });
        return nDList;
    }

    public void attach(final NDManager nDManager) {
        forEach(new Consumer() { // from class: ai.djl.ndarray.-$$Lambda$NDList$mHQWSfIDWLfWAx9cvpLNYAEBZQs
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                ((NDArray) obj).attach(NDManager.this);
            }
        });
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        forEach(new Consumer() { // from class: ai.djl.ndarray.-$$Lambda$-tXC7oPtARAbBFmIEDzju3gzsR4
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                ((NDArray) obj).close();
            }
        });
        clear();
    }

    public boolean contains(String str) {
        Iterator<NDArray> it = iterator();
        while (it.hasNext()) {
            if (str.equals(it.next().getName())) {
                return true;
            }
        }
        return false;
    }

    public void detach() {
        forEach(new Consumer() { // from class: ai.djl.ndarray.-$$Lambda$gsR1aS4rzvDaTj22gFylX6FO0Ao
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                ((NDArray) obj).detach();
            }
        });
    }

    public byte[] encode() {
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            try {
                DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
                dataOutputStream.writeInt(size());
                Iterator<NDArray> it = iterator();
                while (it.hasNext()) {
                    dataOutputStream.write(it.next().encode());
                }
                dataOutputStream.flush();
                byte[] byteArray = byteArrayOutputStream.toByteArray();
                byteArrayOutputStream.close();
                return byteArray;
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalStateException("NDList is not writable", e);
        }
    }

    public NDArray head() {
        return get(0);
    }

    public NDArray remove(String str) {
        Iterator<NDArray> it = iterator();
        int i = 0;
        while (it.hasNext()) {
            NDArray next = it.next();
            if (str.equals(next.getName())) {
                remove(i);
                return next;
            }
            i++;
        }
        return null;
    }

    public NDArray singletonOrThrow() {
        if (size() == 1) {
            return get(0);
        }
        throw new IndexOutOfBoundsException("Incorrect number of elements in NDList.singletonOrThrow: Expected 1 and was " + size());
    }

    public NDList subNDList(int i) {
        return new NDList(subList(i, size()));
    }

    @Override // java.util.AbstractCollection
    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("NDList size: ").append(size()).append('\n');
        Iterator<NDArray> it = iterator();
        int i = 0;
        while (it.hasNext()) {
            NDArray next = it.next();
            String name = next.getName();
            int i2 = i + 1;
            sb.append(i).append(' ');
            if (name != null) {
                sb.append(name);
            }
            sb.append(": ").append(next.getShape()).append(' ').append(next.getDataType()).append('\n');
            i = i2;
        }
        return sb.toString();
    }
}
