/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.ml.liblinear.dense;

import de.bwaldvogel.liblinear.SolverType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ru.ispras.ml.TrainingException;
import ru.ispras.ml.classification.IConfidenceClassifierTrainer;
import ru.ispras.ml.datamodel.Feature;
import ru.ispras.ml.datamodel.ILabeledDataset;
import ru.ispras.ml.datamodel.ILabelledInstance;
import ru.ispras.ml.liblinear.classification.LibLinearClassifierModel;
import ru.ispras.ml.liblinear.classification.LibLinearClassifierTrainer;
import ru.ispras.ml.liblinear.dense.DenseLibLinearClassifier;
import ru.ispras.ml.liblinear.dense.DenseLibLinearClassifierModel;

public class DenseLibLinearClassifierTrainer<Label>
implements IConfidenceClassifierTrainer<Label> {
    private static final long serialVersionUID = 171581932016553634L;
    private final LibLinearClassifierTrainer<Label> trainer;
    private final double minFeatureWeight;
    private DenseLibLinearClassifierModel<Label> denseModel = null;

    public DenseLibLinearClassifierTrainer(LibLinearClassifierTrainer<Label> trainer, double minFeatureWeight) {
        this.trainer = trainer;
        this.minFeatureWeight = minFeatureWeight;
    }

    public DenseLibLinearClassifierTrainer(LibLinearClassifierTrainer<Label> trainer) {
        this(trainer, Double.MIN_VALUE);
    }

    @Override
    public DenseLibLinearClassifier<Label> getPredictor() {
        return new DenseLibLinearClassifier<Label>(this.denseModel);
    }

    @Override
    public void train(ILabelledInstance<Label> precedent) throws TrainingException {
        throw new UnsupportedOperationException("On-the-fly learning is not supported.");
    }

    @Override
    public void train(ILabeledDataset<Label> data) throws TrainingException {
        LibLinearClassifierModel<Label> model = this.getLibLinearModel(data);
        this.denseModel = this.convertToDenseModel(model);
    }

    private LibLinearClassifierModel<Label> getLibLinearModel(ILabeledDataset<Label> data) throws TrainingException {
        this.trainer.train(data);
        LibLinearClassifierModel model = (LibLinearClassifierModel)this.trainer.getModel();
        return model;
    }

    private DenseLibLinearClassifierModel<Label> convertToDenseModel(LibLinearClassifierModel<Label> model) {
        List<Label> labels = this.getLabels(model);
        double[] weights = model.getModel().getFeatureWeights();
        Map<Feature, List<DenseLibLinearClassifierModel.WeightedLabelIndex>> features = this.getFeatures(model, weights);
        List<DenseLibLinearClassifierModel.WeightedLabelIndex> biases = this.getBiases(model, weights);
        return new DenseLibLinearClassifierModel<Label>(labels, features, biases);
    }

    private List<Label> getLabels(LibLinearClassifierModel<Label> model) {
        ArrayList<Object> result = new ArrayList<Object>(Collections.nCopies(model.getIndicesToClasses().size(), null));
        for (Map.Entry<Integer, Label> e : model.getIndicesToClasses().entrySet()) {
            result.set(e.getKey(), e.getValue());
        }
        return result;
    }

    private Map<Feature, List<DenseLibLinearClassifierModel.WeightedLabelIndex>> getFeatures(LibLinearClassifierModel<Label> model, double[] weights) {
        HashMap<Feature, List<DenseLibLinearClassifierModel.WeightedLabelIndex>> result = new HashMap<Feature, List<DenseLibLinearClassifierModel.WeightedLabelIndex>>();
        int numberOfClassesInWeights = this.getNumberOfClassesInWeights(model);
        for (Map.Entry<Feature, Integer> featurePair : model.getFeaturesOrder().entrySet()) {
            Feature feature = featurePair.getKey();
            int featureBasePosition = featurePair.getValue() - 1;
            for (int i = 0; i < numberOfClassesInWeights; ++i) {
                double weight = weights[numberOfClassesInWeights * featureBasePosition + i];
                if (!this.useWeight(weight)) continue;
                this.getWeightedLabelIndexList(result, feature).add(new DenseLibLinearClassifierModel.WeightedLabelIndex(i, weight));
            }
        }
        return result;
    }

    private List<DenseLibLinearClassifierModel.WeightedLabelIndex> getWeightedLabelIndexList(Map<Feature, List<DenseLibLinearClassifierModel.WeightedLabelIndex>> map, Feature feature) {
        if (map.containsKey(feature)) {
            return map.get(feature);
        }
        ArrayList<DenseLibLinearClassifierModel.WeightedLabelIndex> result = new ArrayList<DenseLibLinearClassifierModel.WeightedLabelIndex>(1);
        map.put(feature, result);
        return result;
    }

    private List<DenseLibLinearClassifierModel.WeightedLabelIndex> getBiases(LibLinearClassifierModel<Label> model, double[] weights) {
        ArrayList<DenseLibLinearClassifierModel.WeightedLabelIndex> result = new ArrayList<DenseLibLinearClassifierModel.WeightedLabelIndex>(1);
        double bias = model.getModel().getBias();
        if (bias >= 0.0) {
            int biasBasePosition = model.getFeaturesOrder().size();
            int numberOfClassesInWeights = this.getNumberOfClassesInWeights(model);
            for (int i = 0; i < numberOfClassesInWeights; ++i) {
                double weight = weights[numberOfClassesInWeights * biasBasePosition + i];
                if (!this.useWeight(weight)) continue;
                result.add(new DenseLibLinearClassifierModel.WeightedLabelIndex(i, bias * weight));
            }
        }
        return result;
    }

    private int getNumberOfClassesInWeights(LibLinearClassifierModel<Label> model) {
        int numberOfClasses = model.getIndicesToClasses().size();
        if (this.isSimpleTwoClassProblemTrainer(numberOfClasses)) {
            return 1;
        }
        return numberOfClasses;
    }

    private boolean isSimpleTwoClassProblemTrainer(int numberOfClasses) {
        return numberOfClasses == 2 && this.trainer.getSolverType() != SolverType.MCSVM_CS;
    }

    private boolean useWeight(double weight) {
        return Math.abs(weight) >= this.minFeatureWeight;
    }
}

