package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;

/* loaded from: classes.dex */
public class SimplePoseTranslator extends BaseImageTranslator<Joints> {
    private float threshold;

    /* loaded from: classes.dex */
    public static class Builder extends BaseImageTranslator.BaseBuilder<Builder> {
        float threshold;

        Builder() {
        }

        public SimplePoseTranslator build() {
            validate();
            return new SimplePoseTranslator(this);
        }

        public Builder optThreshold(float f) {
            this.threshold = f;
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // ai.djl.modality.cv.translator.BaseImageTranslator.BaseBuilder
        public Builder self() {
            return this;
        }
    }

    public SimplePoseTranslator(Builder builder) {
        super(builder);
        this.threshold = builder.threshold;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // ai.djl.translate.PostProcessor
    public Joints processOutput(TranslatorContext translatorContext, NDList nDList) {
        int i;
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        int i2 = 0;
        int i3 = (int) singletonOrThrow.getShape().get(0);
        int i4 = 1;
        int i5 = (int) singletonOrThrow.getShape().get(1);
        int i6 = (int) singletonOrThrow.getShape().get(2);
        long j = i3;
        NDArray reshape = singletonOrThrow.reshape(new Shape(1, j, -1));
        NDArray type = reshape.argMax(2).reshape(new Shape(1, j, -1)).toType(DataType.FLOAT32, false);
        NDArray max = reshape.max(new int[]{2}, true);
        NDArray tile = type.tile(2, 2L);
        tile.set(new NDIndex(":, :, 0", new Object[0]), tile.get(":, :, 0", new Object[0]).mod(Integer.valueOf(i6)));
        tile.set(new NDIndex(":, :, 1", new Object[0]), tile.get(":, :, 1", new Object[0]).div(Integer.valueOf(i6)).floor());
        float[] floatArray = tile.get(max.gt(Double.valueOf(0.0d)).toType(DataType.UINT8, false).tile(2, 2L).toType(DataType.BOOLEAN, false)).toFloatArray();
        float[] floatArray2 = max.toFloatArray();
        ArrayList arrayList = new ArrayList(i3);
        while (i2 < i3) {
            if (floatArray2[i2] > this.threshold) {
                int i7 = i2 * 2;
                i = i6;
                arrayList.add(new Joints.Joint(floatArray[i7] / i6, floatArray[i7 + i4] / i5, floatArray2[i2]));
            } else {
                i = i6;
            }
            i2++;
            i6 = i;
            i4 = 1;
        }
        return new Joints(arrayList);
    }
}
