/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.ml.classification.passiveaggressive.trainers;

import java.util.stream.Stream;
import org.apache.commons.lang.Validate;
import ru.ispras.ml.TrainingException;
import ru.ispras.ml.classification.IClassifier;
import ru.ispras.ml.classification.IClassifierTrainer;
import ru.ispras.ml.classification.passiveaggressive.PAClassifier;
import ru.ispras.ml.classification.passiveaggressive.models.IPAModel;
import ru.ispras.ml.datamodel.ILabeledDataset;
import ru.ispras.ml.datamodel.ILabelledInstance;
import ru.ispras.ml.datamodel.value.IValue;
import ru.ispras.ml.datamodel.value.cast.DoubleValueCaster;

public abstract class AbstractPATrainer<LabelType>
implements IClassifierTrainer<LabelType> {
    private static final long serialVersionUID = 6720991552248150012L;
    private static final double DEFAULT_MARGIN = 1.0;
    private final double margin;
    private IPAModel<LabelType> model;

    protected AbstractPATrainer(IPAModel model) {
        this(model, 1.0);
    }

    protected AbstractPATrainer(IPAModel<LabelType> model, double margin) {
        Validate.notNull(model, (String)"Model should not be null");
        Validate.isTrue((margin >= 0.0 ? 1 : 0) != 0, (String)"Margin should be greater than zero");
        this.model = model;
        this.margin = margin;
    }

    public IPAModel<LabelType> getModel() {
        return this.model;
    }

    public void setModel(IPAModel<LabelType> model) {
        Validate.notNull(model, (String)"Model should not be null");
        this.model = model;
    }

    @Override
    public void train(ILabelledInstance<LabelType> precedent) throws TrainingException {
        Validate.notNull(precedent, (String)"Precedent should not be null");
        Validate.notNull(precedent.getLabel(), (String)"Answer should not be null");
        for (IValue value : precedent.getFeaturesValues().values()) {
            Validate.isTrue((boolean)DoubleValueCaster.canCast(value), (String)"It should be possible to cast all features to double");
        }
        int sign = this.getLabelSign(precedent.getLabel());
        double instanceMargin = (double)sign * this.model.getMarginOn(precedent);
        if (instanceMargin < this.margin) {
            double loss = this.margin - instanceMargin;
            double squaredNorm = this.model.getKernel().innerProduct(precedent);
            this.model.add(precedent, (double)sign * this.getUpdateCoefficient(loss, squaredNorm));
        }
    }

    protected abstract double getUpdateCoefficient(double var1, double var3) throws TrainingException;

    @Override
    public void train(ILabeledDataset<LabelType> precedents) throws TrainingException {
        Validate.notNull(precedents, (String)"Precedents should not be null");
        try (Stream stream = precedents.stream();){
            ((Stream)stream.sequential()).forEach(instance -> this.train((ILabelledInstance<LabelType>)instance));
        }
    }

    private int getLabelSign(LabelType label) throws TrainingException {
        if (label.equals(this.model.getPositiveLabel())) {
            return 1;
        }
        if (label.equals(this.model.getNegativeLabel())) {
            return -1;
        }
        throw new TrainingException("Unknown class label");
    }

    @Override
    public IClassifier<LabelType> getPredictor() {
        return new PAClassifier<LabelType>(this.model);
    }
}

