/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.ml.classification.meta;

import ru.ispras.ml.TrainingException;
import ru.ispras.ml.classification.ClassificationAlgorithmFacade;
import ru.ispras.ml.classification.IClassifier;
import ru.ispras.ml.classification.IClassifierTrainer;
import ru.ispras.ml.classification.meta.ByClassifierErrorsPartitioning;
import ru.ispras.ml.classification.meta.ClassifiedInstances;
import ru.ispras.ml.classification.meta.ITrainableClassifierFactory;
import ru.ispras.ml.classification.meta.InstancesSetPartitioning;
import ru.ispras.ml.datamodel.ILabeledDataset;
import ru.ispras.ml.datamodel.ILabelledInstance;

public class IterativeTrainer<LabelType>
implements IClassifierTrainer<LabelType> {
    private static final long serialVersionUID = -7783913539787864617L;
    private static final double DEFAULT_TRAIN_INSTANCES_PROPORTION = 0.2;
    private ITrainableClassifierFactory<LabelType> classifierTrainerFactory;
    protected double initialTrainInstancesProportion;

    public IterativeTrainer(ITrainableClassifierFactory<LabelType> factory) {
        this(factory, 0.2);
    }

    public IterativeTrainer(ITrainableClassifierFactory<LabelType> factory, double trainProportion) {
        this.classifierTrainerFactory = factory;
        this.initialTrainInstancesProportion = trainProportion;
    }

    @Override
    public void train(ILabelledInstance<LabelType> precedent) throws TrainingException {
        throw new UnsupportedOperationException();
    }

    @Override
    public void train(ILabeledDataset<LabelType> precedents) throws TrainingException {
        ClassifiedInstances<LabelType> incorrectlyClassified;
        ClassifiedInstances<LabelType> dataset = new ClassifiedInstances<LabelType>(precedents);
        InstancesSetPartitioning<LabelType> splitter = this.divideIntoTrainAndTestSet(dataset);
        ClassifiedInstances<LabelType> trainSet = splitter.getTrainSet();
        ClassifiedInstances<LabelType> testSet = splitter.getTestSet();
        do {
            ClassificationAlgorithmFacade<LabelType> classifierTrainer = this.train(trainSet);
            ByClassifierErrorsPartitioning<LabelType> evaluator = new ByClassifierErrorsPartitioning<LabelType>(classifierTrainer, testSet);
            incorrectlyClassified = evaluator.getIncorrectlyClassified();
            testSet = evaluator.getCorrectlyClassified();
            trainSet.add(incorrectlyClassified);
        } while (!incorrectlyClassified.isEmpty());
    }

    protected InstancesSetPartitioning<LabelType> divideIntoTrainAndTestSet(ClassifiedInstances<LabelType> dataset) {
        return new InstancesSetPartitioning<LabelType>(this.initialTrainInstancesProportion, dataset);
    }

    private ClassificationAlgorithmFacade<LabelType> train(ClassifiedInstances<LabelType> trainSet) throws TrainingException {
        ClassificationAlgorithmFacade<LabelType> result = this.classifierTrainerFactory.createClassifier();
        result.train(trainSet.getLabelledDataset());
        return result;
    }

    @Override
    public IClassifier<LabelType> getPredictor() {
        return this.classifierTrainerFactory.createClassifier();
    }
}

