/*
 * Decompiled with CFR 0.152.
 */
package ru.ispras.ml.classification.naivebayes;

import java.util.Map;
import ru.ispras.ml.classification.ClassificationException;
import ru.ispras.ml.classification.ConfidenceClassificationResult;
import ru.ispras.ml.classification.IConfidenceClassifier;
import ru.ispras.ml.classification.naivebayes.NaiveBayesModel;
import ru.ispras.ml.datamodel.Feature;
import ru.ispras.ml.datamodel.IInstance;
import ru.ispras.ml.datamodel.value.IValue;

public class NaiveBayesClassifier<LabelType>
implements IConfidenceClassifier<LabelType> {
    private static final long serialVersionUID = -4791395629494530817L;
    private static final double DEFAULT_LAPLAS_SMOOTH_PARAMETER = 0.0;
    private final NaiveBayesModel<LabelType> model;
    final double laplasSmoothParam;

    public NaiveBayesClassifier(NaiveBayesModel<LabelType> model, double smoothParam) {
        assert (smoothParam >= 0.0);
        assert (model != null);
        this.model = model;
        this.laplasSmoothParam = smoothParam;
    }

    public NaiveBayesClassifier(NaiveBayesModel<LabelType> model) {
        this(model, 0.0);
    }

    public double getLaplasSmoothParam() {
        return this.laplasSmoothParam;
    }

    @Override
    public ConfidenceClassificationResult<LabelType> predict(IInstance instance) throws ClassificationException {
        Object result = null;
        double value = 0.0;
        double sum = 0.0;
        for (LabelType classResult : this.model.getClasses()) {
            double classValue = Math.exp(this.computeValueForClass(instance, classResult));
            if (classValue > value) {
                value = classValue;
                result = classResult;
            }
            sum += classValue;
        }
        if (result == null) {
            throw new ClassificationException("Cannot determine class for an instance!");
        }
        return new ConfidenceClassificationResult<Object>(result, value / sum);
    }

    private double computeValueForClass(IInstance instance, LabelType classResult) {
        NaiveBayesModel.ClassStatistics statistics = this.model.getClassStatistics(classResult);
        double result = Math.log(statistics.getNumberOfInstancesInClass()) - Math.log(this.model.getNumberOfInstances());
        for (Map.Entry<Feature, IValue> featureValue : instance.getFeaturesValues().entrySet()) {
            Feature feature = featureValue.getKey();
            IValue value = featureValue.getValue();
            NaiveBayesModel.FeatureStatistics featureStatistics = statistics.getFeatureStatistics(feature);
            result += Math.log(this.getLaplasEstimation(value, featureStatistics));
        }
        return result;
    }

    private double getLaplasEstimation(IValue value, NaiveBayesModel.FeatureStatistics featureStatistics) {
        double numberOfInstancesWithValue = (double)featureStatistics.getCountForValue(value) + this.laplasSmoothParam;
        double numberOfInstancesWithFeature = (double)featureStatistics.getNumberOfInstancesWithFeature() + this.laplasSmoothParam * (double)this.model.getNumberOfClasses();
        return numberOfInstancesWithValue / numberOfInstancesWithFeature;
    }
}

