/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregation;
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.search.aggregations.metrics.PercentilesAggregationBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.common.AbstractAucRoc;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class AucRoc
extends AbstractAucRoc {
    public static final ParseField INCLUDE_CURVE = new ParseField("include_curve", new String[0]);
    public static final ParseField CLASS_NAME = new ParseField("class_name", new String[0]);
    public static final ConstructingObjectParser<AucRoc, Void> PARSER = new ConstructingObjectParser(NAME.getPreferredName(), a -> new AucRoc((Boolean)a[0], (String)a[1]));
    private static final String TRUE_AGG_NAME;
    private static final String NON_TRUE_AGG_NAME;
    private static final String NESTED_AGG_NAME = "nested";
    private static final String NESTED_FILTER_AGG_NAME = "nested_filter";
    private static final String PERCENTILES_AGG_NAME = "percentiles";
    private final boolean includeCurve;
    private final String className;
    private final SetOnce<EvaluationFields> fields = new SetOnce();
    private final SetOnce<EvaluationMetricResult> result = new SetOnce();

    public static AucRoc fromXContent(XContentParser parser) {
        return PARSER.apply(parser, null);
    }

    public AucRoc(Boolean includeCurve, String className) {
        this.includeCurve = includeCurve == null ? false : includeCurve;
        this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME.getPreferredName());
    }

    public AucRoc(StreamInput in) throws IOException {
        this.includeCurve = in.readBoolean();
        this.className = in.readOptionalString();
    }

    @Override
    public String getWriteableName() {
        return MlEvaluationNamedXContentProvider.registeredMetricName(Classification.NAME, NAME);
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeBoolean(this.includeCurve);
        out.writeOptionalString(this.className);
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(INCLUDE_CURVE.getPreferredName(), this.includeCurve);
        if (this.className != null) {
            builder.field(CLASS_NAME.getPreferredName(), this.className);
        }
        builder.endObject();
        return builder;
    }

    @Override
    public Set<String> getRequiredFields() {
        return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_CLASS_FIELD.getPreferredName(), EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName());
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        AucRoc that = (AucRoc)o;
        return this.includeCurve == that.includeCurve && Objects.equals(this.className, that.className);
    }

    public int hashCode() {
        return Objects.hash(this.includeCurve, this.className);
    }

    @Override
    public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters, EvaluationFields evaluationFields) {
        if (this.result.get() != null) {
            return Tuple.tuple(List.of(), List.of());
        }
        this.fields.trySet(evaluationFields);
        double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> v).toArray();
        PercentilesAggregationBuilder percentilesAgg = ((PercentilesAggregationBuilder)AggregationBuilders.percentiles(PERCENTILES_AGG_NAME).field(evaluationFields.getPredictedProbabilityField())).percentiles(percentiles);
        AggregationBuilder nestedAgg = AggregationBuilders.nested(NESTED_AGG_NAME, evaluationFields.getTopClassesField()).subAggregation(AggregationBuilders.filter(NESTED_FILTER_AGG_NAME, QueryBuilders.termQuery(evaluationFields.getPredictedClassField(), this.className)).subAggregation(percentilesAgg));
        TermQueryBuilder actualIsTrueQuery = QueryBuilders.termQuery(evaluationFields.getActualField(), this.className);
        AggregationBuilder percentilesForClassValueAgg = AggregationBuilders.filter(TRUE_AGG_NAME, actualIsTrueQuery).subAggregation(nestedAgg);
        AggregationBuilder percentilesForRestAgg = AggregationBuilders.filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery)).subAggregation(nestedAgg);
        return Tuple.tuple(List.of(percentilesForClassValueAgg, percentilesForRestAgg), List.of());
    }

    @Override
    public void process(InternalAggregations aggs) {
        long totalDocCount;
        if (this.result.get() != null) {
            return;
        }
        SingleBucketAggregation classAgg = (SingleBucketAggregation)aggs.get(TRUE_AGG_NAME);
        SingleBucketAggregation classNested = (SingleBucketAggregation)classAgg.getAggregations().get(NESTED_AGG_NAME);
        SingleBucketAggregation classNestedFilter = (SingleBucketAggregation)classNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
        SingleBucketAggregation restAgg = (SingleBucketAggregation)aggs.get(NON_TRUE_AGG_NAME);
        SingleBucketAggregation restNested = (SingleBucketAggregation)restAgg.getAggregations().get(NESTED_AGG_NAME);
        SingleBucketAggregation restNestedFilter = (SingleBucketAggregation)restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
        if (classAgg.getDocCount() == 0L) {
            throw ExceptionsHelper.badRequestException("[{}] requires at least one [{}] to have the value [{}]", this.getName(), this.fields.get().getActualField(), this.className);
        }
        if (restAgg.getDocCount() == 0L) {
            throw ExceptionsHelper.badRequestException("[{}] requires at least one [{}] to have a different value than [{}]", this.getName(), this.fields.get().getActualField(), this.className);
        }
        long filteredDocCount = classNestedFilter.getDocCount() + restNestedFilter.getDocCount();
        if (filteredDocCount < (totalDocCount = classAgg.getDocCount() + restAgg.getDocCount())) {
            throw ExceptionsHelper.badRequestException("[{}] requires that [{}] appears as one of the [{}] for every document (appeared in {} out of {}). This is probably caused by the {} value being less than the total number of actual classes in the dataset.", this.getName(), this.className, this.fields.get().getPredictedClassField(), filteredDocCount, totalDocCount, org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification.NUM_TOP_CLASSES.getPreferredName());
        }
        Percentiles classPercentiles = (Percentiles)classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
        double[] tpPercentiles = AucRoc.percentilesArray(classPercentiles);
        Percentiles restPercentiles = (Percentiles)restNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
        double[] fpPercentiles = AucRoc.percentilesArray(restPercentiles);
        List<AbstractAucRoc.AucRocPoint> aucRocCurve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
        double aucRocScore = AucRoc.calculateAucScore(aucRocCurve);
        this.result.set(new AbstractAucRoc.Result(aucRocScore, this.includeCurve ? aucRocCurve : Collections.emptyList()));
    }

    public Optional<EvaluationMetricResult> getResult() {
        return Optional.ofNullable(this.result.get());
    }

    static {
        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), CLASS_NAME);
        TRUE_AGG_NAME = NAME.getPreferredName() + "_true";
        NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true";
    }
}

