/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.promql.function;

import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiFunction;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AbsentOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AvgOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.aggregate.CountOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Delta;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FirstOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Idelta;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Increase;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Irate;
import org.elasticsearch.xpack.esql.expression.function.aggregate.LastOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.esql.expression.function.aggregate.MaxOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.esql.expression.function.aggregate.MinOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.PercentileOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.PresentOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev;
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDevOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SumOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.TimeSeriesAggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Variance;
import org.elasticsearch.xpack.esql.expression.function.aggregate.VarianceOverTime;
import org.elasticsearch.xpack.esql.expression.promql.function.FunctionType;
import org.elasticsearch.xpack.esql.parser.ParsingException;

public class PromqlFunctionRegistry {
    public static final PromqlFunctionRegistry INSTANCE = new PromqlFunctionRegistry();
    private final Map<String, FunctionDefinition> promqlFunctions = new HashMap<String, FunctionDefinition>();
    private static final Set<String> NOT_IMPLEMENTED = Set.of("bottomk", "topk", "group", "count_values", "changes", "deriv", "holt_winters", "mad_over_time", "predict_linear", "resets", "abs", "absent", "ceil", "clamp", "clamp_max", "clamp_min", "exp", "floor", "ln", "log2", "log10", "round", "scalar", "sgn", "sort", "sort_desc", "sqrt", "acos", "acosh", "asin", "asinh", "atan", "atanh", "cos", "cosh", "deg", "rad", "sin", "sinh", "tan", "tanh", "day_of_month", "day_of_week", "day_of_year", "days_in_month", "hour", "minute", "month", "timestamp", "year", "label_join", "label_replace", "histogram_avg", "histogram_count", "histogram_fraction", "histogram_quantile", "histogram_stddev", "histogram_stdvar", "histogram_sum", "pi", "time", "vector");

    public PromqlFunctionRegistry() {
        this.register(PromqlFunctionRegistry.functionDefinitions());
    }

    private static FunctionDefinition[][] functionDefinitions() {
        return new FunctionDefinition[][]{{PromqlFunctionRegistry.withinSeries("delta", Delta::new), PromqlFunctionRegistry.withinSeries("idelta", Idelta::new), PromqlFunctionRegistry.withinSeries("increase", Increase::new), PromqlFunctionRegistry.withinSeries("irate", Irate::new), PromqlFunctionRegistry.withinSeries("rate", Rate::new)}, {PromqlFunctionRegistry.withinSeriesOverTimeWithWindow("avg_over_time", AvgOverTime::new), PromqlFunctionRegistry.withinSeriesOverTimeWithWindow("count_over_time", CountOverTime::new), PromqlFunctionRegistry.withinSeriesOverTimeWithWindow("max_over_time", MaxOverTime::new), PromqlFunctionRegistry.withinSeriesOverTimeWithWindow("min_over_time", MinOverTime::new), PromqlFunctionRegistry.withinSeriesOverTimeWithWindow("sum_over_time", SumOverTime::new), PromqlFunctionRegistry.withinSeriesOverTime("stddev_over_time", StdDevOverTime::new), PromqlFunctionRegistry.withinSeriesOverTime("stdvar_over_time", VarianceOverTime::new)}, {PromqlFunctionRegistry.withinSeries("first_over_time", FirstOverTime::new), PromqlFunctionRegistry.withinSeries("last_over_time", LastOverTime::new)}, {PromqlFunctionRegistry.withinSeriesOverTime("absent_over_time", AbsentOverTime::new), PromqlFunctionRegistry.withinSeriesOverTime("present_over_time", PresentOverTime::new)}, {PromqlFunctionRegistry.withinSeriesOverTime("quantile_over_time", PercentileOverTime::new)}, {PromqlFunctionRegistry.acrossSeries("avg", Avg::new), PromqlFunctionRegistry.acrossSeries("count", Count::new), PromqlFunctionRegistry.acrossSeries("max", Max::new), PromqlFunctionRegistry.acrossSeries("min", Min::new), PromqlFunctionRegistry.acrossSeries("sum", Sum::new), PromqlFunctionRegistry.acrossSeries("stddev", StdDev::new), PromqlFunctionRegistry.acrossSeries("stdvar", Variance::new)}, {PromqlFunctionRegistry.acrossSeriesBinary("quantile", Percentile::new)}};
    }

    private static FunctionDefinition withinSeriesOverTime(String name, OverTimeWithinSeries<?> builder) {
        return new FunctionDefinition(name, FunctionType.WITHIN_SERIES_AGGREGATION, Arity.ONE, (source, params) -> builder.build((Source)source, (Expression)params.get(0)));
    }

    private static FunctionDefinition withinSeriesOverTime(String name, OverTimeWithinSeriesBinary<?> builder) {
        return new FunctionDefinition(name, FunctionType.WITHIN_SERIES_AGGREGATION, Arity.TWO, (source, params) -> builder.build((Source)source, (Expression)params.get(0), (Expression)params.get(1)));
    }

    private static FunctionDefinition withinSeriesOverTimeWithWindow(String name, OverTimeWithinSeriesBinary<?> builder) {
        return new FunctionDefinition(name, FunctionType.WITHIN_SERIES_AGGREGATION, Arity.ONE, (source, params) -> builder.build((Source)source, (Expression)params.get(0), (Expression)AggregateFunction.NO_WINDOW));
    }

