/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.ml.classification.meta;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import ru.ispras.ml.classification.meta.ClassifiedInstances;

public class InstancesSetPartitioning<LabelType> {
    private ClassifiedInstances<LabelType> trainSet;
    private ClassifiedInstances<LabelType> testSet;

    public InstancesSetPartitioning(double trainProportion, ClassifiedInstances<LabelType> dataset) {
        int size = dataset.size();
        int trainSetSize = this.computeTrainSetSize(size, trainProportion);
        List<Integer> indexes = this.getShuffledIndexes(size);
        this.trainSet = this.collect(indexes.subList(0, trainSetSize), dataset);
        this.testSet = this.collect(indexes.subList(trainSetSize, size), dataset);
    }

    private int computeTrainSetSize(int size, double trainProportion) {
        if (trainProportion < 0.0 || trainProportion > 1.0) {
            throw new IllegalArgumentException("trainProportion should be in range [0.0, 1.0], given: " + String.valueOf(trainProportion));
        }
        int trainSize = (int)((double)size * trainProportion);
        return Math.min(size, Math.max(1, trainSize));
    }

    protected List<Integer> getShuffledIndexes(int size) {
        ArrayList<Integer> result = new ArrayList<Integer>(size);
        for (int i = 0; i < size; ++i) {
            result.add(i);
        }
        Collections.shuffle(result);
        return result;
    }

    private ClassifiedInstances<LabelType> collect(List<Integer> indexes, ClassifiedInstances<LabelType> dataset) {
        ClassifiedInstances<LabelType> result = new ClassifiedInstances<LabelType>();
        for (int index : indexes) {
            result.add(dataset.getInstance(index), dataset.getClassificationResult(index));
        }
        return result;
    }

    public ClassifiedInstances<LabelType> getTrainSet() {
        return this.trainSet;
    }

    public ClassifiedInstances<LabelType> getTestSet() {
        return this.testSet;
    }
}

