/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.ml.maxent;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import opennlp.maxent.GIS;
import opennlp.model.AbstractModel;
import opennlp.model.DataIndexer;
import opennlp.model.Event;
import opennlp.model.EventStream;
import opennlp.model.OnePassRealValueDataIndexer;
import ru.ispras.ml.TrainingException;
import ru.ispras.ml.classification.IConfidenceClassifierTrainer;
import ru.ispras.ml.datamodel.ILabeledDataset;
import ru.ispras.ml.datamodel.ILabelledInstance;
import ru.ispras.ml.maxent.MaxEntAdapter;
import ru.ispras.ml.maxent.MaxEntHelper;
import ru.ispras.ml.maxent.MaxEntModel;

public final class MaxEntTrainer<L>
implements IConfidenceClassifierTrainer<L> {
    private static final long serialVersionUID = -1794869863478991373L;
    private static final int DEFAULT_ITERATIONS_NUMBER = 64;
    private final int numberOfIterations;
    private final Map<String, L> stringToLabel = new HashMap<String, L>();
    private final Map<L, String> labelToString = new HashMap<L, String>();
    private MaxEntModel model = null;

    public MaxEntTrainer() {
        this(64);
    }

    public MaxEntTrainer(int numberOfIterations) {
        this.numberOfIterations = numberOfIterations;
    }

    @Override
    public void train(ILabelledInstance<L> precedent) throws TrainingException {
        throw new UnsupportedOperationException();
    }

    @Override
    public void train(ILabeledDataset<L> dataset) throws TrainingException {
        try (Stream instancesStream = dataset.stream();){
            Collection events = instancesStream.map(instance -> this.toEvent((ILabelledInstance<L>)instance)).collect(Collectors.toList());
            this.model = this.trainModel(events);
        }
    }

    private Event toEvent(ILabelledInstance<L> instance) {
        MaxEntHelper.MaxEntInstance maxEntInstance = MaxEntHelper.toMaxEntInstance(instance);
        String label = this.getLabel(instance);
        return new Event(label, maxEntInstance.getFeatures(), maxEntInstance.getValues());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private String getLabel(ILabelledInstance<L> instance) {
        Map<L, String> map = this.labelToString;
        synchronized (map) {
            String label = this.labelToString.get(instance.getLabel());
            if (label == null) {
                label = String.valueOf(this.labelToString.size());
                this.labelToString.put(instance.getLabel(), label);
                this.stringToLabel.put(label, instance.getLabel());
            }
            return label;
        }
    }

    private MaxEntModel trainModel(Collection<Event> events) {
        final Iterator<Event> eventsIterator = events.iterator();
        EventStream eventsStream = new EventStream(){

            public Event next() {
                return (Event)eventsIterator.next();
            }

            public boolean hasNext() {
                return eventsIterator.hasNext();
            }
        };
        try {
            return new MaxEntModel((AbstractModel)GIS.trainModel((int)this.numberOfIterations, (DataIndexer)new OnePassRealValueDataIndexer(eventsStream, 0)));
        }
        catch (IOException e) {
            throw new TrainingException(e);
        }
    }

    @Override
    public MaxEntAdapter<L> getPredictor() {
        if (this.model == null) {
            throw new IllegalStateException("Model has not been trained yet!");
        }
        return new MaxEntAdapter<L>(this.model, this.stringToLabel);
    }
}

