package ai.djl.ndarray.index;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.types.DataType;
import com.tianshaokai.mathkeyboard.manager.LatexConstant;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;

/* loaded from: classes.dex */
public class NDIndex {
    private static final Pattern ITEM_PATTERN = Pattern.compile("(\\*)|((-?\\d+|\\{})?:(-?\\d+|\\{})?(:(-?\\d+|\\{}))?)|(-?\\d+|\\{})");
    private int ellipsisIndex;
    private List<NDIndexElement> indices;
    private int rank;

    public NDIndex() {
        this.rank = 0;
        this.indices = new ArrayList();
        this.ellipsisIndex = -1;
    }

    public NDIndex(String str, Object... objArr) {
        this();
        addIndices(str, objArr);
    }

    public NDIndex(long... jArr) {
        this();
        addIndices(jArr);
    }

    private int addIndexItem(String str, int i, Object[] objArr) {
        Long l;
        Long l2;
        String trim = str.trim();
        Matcher matcher = ITEM_PATTERN.matcher(trim);
        if (!matcher.matches()) {
            throw new IllegalArgumentException("Invalid argument index: " + trim);
        }
        if (matcher.group(1) != null) {
            this.indices.add(new NDIndexAll());
            return i;
        }
        String group = matcher.group(7);
        if (group != null) {
            if (!"{}".equals(group)) {
                this.indices.add(new NDIndexFixed(Long.parseLong(group)));
                return i;
            }
            Object obj = objArr[i];
            if (obj instanceof Integer) {
                this.indices.add(new NDIndexFixed(((Integer) obj).intValue()));
                return i + 1;
            }
            if (obj instanceof Long) {
                this.indices.add(new NDIndexFixed(((Long) obj).longValue()));
                return i + 1;
            }
            if (obj instanceof NDArray) {
                NDArray nDArray = (NDArray) obj;
                if (nDArray.getDataType() == DataType.BOOLEAN) {
                    this.indices.add(new NDIndexBooleans(nDArray));
                    return i + 1;
                }
                if (nDArray.getDataType().isInteger()) {
                    this.indices.add(new NDIndexPick(nDArray));
                    return i + 1;
                }
            }
            throw new IllegalArgumentException("Unknown argument: " + obj);
        }
        Long l3 = null;
        if (matcher.group(3) != null) {
            l = parseSliceItem(matcher.group(3), i, objArr);
            if ("{}".equals(matcher.group(3))) {
                i++;
            }
        } else {
            l = null;
        }
        if (matcher.group(4) != null) {
            l2 = parseSliceItem(matcher.group(4), i, objArr);
            if ("{}".equals(matcher.group(4))) {
                i++;
            }
        } else {
            l2 = null;
        }
        if (matcher.group(6) != null) {
            l3 = parseSliceItem(matcher.group(6), i, objArr);
            if ("{}".equals(matcher.group(6))) {
                i++;
            }
        }
        if (l == null && l2 == null && l3 == null) {
            this.indices.add(new NDIndexAll());
        } else {
            this.indices.add(new NDIndexSlice(l, l2, l3));
        }
        return i;
    }

    private Long parseSliceItem(String str, int i, Object... objArr) {
        if (!"{}".equals(str)) {
            return Long.valueOf(Long.parseLong(str));
        }
        Object obj = objArr[i];
        if (obj instanceof Integer) {
            return Long.valueOf(((Integer) obj).longValue());
        }
        if (obj instanceof Long) {
            return (Long) obj;
        }
        throw new IllegalArgumentException("Unknown slice argument: " + obj);
    }

    public static NDIndex sliceAxis(int i, long j, long j2) {
        NDIndex nDIndex = new NDIndex();
        for (int i2 = 0; i2 < i; i2++) {
            nDIndex.addAllDim();
        }
        nDIndex.addSliceDim(j, j2);
        return nDIndex;
    }

    public NDIndex addAllDim() {
        this.rank++;
        this.indices.add(new NDIndexAll());
        return this;
    }

    public NDIndex addAllDim(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The number of index dimensions to add can't be negative");
        }
        this.rank += i;
        for (int i2 = 0; i2 < i; i2++) {
            this.indices.add(new NDIndexAll());
        }
        return this;
    }

    public NDIndex addBooleanIndex(NDArray nDArray) {
        this.rank += nDArray.getShape().dimension();
        this.indices.add(new NDIndexBooleans(nDArray));
        return this;
    }

    public final NDIndex addIndices(String str, Object... objArr) {
        String[] split = str.split(LatexConstant.COMMA);
        this.rank += split.length;
        int i = 0;
        for (int i2 = 0; i2 < split.length; i2++) {
            if (!split[i2].trim().equals("...")) {
                i = addIndexItem(split[i2], i, objArr);
            } else {
                if (this.ellipsisIndex != -1) {
                    throw new IllegalArgumentException("an index can only have a single ellipsis (\"...\")");
                }
                this.ellipsisIndex = i2;
            }
        }
        if (this.ellipsisIndex != -1) {
            this.rank--;
        }
        if (i == objArr.length) {
            return this;
        }
        throw new IllegalArgumentException("Incorrect number of index arguments");
    }

    public final NDIndex addIndices(long... jArr) {
        this.rank += jArr.length;
        for (long j : jArr) {
            this.indices.add(new NDIndexFixed(j));
        }
        return this;
    }

    public NDIndex addPickDim(NDArray nDArray) {
        this.rank++;
        this.indices.add(new NDIndexPick(nDArray));
        return this;
    }

    public NDIndex addSliceDim(long j, long j2) {
        this.rank++;
        this.indices.add(new NDIndexSlice(Long.valueOf(j), Long.valueOf(j2), null));
        return this;
    }

    public NDIndex addSliceDim(long j, long j2, long j3) {
        this.rank++;
        this.indices.add(new NDIndexSlice(Long.valueOf(j), Long.valueOf(j2), Long.valueOf(j3)));
        return this;
    }

    public NDIndexElement get(int i) {
        return this.indices.get(i);
    }

    public int getEllipsisIndex() {
        return this.ellipsisIndex;
    }

    public List<NDIndexElement> getIndices() {
        return this.indices;
    }

    public int getRank() {
        return this.rank;
    }

    public Stream<NDIndexElement> stream() {
        return this.indices.stream();
    }
}
