package org.nd4j.linalg.primitives;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:org/nd4j/linalg/primitives/Counter.class */
public class Counter<T> implements Serializable {
    private static final long serialVersionUID = 119;
    protected ConcurrentHashMap<T, AtomicDouble> map = new ConcurrentHashMap<>();
    protected AtomicDouble totalCount = new AtomicDouble(0.0f);
    protected AtomicBoolean dirty = new AtomicBoolean(false);

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/linalg/primitives/Counter$PairComparator.class */
    public class PairComparator implements Comparator<Pair<T, Double>> {
        protected PairComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Pair<T, Double> pair, Pair<T, Double> pair2) {
            return Double.compare(pair2.value.doubleValue(), pair.value.doubleValue());
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/primitives/Counter$ReversedPairComparator.class */
    protected class ReversedPairComparator implements Comparator<Pair<T, Double>> {
        protected ReversedPairComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Pair<T, Double> pair, Pair<T, Double> pair2) {
            return Double.compare(pair.value.doubleValue(), pair2.value.doubleValue());
        }
    }

    public double getCount(T t) {
        AtomicDouble atomicDouble = this.map.get(t);
        if (atomicDouble == null) {
            return 0.0d;
        }
        return atomicDouble.get();
    }

    public void incrementCount(T t, double d) {
        AtomicDouble atomicDouble = this.map.get(t);
        if (atomicDouble != null) {
            atomicDouble.addAndGet(d);
        } else {
            this.map.put(t, new AtomicDouble(d));
        }
        this.totalCount.addAndGet(d);
    }

    public void incrementAll(Collection<T> collection, double d) {
        Iterator<T> it = collection.iterator();
        while (it.hasNext()) {
            incrementCount(it.next(), d);
        }
    }

    public <T2 extends T> void incrementAll(Counter<T2> counter) {
        for (T2 t2 : counter.keySet()) {
            incrementCount(t2, counter.getCount(t2));
        }
    }

    public double getProbability(T t) {
        if (totalCount() <= 0.0d) {
            throw new IllegalStateException("Can't calculate probability with empty counter");
        }
        return getCount(t) / totalCount();
    }

    public double setCount(T t, double d) {
        AtomicDouble atomicDouble = this.map.get(t);
        if (atomicDouble != null) {
            return atomicDouble.getAndSet(d);
        }
        this.map.put(t, new AtomicDouble(d));
        return 0.0d;
    }

    public Set<T> keySet() {
        return this.map.keySet();
    }

    public boolean isEmpty() {
        return this.map.size() == 0;
    }

    public Set<Map.Entry<T, AtomicDouble>> entrySet() {
        return this.map.entrySet();
    }

    public List<T> keySetSorted() {
        ArrayList arrayList = new ArrayList();
        PriorityQueue<Pair<T, Double>> asPriorityQueue = asPriorityQueue();
        while (!asPriorityQueue.isEmpty()) {
            arrayList.add(asPriorityQueue.poll().getFirst());
        }
        return arrayList;
    }

    public void normalize() {
        for (T t : keySet()) {
            setCount(t, getCount(t) / this.totalCount.get());
        }
        rebuildTotals();
    }

    protected void rebuildTotals() {
        this.totalCount.set(0.0d);
        Iterator<T> it = keySet().iterator();
        while (it.hasNext()) {
            this.totalCount.addAndGet(getCount(it.next()));
        }
        this.dirty.set(false);
    }

    public double totalCount() {
        if (this.dirty.get()) {
            rebuildTotals();
        }
        return this.totalCount.get();
    }

    public double removeKey(T t) {
        AtomicDouble remove = this.map.remove(t);
        this.dirty.set(true);
        if (remove != null) {
            return remove.get();
        }
        return 0.0d;
    }

    public T argMax() {
        double d = -1.7976931348623157E308d;
        T t = null;
        for (Map.Entry<T, AtomicDouble> entry : this.map.entrySet()) {
            if (entry.getValue().get() > d || t == null) {
                t = entry.getKey();
                d = entry.getValue().get();
            }
        }
        return t;
    }

    public void dropElementsBelowThreshold(double d) {
        Iterator<T> it = keySet().iterator();
        while (it.hasNext()) {
            if (this.map.get(it.next()).get() < d) {
                it.remove();
                this.dirty.set(true);
            }
        }
    }

    public boolean containsElement(T t) {
        return this.map.containsKey(t);
    }

    public void clear() {
        this.map.clear();
        this.totalCount.set(0.0d);
        this.dirty.set(false);
    }

    public int size() {
        return this.map.size();
    }

    public void keepTopNElements(int i) {
        PriorityQueue<Pair<T, Double>> asPriorityQueue = asPriorityQueue();
        clear();
        for (int i2 = 0; i2 < i; i2++) {
            Pair<T, Double> poll = asPriorityQueue.poll();
            if (poll != null) {
                incrementCount(poll.getFirst(), poll.getSecond().doubleValue());
            }
        }
    }

    public PriorityQueue<Pair<T, Double>> asPriorityQueue() {
        PriorityQueue<Pair<T, Double>> priorityQueue = new PriorityQueue<>(this.map.size(), new PairComparator());
        for (Map.Entry<T, AtomicDouble> entry : this.map.entrySet()) {
            priorityQueue.add(Pair.create(entry.getKey(), Double.valueOf(entry.getValue().get())));
        }
        return priorityQueue;
    }

    public PriorityQueue<Pair<T, Double>> asReversedPriorityQueue() {
        PriorityQueue<Pair<T, Double>> priorityQueue = new PriorityQueue<>(this.map.size(), new ReversedPairComparator());
        for (Map.Entry<T, AtomicDouble> entry : this.map.entrySet()) {
            priorityQueue.add(Pair.create(entry.getKey(), Double.valueOf(entry.getValue().get())));
        }
        return priorityQueue;
    }
}
