/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.ml.regression.linear;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import ru.ispras.ml.TrainingException;
import ru.ispras.ml.datamodel.Feature;
import ru.ispras.ml.datamodel.ILabeledDataset;
import ru.ispras.ml.datamodel.ILabelledInstance;
import ru.ispras.ml.datamodel.value.IValue;
import ru.ispras.ml.datamodel.value.cast.DoubleValueCaster;
import ru.ispras.ml.regression.IRegressor;
import ru.ispras.ml.regression.linear.ILinearRegressorTrainer;
import ru.ispras.ml.regression.linear.InSampleQualityMetric;
import ru.ispras.ml.regression.linear.LinearRegressor;
import ru.ispras.ml.regression.linear.LinearRegressorModel;
import ru.ispras.ml.regression.linear.featurefunctions.IFeatureFunction;
import ru.ispras.ml.regression.linear.featurefunctions.PolynomialFeatureFunction;

public class LinearRegressorTrainer
implements ILinearRegressorTrainer {
    private static final long serialVersionUID = -1849786724303907348L;
    private LinearRegressorModel model;
    private final IFeatureFunction featureFunction;

    public LinearRegressorTrainer(IFeatureFunction featureFunction) {
        this.featureFunction = featureFunction;
    }

    public LinearRegressorTrainer() {
        this(new PolynomialFeatureFunction(1));
    }

    @Override
    public void train(ILabelledInstance<Double> isntance) throws TrainingException {
        throw new UnsupportedOperationException();
    }

    @Override
    public void train(ILabeledDataset<Double> precedents) throws TrainingException {
        if (precedents.size() < 2L) {
            throw new IllegalArgumentException("Not enough data for training (at least 2 samples are needed)");
        }
        ArrayList<Feature> orderedFeatures = new ArrayList<Feature>(precedents.getFeatureSet());
        Collections.sort(orderedFeatures, new Comparator<Feature>(){

            @Override
            public int compare(Feature f1, Feature f2) {
                String name1 = f1.getName();
                String name2 = f2.getName();
                return name1.compareTo(name2);
            }
        });
        Matrices matrices = this.createMatrices(precedents, orderedFeatures);
        DoubleMatrix2D precedentsMatrix = matrices.precedentsMatrix;
        DoubleMatrix1D answersVector = matrices.answersVector;
        DoubleMatrix1D coefficients = this.computeRegressionCoefficients(precedentsMatrix, answersVector);
        Map<Feature, List<Double>> coefficientsMap = this.mapFeaturesToFucntionCoefficients(coefficients, orderedFeatures);
        InSampleQualityMetric qualityMetric = new InSampleQualityMetric(precedentsMatrix, coefficients, answersVector);
        double freeCoefficient = coefficients.get(0);
        this.model = new LinearRegressorModel(coefficientsMap, freeCoefficient, this.featureFunction, qualityMetric);
    }

    private DoubleMatrix1D computeRegressionCoefficients(DoubleMatrix2D precedentsMatrix, DoubleMatrix1D answersVector) throws TrainingException {
        Algebra algebra = new Algebra();
        DoubleMatrix2D transposedPrecedentsMatrix = precedentsMatrix.viewDice();
        DoubleMatrix2D f = algebra.mult(transposedPrecedentsMatrix, precedentsMatrix);
        try {
            f = algebra.inverse(f);
        }
        catch (IllegalArgumentException e) {
            String message = "Feature values for precedents are linear dependent";
            throw new TrainingException(message, e);
        }
        f = algebra.mult(f, transposedPrecedentsMatrix);
        DoubleMatrix1D coefficients = algebra.mult(f, answersVector);
        return coefficients;
    }

    private Map<Feature, List<Double>> mapFeaturesToFucntionCoefficients(DoubleMatrix1D coefficients, List<Feature> orderedFeatures) {
        HashMap<Feature, List<Double>> coefficientsMap = new HashMap<Feature, List<Double>>();
        int count = 1;
        for (Feature currentFeature : orderedFeatures) {
            ArrayList<Double> functionCoefficients = new ArrayList<Double>(this.featureFunction.getCoefficientsNumber());
            for (int i = 0; i < this.featureFunction.getCoefficientsNumber(); ++i) {
                functionCoefficients.add(coefficients.get(count));
                ++count;
            }
            coefficientsMap.put(currentFeature, functionCoefficients);
        }
        return coefficientsMap;
    }

    Matrices createMatrices(ILabeledDataset<Double> precedents, List<Feature> orderedFeatures) {
        int numberOfColumns = orderedFeatures.size() * this.featureFunction.getCoefficientsNumber();
        int numberOfPrecedents = (int)precedents.size();
        DenseDoubleMatrix2D matrix = new DenseDoubleMatrix2D(numberOfPrecedents, numberOfColumns + 1);
        DenseDoubleMatrix1D vector = new DenseDoubleMatrix1D(numberOfPrecedents);
        try (Stream stream = precedents.stream();){
            Iterator it = stream.iterator();
            int currentRow = 0;
            while (it.hasNext()) {
                ILabelledInstance instance = (ILabelledInstance)it.next();
                matrix.set(currentRow, 0, 1.0);
                int currentColumn = 1;
                for (Feature currentFeature : orderedFeatures) {
                    double[] values;
                    IValue featureValue = instance.getFeaturesValues().get(currentFeature);
                    double doubleFeatureValue = 0.0;
                    if (featureValue != null) {
                        doubleFeatureValue = DoubleValueCaster.cast(featureValue).getDouble();
                    }
                    for (double currentValue : values = this.featureFunction.fn(doubleFeatureValue)) {
                        matrix.set(currentRow, currentColumn, currentValue);
                        ++currentColumn;
                    }
                }
                double value = (Double)instance.getLabel();
                vector.set(currentRow, value);
                ++currentRow;
            }
        }
        return new Matrices((DoubleMatrix2D)matrix, (DoubleMatrix1D)vector);
    }

    public boolean hasModel() {
        return this.model != null;
    }

    @Override
    public LinearRegressorModel getModel() {
        return this.model;
    }

    public IRegressor getPredictor() {
        return new LinearRegressor(this.model);
    }

    public static class Matrices {
        public DoubleMatrix2D precedentsMatrix;
        public DoubleMatrix1D answersVector;

        public Matrices(DoubleMatrix2D precedentsMatrix, DoubleMatrix1D answersVector) {
            this.precedentsMatrix = precedentsMatrix;
            this.answersVector = answersVector;
        }
    }
}

