/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.ml.classification.passiveaggressive.models;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang.Validate;
import ru.ispras.ml.classification.passiveaggressive.kernels.IKernelFunction;
import ru.ispras.ml.classification.passiveaggressive.kernels.LinearKernel;
import ru.ispras.ml.classification.passiveaggressive.models.AbstractBinaryClassificationModel;
import ru.ispras.ml.classification.passiveaggressive.models.IPAModel;
import ru.ispras.ml.datamodel.Feature;
import ru.ispras.ml.datamodel.IInstance;
import ru.ispras.ml.datamodel.value.cast.DoubleValueCaster;

public class LinearPAModel<LabelType>
extends AbstractBinaryClassificationModel<LabelType>
implements IPAModel<LabelType> {
    private static final long serialVersionUID = 30353652267761037L;
    private double bias;
    private final Map<Feature, Double> featureWeights;

    public LinearPAModel(LabelType positiveLabel, LabelType negativeLabel) {
        this(positiveLabel, negativeLabel, 0.0, new HashMap<Feature, Double>());
    }

    protected LinearPAModel(LabelType positiveLabel, LabelType negativeLabel, double bias, Map<Feature, Double> featureWeights) {
        super(positiveLabel, negativeLabel);
        this.bias = bias;
        this.featureWeights = featureWeights;
    }

    public double getBias() {
        return this.bias;
    }

    public Map<Feature, Double> getFeatureWeights() {
        return Collections.unmodifiableMap(this.featureWeights);
    }

    @Override
    public double getMarginOn(IInstance instance) {
        Validate.notNull((Object)instance, (String)"Instance should not be null");
        double res = this.bias;
        for (Feature feature : this.featureWeights.keySet()) {
            if (!instance.hasFeature(feature)) continue;
            res += this.featureWeights.get(feature) * DoubleValueCaster.cast(instance.getFeatureValue(feature)).getDouble();
        }
        return res;
    }

    @Override
    public void add(IInstance instance, double weight) {
        Validate.notNull((Object)instance, (String)"Instance should not be null");
        this.bias += weight;
        for (Feature feature : instance.getFeaturesValues().keySet()) {
            Double oldFeatureWeight = this.featureWeights.get(feature);
            if (oldFeatureWeight == null) {
                oldFeatureWeight = 0.0;
            }
            Double newFeatureWeight = oldFeatureWeight + weight * DoubleValueCaster.cast(instance.getFeatureValue(feature)).getDouble();
            this.featureWeights.put(feature, newFeatureWeight);
        }
    }

    @Override
    public IKernelFunction getKernel() {
        return new LinearKernel(1.0);
    }

    @Override
    public boolean equals(Object other) {
        if (!super.equals(other)) {
            return false;
        }
        LinearPAModel otherModel = (LinearPAModel)other;
        return this.bias == otherModel.bias && this.featureWeights.equals(otherModel.featureWeights);
    }

    @Override
    public int hashCode() {
        return super.hashCode() ^ new Double(this.bias).hashCode() ^ this.featureWeights.hashCode();
    }
}