    private static FunctionDefinition withinSeries(String name, WithinSeries<?> builder) {
        return new FunctionDefinition(name, FunctionType.WITHIN_SERIES_AGGREGATION, Arity.ONE, (source, params) -> {
            Expression valueField = (Expression)params.get(0);
            Expression timestampField = (Expression)params.get(1);
            return builder.build((Source)source, valueField, timestampField);
        });
    }

    private static FunctionDefinition withinSeries(String name, WithinSeriesWindow<?> builder) {
        return new FunctionDefinition(name, FunctionType.WITHIN_SERIES_AGGREGATION, Arity.ONE, (source, params) -> {
            Expression valueField = (Expression)params.get(0);
            Expression timestampField = (Expression)params.get(1);
            return builder.build((Source)source, valueField, (Expression)AggregateFunction.NO_WINDOW, timestampField);
        });
    }

    private static FunctionDefinition acrossSeries(String name, AcrossSeriesUnary<?> builder) {
        return new FunctionDefinition(name, FunctionType.ACROSS_SERIES_AGGREGATION, Arity.ONE, (source, params) -> builder.build((Source)source, (Expression)params.get(0)));
    }

    private static FunctionDefinition acrossSeriesBinary(String name, AcrossSeriesBinary<?> builder) {
        return new FunctionDefinition(name, FunctionType.ACROSS_SERIES_AGGREGATION, Arity.TWO, (source, params) -> {
            Expression param = (Expression)params.get(0);
            Expression field = (Expression)params.get(1);
            return builder.build((Source)source, field, param);
        });
    }

    private void register(FunctionDefinition[][] definitionGroups) {
        FunctionDefinition[][] functionDefinitionArray = definitionGroups;
        int n = functionDefinitionArray.length;
        for (int i = 0; i < n; ++i) {
            FunctionDefinition[] group;
            for (FunctionDefinition def : group = functionDefinitionArray[i]) {
                String normalized = this.normalize(def.name());
                this.promqlFunctions.put(normalized, def);
            }
        }
    }

    private String normalize(String name) {
        return name.toLowerCase(Locale.ROOT);
    }

    public Boolean functionExists(String name) {
        String normalized = this.normalize(name);
        if (this.promqlFunctions.containsKey(normalized)) {
            return true;
        }
        if (NOT_IMPLEMENTED.contains(normalized)) {
            return null;
        }
        return false;
    }

    public FunctionDefinition functionMetadata(String name) {
        String normalized = this.normalize(name);
        return this.promqlFunctions.get(normalized);
    }

    public Function buildEsqlFunction(String name, Source source, List<Expression> params) {
        FunctionDefinition metadata = this.functionMetadata(name);
        try {
            return metadata.esqlBuilder().apply(source, params);
        }
        catch (Exception e) {
            throw new ParsingException(source, "Error building ESQL function for [{}]: {}", name, e.getMessage());
        }
    }

    public record FunctionDefinition(String name, FunctionType functionType, Arity arity, BiFunction<Source, List<Expression>, Function> esqlBuilder) {
        public FunctionDefinition {
            Objects.requireNonNull(name, "name cannot be null");
            Objects.requireNonNull(functionType, "functionType cannot be null");
            Objects.requireNonNull(arity, "arity cannot be null");
            Objects.requireNonNull(esqlBuilder, "esqlBuilder cannot be null");
        }
    }

    @FunctionalInterface
    protected static interface WithinSeries<T extends TimeSeriesAggregateFunction> {
        public T build(Source var1, Expression var2, Expression var3);
    }

    @FunctionalInterface
    protected static interface WithinSeriesWindow<T extends TimeSeriesAggregateFunction> {
        public T build(Source var1, Expression var2, Expression var3, Expression var4);
    }

    @FunctionalInterface
    protected static interface OverTimeWithinSeriesBinary<T extends TimeSeriesAggregateFunction> {
        public T build(Source var1, Expression var2, Expression var3);
    }

    @FunctionalInterface
    protected static interface OverTimeWithinSeries<T extends TimeSeriesAggregateFunction> {
        public T build(Source var1, Expression var2);
    }

    @FunctionalInterface
    protected static interface AcrossSeriesUnary<T extends AggregateFunction> {
        public T build(Source var1, Expression var2);
    }

    @FunctionalInterface
    protected static interface AcrossSeriesBinary<T extends AggregateFunction> {
        public T build(Source var1, Expression var2, Expression var3);
    }

    public record Arity(int min, int max) {
        public static final Arity NONE = new Arity(0, 0);
        public static final Arity ONE = new Arity(1, 1);
        public static final Arity TWO = new Arity(2, 2);
        public static final Arity VARIADIC = new Arity(1, Integer.MAX_VALUE);

        public Arity {
            if (min < 0) {
                throw new IllegalArgumentException("min must be non-negative");
            }
            if (max < min) {
                throw new IllegalArgumentException("max must be >= min");
            }
        }

        public static Arity fixed(int count) {
            return switch (count) {
                case 0 -> NONE;
                case 1 -> ONE;
                case 2 -> TWO;
                default -> new Arity(count, count);
            };
        }

        public static Arity range(int min, int max) {
            return min == max ? Arity.fixed(min) : new Arity(min, max);
        }

        public static Arity atLeast(int min) {
            return min == 1 ? VARIADIC : new Arity(min, Integer.MAX_VALUE);
        }

        public static Arity optional(int max) {
            return new Arity(0, max);
        }

        public boolean validate(int paramCount) {
            return paramCount >= this.min && paramCount <= this.max;
        }
    }
}

