/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.texterra.core.nlp.annotators.ml;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.OptionalInt;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import ru.ispras.ml.datamodel.ILabeledDataset;
import ru.ispras.ml.datamodel.ILabelledInstance;
import ru.ispras.ml.datamodel.ListLabelledDataset;
import ru.ispras.texterra.core.nlp.annotators.ml.IDatasetPreprocessor;

public final class RandomDatasetBalancer
implements IDatasetPreprocessor,
Serializable {
    private static final long serialVersionUID = 3733478207825888461L;
    private final Random random;

    public RandomDatasetBalancer(Random random) {
        this.random = random;
    }

    public RandomDatasetBalancer() {
        this(new Random());
    }

    @Override
    public <L> ILabeledDataset<L> preprocess(ILabeledDataset<L> dataset) {
        Map<L, List<ILabelledInstance<L>>> grouped = this.group(dataset);
        OptionalInt groupSize = grouped.values().stream().mapToInt(group -> group.size()).min();
        if (!groupSize.isPresent()) {
            throw new IllegalStateException("Cannot determine number of items in group");
        }
        int size = groupSize.getAsInt();
        List result = ((Stream)grouped.values().stream().sequential()).flatMap(group -> this.getBalanced(Collections.unmodifiableList(group), size).stream()).collect(Collectors.toList());
        return new ListLabelledDataset(result);
    }

    private <L> List<ILabelledInstance<L>> getBalanced(List<ILabelledInstance<L>> instances, int size) {
        if (instances.size() == size) {
            return instances;
        }
        if (instances.size() < size) {
            return this.getShuffled(instances).subList(0, size);
        }
        ArrayList<ILabelledInstance<L>> result = new ArrayList<ILabelledInstance<L>>();
        for (int i = 0; i < size / instances.size(); ++i) {
            result.addAll(instances);
        }
        if (result.size() != size) {
            result.addAll(this.getShuffled(instances).subList(0, size - result.size()));
        }
        return result;
    }

    private <L> List<ILabelledInstance<L>> getShuffled(List<ILabelledInstance<L>> instances) {
        ArrayList<ILabelledInstance<L>> shuffledInstances = new ArrayList<ILabelledInstance<L>>(instances);
        Collections.shuffle(shuffledInstances, this.random);
        return shuffledInstances;
    }
}

