/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical.promql;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.action.PromqlFeatures;
import org.elasticsearch.xpack.esql.capabilities.ConfigurationAware;
import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.EndsWith;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.StartsWith;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLike;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.expression.promql.function.PromqlFunctionRegistry;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.promql.AutomatonUtils;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.promql.AcrossSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.promql.PlaceholderRelation;
import org.elasticsearch.xpack.esql.plan.logical.promql.PromqlCommand;
import org.elasticsearch.xpack.esql.plan.logical.promql.PromqlFunctionCall;
import org.elasticsearch.xpack.esql.plan.logical.promql.WithinSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.promql.selector.LabelMatcher;
import org.elasticsearch.xpack.esql.plan.logical.promql.selector.LabelMatchers;
import org.elasticsearch.xpack.esql.plan.logical.promql.selector.Selector;

public final class TranslatePromqlToTimeSeriesAggregate
extends OptimizerRules.OptimizerRule<PromqlCommand> {
    public static final Duration DEFAULT_LOOKBACK = Duration.ofMinutes(5L);

    public TranslatePromqlToTimeSeriesAggregate() {
        super(OptimizerRules.TransformDirection.UP);
    }

    @Override
    protected LogicalPlan rule(PromqlCommand promqlCommand) {
        if (!PromqlFeatures.isEnabled()) {
            throw new EsqlIllegalArgumentException("PromQL translation attempted but feature is disabled. This should have been caught by the parser.");
        }
        LogicalPlan promqlPlan = promqlCommand.promqlPlan();
        promqlPlan = promqlPlan.transformUp(PlaceholderRelation.class, pr -> TranslatePromqlToTimeSeriesAggregate.withTimestampFilter(promqlCommand, promqlCommand.child()));
        TimeSeriesAggregate tsAggregate = (TimeSeriesAggregate)TranslatePromqlToTimeSeriesAggregate.map(promqlCommand, promqlPlan).plan();
        promqlPlan = tsAggregate;
        Alias convertedValue = new Alias(promqlCommand.source(), promqlCommand.valueColumnName(), new ToDouble(promqlCommand.source(), tsAggregate.output().getFirst().toAttribute()), promqlCommand.valueId());
        promqlPlan = new Eval(promqlCommand.source(), promqlPlan, List.of(convertedValue));
        ArrayList<NamedExpression> projections = new ArrayList<NamedExpression>();
        projections.add(convertedValue.toAttribute());
        List<Attribute> output = tsAggregate.output();
        for (int i = 1; i < output.size(); ++i) {
            projections.add(output.get(i));
        }
        return new Project(promqlCommand.source(), promqlPlan, projections);
    }

    private static LogicalPlan withTimestampFilter(PromqlCommand promqlCommand, LogicalPlan plan) {
        if (promqlCommand.start().value() != null && promqlCommand.end().value() != null) {
            Source promqlSource = promqlCommand.source();
            Expression timestamp = promqlCommand.timestamp();
            plan = new Filter(promqlSource, plan, new And(promqlSource, new GreaterThanOrEqual(promqlSource, timestamp, promqlCommand.start()), new LessThanOrEqual(promqlSource, timestamp, promqlCommand.end())));
        }
        return plan;
    }

    private static MapResult map(PromqlCommand promqlCommand, LogicalPlan p) {
        if (p instanceof Selector) {
            Selector selector = (Selector)p;
            return TranslatePromqlToTimeSeriesAggregate.mapSelector(selector);
        }
        if (p instanceof PromqlFunctionCall) {
            PromqlFunctionCall functionCall = (PromqlFunctionCall)p;
            return TranslatePromqlToTimeSeriesAggregate.mapFunction(promqlCommand, functionCall);
        }
        throw new QlIllegalArgumentException("Unsupported PromQL plan node: {}", new Object[]{p});
    }

    private static MapResult mapSelector(Selector selector) {
        LabelMatchers matchers = selector.labelMatchers();
        Expression matcherCondition = TranslatePromqlToTimeSeriesAggregate.translateLabelMatchers(selector.source(), selector.labels(), matchers);
        ArrayList<Expression> selectorConditions = new ArrayList<Expression>();
        selectorConditions.add(new IsNotNull(selector.source(), selector.series()));
        if (matcherCondition != null) {
            selectorConditions.add(matcherCondition);
        }
        HashMap<String, Expression> extras = new HashMap<String, Expression>();
        extras.put("field", selector.series());
        Filter p = new Filter(selector.source(), selector.child(), Predicates.combineAnd(selectorConditions));
        return new MapResult(p, extras);
    }

    private static MapResult mapFunction(PromqlCommand promqlCommand, PromqlFunctionCall functionCall) {
        MapResult childResult = TranslatePromqlToTimeSeriesAggregate.map(promqlCommand, functionCall.child());
        Map<String, Expression> extras = childResult.extras;
        Expression target = extras.get("field");
        if (functionCall instanceof WithinSeriesAggregate) {
            WithinSeriesAggregate withinAggregate = (WithinSeriesAggregate)functionCall;
            Function esqlFunction = PromqlFunctionRegistry.INSTANCE.buildEsqlFunction(withinAggregate.functionName(), withinAggregate.source(), List.of(target, promqlCommand.timestamp()));
            extras.put("field", esqlFunction);
            return new MapResult(childResult.plan, extras);
        }
        if (functionCall instanceof AcrossSeriesAggregate) {
            AcrossSeriesAggregate acrossAggregate = (AcrossSeriesAggregate)functionCall;
            ArrayList<NamedExpression> aggs = new ArrayList<NamedExpression>();
            ArrayList<Expression> groupings = new ArrayList<Expression>(acrossAggregate.groupings().size());
            Alias stepBucket = TranslatePromqlToTimeSeriesAggregate.createStepBucketAlias(promqlCommand);
            TranslatePromqlToTimeSeriesAggregate.initAggregatesAndGroupings(acrossAggregate, target, aggs, groupings, stepBucket.toAttribute());
            Eval p = new Eval(stepBucket.source(), childResult.plan, List.of(stepBucket));
            TimeSeriesAggregate timeSeriesAggregate = new TimeSeriesAggregate(acrossAggregate.source(), p, groupings, aggs, null);
            return new MapResult(timeSeriesAggregate, extras);
        }
        throw new QlIllegalArgumentException("Unsupported PromQL function call: {}", new Object[]{functionCall});
    }

    private static void initAggregatesAndGroupings(AcrossSeriesAggregate acrossAggregate, Expression target, List<NamedExpression> aggs, List<Expression> groupings, Attribute stepBucket) {
        Function esqlFunction = PromqlFunctionRegistry.INSTANCE.buildEsqlFunction(acrossAggregate.functionName(), acrossAggregate.source(), List.of(target));
        Alias value = new Alias(acrossAggregate.source(), acrossAggregate.sourceText(), esqlFunction);
        aggs.add(value);
        aggs.add(stepBucket);
        groupings.add(stepBucket);
        for (NamedExpression grouping : acrossAggregate.groupings()) {
            aggs.add(grouping);
            groupings.add(grouping.toAttribute());
        }
    }

    private static Alias createStepBucketAlias(PromqlCommand promqlCommand) {
        Literal timeBucketSize = promqlCommand.isRangeQuery() ? promqlCommand.step() : Literal.timeDuration(promqlCommand.source(), DEFAULT_LOOKBACK);
        Bucket b = new Bucket(promqlCommand.source(), promqlCommand.timestamp(), timeBucketSize, null, null, ConfigurationAware.CONFIGURATION_MARKER);
        return new Alias(b.source(), "step", b, promqlCommand.stepId());
    }

    static Expression translateLabelMatchers(Source source, List<Expression> fields, LabelMatchers labelMatchers) {
        ArrayList<Expression> conditions = new ArrayList<Expression>();
        boolean hasNameMatcher = false;
        List<LabelMatcher> matchers = labelMatchers.matchers();
        int s = matchers.size();
        for (int i = 0; i < s; ++i) {
            LabelMatcher matcher = matchers.get(i);
            if ("__name__".equals(matcher.name())) {
                hasNameMatcher = true;
                continue;
            }
            Expression field = fields.get(hasNameMatcher ? i - 1 : i);
            Expression condition = TranslatePromqlToTimeSeriesAggregate.translateLabelMatcher(source, field, matcher);
            if (condition == null) continue;
            conditions.add(condition);
        }
        if (conditions.isEmpty()) {
            return null;
        }
        return Predicates.combineAnd(conditions);
    }

    private static Expression translateLabelMatcher(Source source, Expression field, LabelMatcher matcher) {
        if (matcher.matchesAll()) {
            return Literal.fromBoolean(source, true);
        }
        if (matcher.matchesNone()) {
            return Literal.fromBoolean(source, false);
        }
        String exactMatch = AutomatonUtils.matchesExact(matcher.automaton());
        if (exactMatch != null) {
            return new Equals(source, field, Literal.keyword(source, exactMatch));
        }
        List<AutomatonUtils.PatternFragment> fragments = AutomatonUtils.extractFragments(matcher.value());
        if (fragments != null && !fragments.isEmpty()) {
            return TranslatePromqlToTimeSeriesAggregate.translateDisjointPatterns(source, field, fragments);
        }
        return new RLike(source, field, new RLikePattern(matcher.toString()));
    }

    private static Expression translateDisjointPatterns(Source source, Expression field, List<AutomatonUtils.PatternFragment> fragments) {
        ArrayList<AutomatonUtils.PatternFragment> sortedFragments = new ArrayList<AutomatonUtils.PatternFragment>(fragments);
        sortedFragments.sort(Comparator.comparingInt(a -> a.type().ordinal()));
        AutomatonUtils.PatternFragment.Type firstType = ((AutomatonUtils.PatternFragment)sortedFragments.get(0)).type();
        boolean homogeneous = true;
        for (AutomatonUtils.PatternFragment fragment : sortedFragments) {
            if (fragment.type() == firstType) continue;
            homogeneous = false;
            break;
        }
        if (homogeneous && firstType == AutomatonUtils.PatternFragment.Type.EXACT) {
            ArrayList<Expression> values = new ArrayList<Expression>(sortedFragments.size());
            for (AutomatonUtils.PatternFragment fragment : sortedFragments) {
                values.add(Literal.keyword(source, fragment.value()));
            }
            return new In(source, field, values);
        }
        ArrayList<Expression> conditions = new ArrayList<Expression>(sortedFragments.size());
        for (AutomatonUtils.PatternFragment fragment : sortedFragments) {
            Expression condition = TranslatePromqlToTimeSeriesAggregate.translatePatternFragment(source, field, fragment);
            conditions.add(condition);
        }
        return Predicates.combineOr(conditions);
    }

    private static Expression translatePatternFragment(Source source, Expression field, AutomatonUtils.PatternFragment fragment) {
        Literal value = Literal.keyword(source, fragment.value());
        return switch (fragment.type()) {
            default -> throw new MatchException(null, null);
            case AutomatonUtils.PatternFragment.Type.EXACT -> new Equals(source, field, value);
            case AutomatonUtils.PatternFragment.Type.PREFIX -> new StartsWith(source, field, value);
            case AutomatonUtils.PatternFragment.Type.PROPER_PREFIX -> new And(source, new NotEquals(source, field, value), new StartsWith(source, field, value));
            case AutomatonUtils.PatternFragment.Type.SUFFIX -> new EndsWith(source, field, value);
            case AutomatonUtils.PatternFragment.Type.PROPER_SUFFIX -> new And(source, new NotEquals(source, field, value), new EndsWith(source, field, value));
            case AutomatonUtils.PatternFragment.Type.REGEX -> new RLike(source, field, new RLikePattern(fragment.value()));
        };
    }

    private record MapResult(LogicalPlan plan, Map<String, Expression> extras) {
    }
}

