/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.ml.classification.decisiontrees.id3;

import java.io.Serializable;
import java.util.Random;
import ru.ispras.ml.TrainingException;
import ru.ispras.ml.classification.ClassificationException;
import ru.ispras.ml.classification.decisiontrees.DecisionTreeClassifier;
import ru.ispras.ml.classification.decisiontrees.DecisionTreeModel;
import ru.ispras.ml.classification.decisiontrees.FeatureValuesInSample;
import ru.ispras.ml.classification.decisiontrees.Sample;
import ru.ispras.ml.classification.decisiontrees.id3.ID3Trainer;
import ru.ispras.ml.datamodel.ILabeledDataset;
import ru.ispras.ml.datamodel.ILabelledInstance;
import ru.ispras.ml.prediction.PredictionResult;

public class IterativeID3Trainer<LabelType>
extends ID3Trainer<LabelType> {
    private static final long serialVersionUID = 4350046375452166009L;
    private static final double PROBABILITY_TO_GET_IN_WINDOW = 0.2;

    @Override
    public void train(ILabeledDataset<LabelType> precedents) throws TrainingException {
        Sample<LabelType> sample = new Sample<LabelType>(precedents);
        FeatureValuesInSample featureSet = sample.getFeatureSet();
        WindowedSample<LabelType> windowedSample = new WindowedSample<LabelType>(sample);
        DecisionTreeModel currentModel = null;
        boolean nextStepRequired = true;
        while (nextStepRequired) {
            currentModel = new DecisionTreeModel(this.treeBuilder.buildSubtree(windowedSample.getTrainSample(), featureSet, null));
            if (windowedSample.getTestSample().isEmpty()) break;
            DecisionTreeClassifier classifier = new DecisionTreeClassifier(currentModel);
            nextStepRequired = this.validateOnTestSample(classifier, featureSet, windowedSample);
        }
        this.model = currentModel;
    }

    private boolean validateOnTestSample(DecisionTreeClassifier<LabelType> classifier, FeatureValuesInSample featureSet, WindowedSample<LabelType> windowedSample) {
        boolean nextStepRequired = false;
        Sample<LabelType> newTestSample = new Sample<LabelType>();
        for (ILabelledInstance<LabelType> precedent : windowedSample.getTestSample().getObjects()) {
            PredictionResult result = null;
            try {
                result = classifier.predict(precedent);
            }
            catch (ClassificationException e) {
                // empty catch block
            }
            if (!precedent.getLabel().equals(result)) {
                windowedSample.getTrainSample().addObject(precedent);
                nextStepRequired = true;
                continue;
            }
            newTestSample.addObject(precedent);
        }
        windowedSample.setTestSample(newTestSample);
        return nextStepRequired;
    }

    private static class WindowedSample<LabelType>
    implements Serializable {
        private static final long serialVersionUID = 6143498894540880870L;
        private final Sample<LabelType> trainSample;
        private Sample<LabelType> testSample;

        public WindowedSample(Sample<LabelType> sample) {
            Sample<LabelType> window = new Sample<LabelType>();
            Sample<LabelType> testSample = new Sample<LabelType>();
            Random rand = new Random();
            for (ILabelledInstance<LabelType> precedent : sample.getObjects()) {
                if (rand.nextDouble() < 0.2) {
                    window.addObject(precedent);
                    continue;
                }
                testSample.addObject(precedent);
            }
            this.trainSample = window;
            this.testSample = testSample;
        }

        public Sample<LabelType> getTrainSample() {
            return this.trainSample;
        }

        public Sample<LabelType> getTestSample() {
            return this.testSample;
        }

        public void setTestSample(Sample<LabelType> testSample) {
            this.testSample = testSample;
        }
    }
}

