package ai.djl.training;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.training.optimizer.Optimizer;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: classes.dex */
public class LocalParameterServer implements ParameterServer {
    private Map<String, NDArray[]> gradMap = new ConcurrentHashMap();
    private Optimizer optimizer;

    public LocalParameterServer(Optimizer optimizer) {
        this.optimizer = optimizer;
    }

    @Override // ai.djl.training.ParameterServer, java.lang.AutoCloseable
    public void close() {
    }

    @Override // ai.djl.training.ParameterServer
    public void init(String str, NDArray[] nDArrayArr) {
    }

    @Override // ai.djl.training.ParameterServer
    public void pull(String str, NDArray[] nDArrayArr, int i) {
        NDArray device;
        NDArray[] nDArrayArr2 = this.gradMap.get(str);
        Device device2 = nDArrayArr2[0].getDevice();
        for (int i2 = 1; i2 < nDArrayArr2.length; i2++) {
            device = nDArrayArr2[i2].toDevice(device2, true);
            try {
                nDArrayArr2[0].addi(device);
                if (device != null) {
                    device.close();
                }
            } finally {
            }
        }
        NDArray duplicate = nDArrayArr2[0].duplicate();
        try {
            for (NDArray nDArray : nDArrayArr) {
                if (nDArray.getDevice().equals(device2)) {
                    this.optimizer.update(str, nDArray, duplicate);
                } else {
                    device = duplicate.toDevice(nDArray.getDevice(), true);
                    try {
                        this.optimizer.update(str, nDArray, device);
                        if (device != null) {
                            device.close();
                        }
                    } finally {
                    }
                }
            }
            if (duplicate != null) {
                duplicate.close();
            }
            Arrays.stream(nDArrayArr2).forEach($$Lambda$tXC7oPtARAbBFmIEDzju3gzsR4.INSTANCE);
        } catch (Throwable th) {
            try {
                throw th;
            } catch (Throwable th2) {
                if (duplicate != null) {
                    if (th != null) {
                        try {
                            duplicate.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        duplicate.close();
                    }
                }
                throw th2;
            }
        }
    }

    @Override // ai.djl.training.ParameterServer
    public void push(String str, NDArray[] nDArrayArr, int i) {
        NDArray[] put = this.gradMap.put(str, nDArrayArr);
        if (put != null) {
            Arrays.stream(put).forEach($$Lambda$tXC7oPtARAbBFmIEDzju3gzsR4.INSTANCE);
        }
    }
}
