/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public final class InferenceHelpers {
    private InferenceHelpers() {
    }

    public static Tuple<TopClassificationValue, List<TopClassEntry>> topClasses(double[] probabilities, List<String> classificationLabels, @Nullable double[] classificationWeights, int numToInclude, PredictionFieldType predictionFieldType) {
        if (classificationLabels != null && probabilities.length != classificationLabels.size()) {
            throw ExceptionsHelper.serverError("model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", null, probabilities.length, classificationLabels.size());
        }
        double[] scores = classificationWeights == null ? probabilities : IntStream.range(0, probabilities.length).mapToDouble(i -> probabilities[i] * classificationWeights[i]).toArray();
        int[] sortedIndices = IntStream.range(0, scores.length).boxed().sorted(Comparator.comparing(i -> scores[(Integer)i]).reversed()).mapToInt(i -> i).toArray();
        TopClassificationValue topClassificationValue = new TopClassificationValue(sortedIndices[0], probabilities[sortedIndices[0]], scores[sortedIndices[0]]);
        if (numToInclude == 0) {
            return Tuple.tuple((Object)topClassificationValue, Collections.emptyList());
        }
        List<String> labels = classificationLabels == null ? IntStream.range(0, probabilities.length).mapToObj(String::valueOf).toList() : classificationLabels;
        int count = numToInclude < 0 ? probabilities.length : Math.min(numToInclude, probabilities.length);
        ArrayList<TopClassEntry> topClassEntries = new ArrayList<TopClassEntry>(count);
        for (int i2 = 0; i2 < count; ++i2) {
            int idx = sortedIndices[i2];
            topClassEntries.add(new TopClassEntry(predictionFieldType.transformPredictedValue(Double.valueOf(idx), labels.get(idx)), probabilities[idx], scores[idx]));
        }
        return Tuple.tuple((Object)topClassificationValue, topClassEntries);
    }

    public static String classificationLabel(Integer inferenceValue, @Nullable List<String> classificationLabels) {
        if (classificationLabels == null) {
            return String.valueOf(inferenceValue);
        }
        if (inferenceValue < 0 || inferenceValue >= classificationLabels.size()) {
            throw ExceptionsHelper.serverError("model returned classification value of [{}] which is not a valid index in classification labels [{}]", null, inferenceValue, classificationLabels);
        }
        return classificationLabels.get(inferenceValue);
    }

    public static Double toDouble(Object value) {
        if (value instanceof Number) {
            Number number = (Number)value;
            return number.doubleValue();
        }
        if (value instanceof String) {
            String str = (String)value;
            return InferenceHelpers.stringToDouble(str);
        }
        return null;
    }

    private static Double stringToDouble(String value) {
        if (value.isEmpty()) {
            return null;
        }
        try {
            return Double.valueOf(value);
        }
        catch (NumberFormatException nfe) {
            assert (false) : "value is not properly formatted double [" + value + "]";
            return null;
        }
    }

    public static Map<String, double[]> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap, Map<String, double[]> featureImportances) {
        if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
            return featureImportances;
        }
        HashMap<String, double[]> originalFeatureImportance = new HashMap<String, double[]>();
        featureImportances.forEach((feature, importance) -> {
            String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, (String)feature);
            originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : InferenceHelpers.sumDoubleArrays(importance, v1));
        });
        return originalFeatureImportance;
    }

    public static List<RegressionFeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
        ArrayList<RegressionFeatureImportance> importances = new ArrayList<RegressionFeatureImportance>(featureImportance.size());
        featureImportance.forEach((k, v) -> importances.add(new RegressionFeatureImportance((String)k, v[0])));
        return importances;
    }

    public static List<ClassificationFeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance, @Nullable List<String> classificationLabels, @Nullable PredictionFieldType predictionFieldType) {
        ArrayList<ClassificationFeatureImportance> importances = new ArrayList<ClassificationFeatureImportance>(featureImportance.size());
        PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
        featureImportance.forEach((k, v) -> {
            if (((double[])v).length == 1) {
                String zeroLabel = classificationLabels == null ? null : (String)classificationLabels.get(0);
                String oneLabel = classificationLabels == null ? null : (String)classificationLabels.get(1);
                importances.add(new ClassificationFeatureImportance((String)k, Arrays.asList(new ClassificationFeatureImportance.ClassImportance(fieldType.transformPredictedValue(0.0, zeroLabel), -v[0]), new ClassificationFeatureImportance.ClassImportance(fieldType.transformPredictedValue(1.0, oneLabel), v[0]))));
            } else {
                ArrayList<ClassificationFeatureImportance.ClassImportance> classImportance = new ArrayList<ClassificationFeatureImportance.ClassImportance>(((double[])v).length);
                assert (classificationLabels == null || classificationLabels.size() == ((double[])v).length);
                for (int i = 0; i < ((double[])v).length; ++i) {
                    String label = classificationLabels == null ? null : (String)classificationLabels.get(i);
                    classImportance.add(new ClassificationFeatureImportance.ClassImportance(fieldType.transformPredictedValue(Double.valueOf(i), label), v[i]));
                }
                importances.add(new ClassificationFeatureImportance((String)k, (List<ClassificationFeatureImportance.ClassImportance>)classImportance));
            }
        });
        return importances;
    }

    public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
        return InferenceHelpers.sumDoubleArrays(sumTo, inc, 1);
    }

    public static double[] sumDoubleArrays(double[] sumTo, double[] inc, int weight) {
        assert (sumTo != null && inc != null && sumTo.length == inc.length);
        for (int i = 0; i < inc.length; ++i) {
            int n = i;
            sumTo[n] = sumTo[n] + inc[i] * (double)weight;
        }
        return sumTo;
    }

    public static void divMut(double[] xs, int v) {
        if (xs.length == 0) {
            return;
        }
        if (v == 0) {
            throw new IllegalArgumentException("unable to divide by [" + v + "] as it results in undefined behavior");
        }
        int i = 0;
        while (i < xs.length) {
            int n = i++;
            xs[n] = xs[n] / (double)v;
        }
    }

    public static class TopClassificationValue {
        private final int value;
        private final double probability;
        private final double score;

        TopClassificationValue(int value, double probability, double score) {
            this.value = value;
            this.probability = probability;
            this.score = score;
        }

        public int getValue() {
            return this.value;
        }

        public double getProbability() {
            return this.probability;
        }

        public double getScore() {
            return this.score;
        }
    }
}

