/*
 * Decompiled with CFR 0.152.
 */
package ciir.umass.edu.learning.boosting;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.learning.boosting.WeakRanker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.KeyValuePair;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

public class AdaRank
extends Ranker {
    public static int nIteration = 500;
    public static double tolerance = 0.002;
    public static boolean trainWithEnqueue = true;
    public static int maxSelCount = 5;
    protected HashMap<Integer, Integer> usedFeatures = new HashMap();
    protected double[] sweight = null;
    protected List<WeakRanker> rankers = null;
    protected List<Double> rweight = null;
    protected List<WeakRanker> bestModelRankers = null;
    protected List<Double> bestModelWeights = null;
    int lastFeature = -1;
    int lastFeatureConsecutiveCount = 0;
    boolean performanceChanged = false;
    List<Integer> featureQueue = null;
    protected double[] backupSampleWeight = null;
    protected double backupTrainScore = 0.0;
    protected double lastTrainedScore = -1.0;

    public AdaRank() {
    }

    public AdaRank(List<RankList> samples, int[] features, MetricScorer scorer) {
        super(samples, features, scorer);
    }

    private void updateBestModelOnValidation() {
        this.bestModelRankers.clear();
        this.bestModelRankers.addAll(this.rankers);
        this.bestModelWeights.clear();
        this.bestModelWeights.addAll(this.rweight);
    }

    private WeakRanker learnWeakRanker() {
        double bestScore = -1.0;
        WeakRanker bestWR = null;
        for (int i : this.features) {
            if (this.featureQueue.contains(i) || this.usedFeatures.get(i) != null) continue;
            WeakRanker wr = new WeakRanker(i);
            double s = 0.0;
            for (int j = 0; j < this.samples.size(); ++j) {
                double t = this.scorer.score(wr.rank((RankList)this.samples.get(j))) * this.sweight[j];
                s += t;
            }
            if (!(bestScore < s)) continue;
            bestScore = s;
            bestWR = wr;
        }
        return bestWR;
    }

    private int learn(int startIteration, boolean withEnqueue) {
        int t;
        for (t = startIteration; t <= nIteration; ++t) {
            String status;
            this.PRINT(new int[]{7}, new String[]{t + ""});
            WeakRanker bestWR = this.learnWeakRanker();
            if (bestWR == null) break;
            if (withEnqueue) {
                if (bestWR.getFID() == this.lastFeature) {
                    this.featureQueue.add(this.lastFeature);
                    this.rankers.remove(this.rankers.size() - 1);
                    this.rweight.remove(this.rweight.size() - 1);
                    this.copy(this.backupSampleWeight, this.sweight);
                    this.bestScoreOnValidationData = 0.0;
                    this.lastTrainedScore = this.backupTrainScore;
                    this.PRINTLN(new int[]{8, 9, 9, 9}, new String[]{bestWR.getFID() + "", "", "", "ROLLBACK"});
                    continue;
                }
                this.lastFeature = bestWR.getFID();
                this.copy(this.sweight, this.backupSampleWeight);
                this.backupTrainScore = this.lastTrainedScore;
            }
            double num = 0.0;
            double denom = 0.0;
            for (int i = 0; i < this.samples.size(); ++i) {
                double tmp = this.scorer.score(bestWR.rank((RankList)this.samples.get(i)));
                num += this.sweight[i] * (1.0 + tmp);
                denom += this.sweight[i] * (1.0 - tmp);
            }
            this.rankers.add(bestWR);
            double alpha_t = 0.5 * SimpleMath.ln(num / denom);
            this.rweight.add(alpha_t);
            double trainedScore = 0.0;
            double total = 0.0;
            for (RankList sample : this.samples) {
                double tmp = this.scorer.score(this.rank(sample));
                total += Math.exp(-alpha_t * tmp);
                trainedScore += tmp;
            }
            double delta = (trainedScore /= (double)this.samples.size()) + tolerance - this.lastTrainedScore;
            String string = status = delta > 0.0 ? "OK" : "DAMN";
            if (!withEnqueue) {
                if (trainedScore != this.lastTrainedScore) {
                    this.performanceChanged = true;
                    this.lastFeatureConsecutiveCount = 0;
                    this.usedFeatures.clear();
                } else {
                    this.performanceChanged = false;
                    if (this.lastFeature == bestWR.getFID()) {
                        ++this.lastFeatureConsecutiveCount;
                        if (this.lastFeatureConsecutiveCount == maxSelCount) {
                            status = "F. REM.";
                            this.lastFeatureConsecutiveCount = 0;
                            this.usedFeatures.put(this.lastFeature, 1);
                        }
                    } else {
                        this.lastFeatureConsecutiveCount = 0;
                        this.usedFeatures.clear();
                    }
                }
                this.lastFeature = bestWR.getFID();
            }
            this.PRINT(new int[]{8, 9}, new String[]{bestWR.getFID() + "", SimpleMath.round(trainedScore, 4) + ""});
            if (t % 1 == 0 && this.validationSamples != null) {
                double scoreOnValidation = this.scorer.score(this.rank(this.validationSamples));
                if (scoreOnValidation > this.bestScoreOnValidationData) {
                    this.bestScoreOnValidationData = scoreOnValidation;
                    this.updateBestModelOnValidation();
                }
                this.PRINT(new int[]{9, 9}, new String[]{SimpleMath.round(scoreOnValidation, 4) + "", status});
            } else {
                this.PRINT(new int[]{9, 9}, new String[]{"", status});
            }
            this.PRINTLN("");
            if (delta <= 0.0) {
                this.rankers.remove(this.rankers.size() - 1);
                this.rweight.remove(this.rweight.size() - 1);
                break;
            }
            this.lastTrainedScore = trainedScore;
            for (int i = 0; i < this.sweight.length; ++i) {
                int n = i;
                this.sweight[n] = this.sweight[n] * (Math.exp(-alpha_t * this.scorer.score(this.rank((RankList)this.samples.get(i)))) / total);
            }
        }
        return t;
    }

    @Override
    public void init() {
        this.PRINT("Initializing... ");
        this.usedFeatures.clear();
        this.sweight = new double[this.samples.size()];
        for (int i = 0; i < this.sweight.length; ++i) {
            this.sweight[i] = 1.0f / (float)this.samples.size();
        }
        this.backupSampleWeight = new double[this.sweight.length];
        this.copy(this.sweight, this.backupSampleWeight);
        this.lastTrainedScore = -1.0;
        this.rankers = new ArrayList<WeakRanker>();
        this.rweight = new ArrayList<Double>();
        this.featureQueue = new ArrayList<Integer>();
        this.bestScoreOnValidationData = 0.0;
        this.bestModelRankers = new ArrayList<WeakRanker>();
        this.bestModelWeights = new ArrayList<Double>();
        this.PRINTLN("[Done]");
    }

    @Override
    public void learn() {
        this.PRINTLN("---------------------------");
        this.PRINTLN("Training starts...");
        this.PRINTLN("--------------------------------------------------------");
        this.PRINTLN(new int[]{7, 8, 9, 9, 9}, new String[]{"#iter", "Sel. F.", this.scorer.name() + "-T", this.scorer.name() + "-V", "Status"});
        this.PRINTLN("--------------------------------------------------------");
        if (trainWithEnqueue) {
            int t = this.learn(1, true);
            for (int i = this.featureQueue.size() - 1; i >= 0; --i) {
                this.featureQueue.remove(i);
                t = this.learn(t, false);
            }
        } else {
            this.learn(1, false);
        }
        if (this.validationSamples != null && this.bestModelRankers.size() > 0) {
            this.rankers.clear();
            this.rweight.clear();
            this.rankers.addAll(this.bestModelRankers);
            this.rweight.addAll(this.bestModelWeights);
        }
        this.scoreOnTrainingData = SimpleMath.round(this.scorer.score(this.rank(this.samples)), 4);
        this.PRINTLN("--------------------------------------------------------");
        this.PRINTLN("Finished sucessfully.");
        this.PRINTLN(this.scorer.name() + " on training data: " + this.scoreOnTrainingData);
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(this.rank(this.validationSamples));
            this.PRINTLN(this.scorer.name() + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        this.PRINTLN("---------------------------------");
    }

    @Override
    public double eval(DataPoint p) {
        double score = 0.0;
        for (int j = 0; j < this.rankers.size(); ++j) {
            score += this.rweight.get(j) * (double)p.getFeatureValue(this.rankers.get(j).getFID());
        }
        return score;
    }

    @Override
    public Ranker createNew() {
        return new AdaRank();
    }

    @Override
    public String toString() {
        String output = "";
        for (int i = 0; i < this.rankers.size(); ++i) {
            output = output + this.rankers.get(i).getFID() + ":" + this.rweight.get(i) + (i == this.rankers.size() - 1 ? "" : " ");
        }
        return output;
    }

    @Override
    public String model() {
        String output = "## " + this.name() + "\n";
        output = output + "## Iteration = " + nIteration + "\n";
        output = output + "## Train with enqueue: " + (trainWithEnqueue ? "Yes" : "No") + "\n";
        output = output + "## Tolerance = " + tolerance + "\n";
        output = output + "## Max consecutive selection count = " + maxSelCount + "\n";
        output = output + this.toString();
        return output;
    }

    @Override
    public void loadFromString(String fullText) {
        try (BufferedReader in = new BufferedReader(new StringReader(fullText));){
            String content = "";
            KeyValuePair kvp = null;
            while ((content = in.readLine()) != null) {
                if ((content = content.trim()).length() == 0 || content.indexOf("##") == 0) continue;
                kvp = new KeyValuePair(content);
                break;
            }
            assert (kvp != null);
            List<String> keys = kvp.keys();
            List<String> values = kvp.values();
            this.rweight = new ArrayList<Double>();
            this.rankers = new ArrayList<WeakRanker>();
            this.features = new int[keys.size()];
            for (int i = 0; i < keys.size(); ++i) {
                this.features[i] = Integer.parseInt(keys.get(i));
                this.rankers.add(new WeakRanker(this.features[i]));
                this.rweight.add(Double.parseDouble(values.get(i)));
            }
        }
        catch (Exception ex) {
            throw RankLibError.create("Error in AdaRank::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        this.PRINTLN("No. of rounds: " + nIteration);
        this.PRINTLN("Train with 'enequeue': " + (trainWithEnqueue ? "Yes" : "No"));
        this.PRINTLN("Tolerance: " + tolerance);
        this.PRINTLN("Max Sel. Count: " + maxSelCount);
    }

    @Override
    public String name() {
        return "AdaRank";
    }
}

