/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.texterra.core.nlp.annotators.ml.pipelines;

import java.util.Arrays;
import java.util.function.Supplier;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import ru.ispras.ml.datamodel.InstanceFactory;
import ru.ispras.ml.datamodel.identification.HashCodeInstanceIdGeneratingStrategy;
import ru.ispras.ml.datamodel.identification.InstanceIdGeneratingStrategy;
import ru.ispras.ml.featureextractors.IFeatureExtractor;
import ru.ispras.ml.pipelines.PredictionPipeline;
import ru.ispras.ml.pipelines.SourceLabelPair;
import ru.ispras.texterra.core.nlp.annotators.ml.IPredictorTrainerFactory;

public final class PredictionPipelineTrainer<S, L> {
    private final IFeatureExtractor<S> featureExtractor;
    private final IPredictorTrainerFactory<L> predictorTrainerFactory;

    public PredictionPipelineTrainer(IFeatureExtractor<S> featureExtractor, IPredictorTrainerFactory<L> predictorTrainerFactory) {
        this.featureExtractor = featureExtractor;
        this.predictorTrainerFactory = predictorTrainerFactory;
    }

    public PredictionPipeline<S, L> train(Iterable<SourceLabelPair<S, L>> dataset) {
        return this.trainFromStream(() -> StreamSupport.stream(dataset.spliterator(), false));
    }

    public PredictionPipeline<S, L> trainFromStream(Supplier<Stream<SourceLabelPair<S, L>>> dataset) {
        return new ru.ispras.ml.pipelines.PredictionPipelineTrainer(new InstanceFactory(Arrays.asList(this.featureExtractor), (InstanceIdGeneratingStrategy)new HashCodeInstanceIdGeneratingStrategy()), this.predictorTrainerFactory.create()).train(dataset);
    }
}

