/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression;

import com.oracle.labs.mlrg.olcut.util.MutableDouble;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.regression.RegressionInfo;
import org.tribuo.regression.Regressor;

public class ImmutableRegressionInfo
extends RegressionInfo
implements ImmutableOutputInfo<Regressor> {
    private static final Logger logger = Logger.getLogger(ImmutableRegressionInfo.class.getName());
    private static final long serialVersionUID = 2L;
    private final Map<Integer, String> idLabelMap;
    private final Map<String, Integer> labelIDMap;
    private final Set<Regressor> domain;
    private transient double[] minArray;
    private transient double[] maxArray;
    private transient double[] meanArray;
    private transient double[] varianceArray;

    private ImmutableRegressionInfo(ImmutableRegressionInfo info) {
        super(info);
        this.idLabelMap = new LinkedHashMap<Integer, String>();
        this.idLabelMap.putAll(info.idLabelMap);
        this.labelIDMap = new LinkedHashMap<String, Integer>();
        this.labelIDMap.putAll(info.labelIDMap);
        this.domain = ImmutableRegressionInfo.calculateDomain(this.minMap);
        this.computeStatisticArrays();
    }

    ImmutableRegressionInfo(RegressionInfo info) {
        super(info);
        this.idLabelMap = new LinkedHashMap<Integer, String>();
        this.labelIDMap = new LinkedHashMap<String, Integer>();
        TreeSet names = new TreeSet(this.countMap.keySet());
        int counter = 0;
        for (String e : names) {
            this.idLabelMap.put(counter, e);
            this.labelIDMap.put(e, counter);
            ++counter;
        }
        this.domain = ImmutableRegressionInfo.calculateDomain(this.minMap);
        this.computeStatisticArrays();
    }

    ImmutableRegressionInfo(RegressionInfo info, Map<Regressor, Integer> mapping) {
        super(info);
        if (mapping.size() != info.size()) {
            throw new IllegalStateException("Mapping and info come from different sources, mapping.size() = " + mapping.size() + ", info.size() = " + info.size());
        }
        Object[] names = new String[mapping.size()];
        for (Map.Entry<Regressor, Integer> e : mapping.entrySet()) {
            Regressor r = e.getKey();
            String[] curNames = r.getNames();
            if (names[e.getValue()] != null) {
                throw new IllegalArgumentException("Mapping must be a bijection, but found two mappings for index " + e.getValue() + ", '" + (String)names[e.getValue()] + "' and '" + curNames[0] + "'");
            }
            if (curNames.length == 1) {
                names[e.getValue().intValue()] = curNames[0];
                continue;
            }
            throw new IllegalArgumentException("Mapping must contain a single regression dimension per id, but contains " + Arrays.toString(names) + " -> " + e.getValue());
        }
        this.idLabelMap = new LinkedHashMap<Integer, String>();
        this.labelIDMap = new LinkedHashMap<String, Integer>();
        for (int i = 0; i < names.length; ++i) {
            this.idLabelMap.put(i, names[i]);
            this.labelIDMap.put(names[i], i);
        }
        if (!this.countMap.keySet().containsAll(this.labelIDMap.keySet()) || !this.labelIDMap.keySet().containsAll(this.countMap.keySet())) {
            throw new IllegalArgumentException("Mapping must contain an entry for each element in the info, found " + this.labelIDMap.keySet() + " and " + this.countMap.keySet());
        }
        this.domain = ImmutableRegressionInfo.calculateDomain(this.minMap);
        this.computeStatisticArrays();
    }

    private static Set<Regressor> calculateDomain(Map<String, MutableDouble> minMap) {
        TreeSet<Regressor.DimensionTuple> outputs = new TreeSet<Regressor.DimensionTuple>(Comparator.comparing(Regressor.DimensionTuple::getName));
        for (Map.Entry<String, MutableDouble> e : minMap.entrySet()) {
            outputs.add(new Regressor.DimensionTuple(e.getKey(), e.getValue().doubleValue()));
        }
        LinkedHashSet<Regressor.DimensionTuple> preSortedOutputs = new LinkedHashSet<Regressor.DimensionTuple>(outputs);
        Set<Regressor> immutableOutputs = Collections.unmodifiableSet(preSortedOutputs);
        return immutableOutputs;
    }

    @Override
    public Set<Regressor> getDomain() {
        return this.domain;
    }

    public int getID(Regressor output) {
        return this.labelIDMap.getOrDefault(output.getDimensionNamesString(), -1);
    }

    public Regressor getOutput(int id) {
        String label = this.idLabelMap.get(id);
        if (label != null) {
            return new Regressor(label, 1.0);
        }
        logger.log(Level.INFO, "No entry found for id " + id);
        return null;
    }

    public double getMin(int id) {
        return this.minArray[id];
    }

    public double getMax(int id) {
        return this.maxArray[id];
    }

    public double getMean(int id) {
        return this.meanArray[id];
    }

    public double getVariance(int id) {
        return this.varianceArray[id];
    }

    public long getTotalObservations() {
        return this.overallCount;
    }

    private void computeStatisticArrays() {
        int size = this.labelIDMap.size();
        this.minArray = new double[size];
        this.maxArray = new double[size];
        this.meanArray = new double[size];
        this.varianceArray = new double[size];
        for (int i = 0; i < size; ++i) {
            String name = this.idLabelMap.get(i);
            this.minArray[i] = ((MutableDouble)this.minMap.get(name)).doubleValue();
            this.maxArray[i] = ((MutableDouble)this.maxMap.get(name)).doubleValue();
            this.meanArray[i] = ((MutableDouble)this.meanMap.get(name)).doubleValue();
            this.varianceArray[i] = ((MutableDouble)this.sumSquaresMap.get(name)).doubleValue() / (double)(((MutableLong)this.countMap.get(name)).longValue() - 1L);
        }
    }

    @Override
    public ImmutableRegressionInfo copy() {
        return new ImmutableRegressionInfo(this);
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("RegressionInfo(");
        for (Map.Entry e : this.countMap.entrySet()) {
            String name = (String)e.getKey();
            long count = ((MutableLong)e.getValue()).longValue();
            builder.append(String.format("{name=%s,id=%d,count=%d,max=%f,min=%f,mean=%f,variance=%f},", name, this.labelIDMap.get(name), count, ((MutableDouble)this.maxMap.get(name)).doubleValue(), ((MutableDouble)this.minMap.get(name)).doubleValue(), ((MutableDouble)this.meanMap.get(name)).doubleValue(), ((MutableDouble)this.sumSquaresMap.get(name)).doubleValue() / (double)(count - 1L)));
        }
        builder.deleteCharAt(builder.length() - 1);
        builder.append(')');
        return builder.toString();
    }

    public boolean validateMapping() {
        Object[] names = new String[this.idLabelMap.size()];
        for (Map.Entry<Integer, String> e : this.idLabelMap.entrySet()) {
            names[e.getKey().intValue()] = e.getValue();
        }
        Object[] sortedNames = Arrays.copyOf(names, names.length);
        Arrays.sort(sortedNames);
        return Arrays.equals(names, sortedNames);
    }

    public int[] getIDtoNaturalOrderMapping() {
        int[] mapping = new int[this.idLabelMap.size()];
        TreeMap<String, Integer> sortedMap = new TreeMap<String, Integer>(String::compareTo);
        sortedMap.putAll(this.labelIDMap);
        int i = 0;
        for (Map.Entry e : sortedMap.entrySet()) {
            mapping[((Integer)e.getValue()).intValue()] = i++;
        }
        return mapping;
    }

    public int[] getNaturalOrderToIDMapping() {
        int[] mapping = new int[this.idLabelMap.size()];
        TreeMap<String, Integer> sortedMap = new TreeMap<String, Integer>(String::compareTo);
        sortedMap.putAll(this.labelIDMap);
        int i = 0;
        for (Map.Entry e : sortedMap.entrySet()) {
            mapping[i] = (Integer)e.getValue();
            ++i;
        }
        return mapping;
    }

    public String toReadableString() {
        return this.toString();
    }

    public Iterator<Pair<Integer, Regressor>> iterator() {
        return new ImmutableInfoIterator(this.idLabelMap);
    }

    public boolean domainAndIDEquals(ImmutableOutputInfo<Regressor> other) {
        if (this.size() == other.size()) {
            for (Map.Entry<Integer, String> e : this.idLabelMap.entrySet()) {
                Regressor otherReg = (Regressor)other.getOutput(e.getKey().intValue());
                if (otherReg == null) {
                    return false;
                }
                if (otherReg.getDimensionNamesString().equals(e.getValue())) continue;
                return false;
            }
            return true;
        }
        return false;
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.computeStatisticArrays();
    }

    private static class ImmutableInfoIterator
    implements Iterator<Pair<Integer, Regressor>> {
        private final Iterator<Map.Entry<Integer, String>> itr;

        public ImmutableInfoIterator(Map<Integer, String> idLabelMap) {
            this.itr = idLabelMap.entrySet().iterator();
        }

        @Override
        public boolean hasNext() {
            return this.itr.hasNext();
        }

        @Override
        public Pair<Integer, Regressor> next() {
            Map.Entry<Integer, String> e = this.itr.next();
            return new Pair((Object)e.getKey(), (Object)new Regressor.DimensionTuple(e.getValue(), 1.0));
        }
    }
}

