/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.ml.classification.randomforest;

import java.io.Serializable;
import java.util.List;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder;
import quickml.supervised.tree.decisionTree.DecisionTreeBuilder;
import ru.ispras.ml.TrainingException;
import ru.ispras.ml.classification.IConfidenceClassifier;
import ru.ispras.ml.classification.IConfidenceClassifierTrainer;
import ru.ispras.ml.classification.randomforest.AttributesMapHelper;
import ru.ispras.ml.classification.randomforest.RandomForestClassifier;
import ru.ispras.ml.datamodel.ILabeledDataset;
import ru.ispras.ml.datamodel.ILabelledInstance;

public final class RandomForestTrainer<L extends Serializable>
implements IConfidenceClassifierTrainer<L> {
    private static final long serialVersionUID = 1609103231653240137L;
    private final Supplier<DecisionTreeBuilder<ClassifierInstance>> treeBuilderSupplier;
    private final int numTrees;
    private final Function<L, Double> labelWeighter;
    private RandomDecisionForest model = null;

    public RandomForestTrainer() {
        this(10);
    }

    public RandomForestTrainer(int numTrees) {
        this(numTrees, (Supplier<DecisionTreeBuilder> & Serializable)() -> new DecisionTreeBuilder());
    }

    public RandomForestTrainer(int numTrees, Supplier<DecisionTreeBuilder<ClassifierInstance>> treeBuilderSupplier) {
        this(numTrees, treeBuilderSupplier, (Function<Serializable, Double> & Serializable)ignore -> 1.0);
    }

    public RandomForestTrainer(int numTrees, Supplier<DecisionTreeBuilder<ClassifierInstance>> treeBuilderSupplier, Function<L, Double> labelWeighter) {
        this.numTrees = numTrees;
        this.treeBuilderSupplier = treeBuilderSupplier;
        this.labelWeighter = labelWeighter;
    }

    @Override
    public void train(ILabelledInstance<L> precedent) throws TrainingException {
        throw new UnsupportedOperationException();
    }

    @Override
    public void train(ILabeledDataset<L> data) throws TrainingException {
        List instances = data.stream().map(instance -> new ClassifierInstance(AttributesMapHelper.getAttributes(instance), (Serializable)instance.getLabel(), this.labelWeighter.apply(instance.getLabel()).doubleValue())).collect(Collectors.toList());
        this.model = new RandomDecisionForestBuilder(this.treeBuilderSupplier.get()).numTrees(this.numTrees).buildPredictiveModel(instances);
    }

    @Override
    public IConfidenceClassifier<L> getPredictor() {
        return new RandomForestClassifier(this.model);
    }
}

