package ai.djl.modality.nlp;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: classes.dex */
public class SimpleVocabulary implements Vocabulary {
    private int minFrequency;
    private Set<String> reservedTokens;
    private String unknownToken;
    private Map<String, TokenInfo> tokens = new ConcurrentHashMap();
    private List<String> indexToToken = new ArrayList();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes.dex */
    public static final class TokenInfo {
        int frequency;
        long index = -1;
    }

    /* loaded from: classes.dex */
    public static class VocabularyBuilder {
        protected List<List<String>> sentences = new LinkedList();
        protected Set<String> reservedTokens = new HashSet();
        protected int minFrequency = 10;
        protected String unknownToken = "<unk>";

        public VocabularyBuilder add(List<String> list) {
            this.sentences.add(list);
            return this;
        }

        public VocabularyBuilder addAll(List<List<String>> list) {
            this.sentences.addAll(list);
            return this;
        }

        public SimpleVocabulary build() {
            return new SimpleVocabulary(this);
        }

        public VocabularyBuilder optMinFrequency(int i) {
            this.minFrequency = i;
            return this;
        }

        public VocabularyBuilder optReservedTokens(Collection<String> collection) {
            this.reservedTokens.addAll(collection);
            return this;
        }

        public VocabularyBuilder optUnknownToken(String str) {
            this.unknownToken = str;
            return this;
        }
    }

    public SimpleVocabulary(VocabularyBuilder vocabularyBuilder) {
        this.reservedTokens = vocabularyBuilder.reservedTokens;
        this.minFrequency = vocabularyBuilder.minFrequency;
        String str = vocabularyBuilder.unknownToken;
        this.unknownToken = str;
        this.reservedTokens.add(str);
        Iterator<List<String>> it = vocabularyBuilder.sentences.iterator();
        while (it.hasNext()) {
            addAllTokens(it.next());
        }
    }

    private void addAllTokens(Collection<String> collection) {
        Iterator<String> it = collection.iterator();
        while (it.hasNext()) {
            addToken(it.next());
        }
    }

    private void addToken(String str) {
        if (this.reservedTokens.contains(str)) {
            return;
        }
        TokenInfo orDefault = this.tokens.getOrDefault(str, new TokenInfo());
        int i = orDefault.frequency + 1;
        orDefault.frequency = i;
        if (i == this.minFrequency) {
            orDefault.index = this.indexToToken.size();
            this.indexToToken.add(str);
        }
        this.tokens.put(str, orDefault);
    }

    public List<String> getAllTokens() {
        HashSet hashSet = new HashSet(this.indexToToken);
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(this.reservedTokens);
        hashSet.removeAll(this.reservedTokens);
        arrayList.addAll(hashSet);
        return arrayList;
    }

    @Override // ai.djl.modality.nlp.Vocabulary
    public long getIndex(String str) {
        if (this.tokens.containsKey(str)) {
            return this.tokens.get(str).index;
        }
        return 0L;
    }

    @Override // ai.djl.modality.nlp.Vocabulary
    public String getToken(long j) {
        return (j < 0 || j >= ((long) this.indexToToken.size())) ? this.unknownToken : this.indexToToken.get((int) j);
    }

    public String getUnknownToken() {
        return this.unknownToken;
    }

    public boolean isKnownToken(String str) {
        if (this.reservedTokens.contains(str)) {
            return true;
        }
        return this.tokens.containsKey(str) && this.tokens.get(str).frequency >= this.minFrequency;
    }

    public int size() {
        return this.tokens.size();
    }
}
