/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.TemporaryNameUtils;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;

public final class ReplaceAggregateNestedExpressionWithEval
extends OptimizerRules.OptimizerRule<Aggregate> {
    @Override
    protected LogicalPlan rule(Aggregate aggregate) {
        ArrayList<Alias> evals = new ArrayList<Alias>();
        HashMap<String, Attribute> evalNames = new HashMap<String, Attribute>();
        HashMap<GroupingFunction.EvaluatableGroupingFunction, Attribute> groupingAttributes = new HashMap<GroupingFunction.EvaluatableGroupingFunction, Attribute>();
        ArrayList<Expression> newGroupings = new ArrayList<Expression>(aggregate.groupings());
        boolean groupingChanged = false;
        int s = newGroupings.size();
        for (int i = 0; i < s; ++i) {
            Expression g = (Expression)newGroupings.get(i);
            if (!(g instanceof Alias)) continue;
            Alias as2 = (Alias)g;
            Expression asChild = as2.child();
            if (asChild instanceof GroupingFunction.NonEvaluatableGroupingFunction) {
                GroupingFunction.NonEvaluatableGroupingFunction nonEvaluatableGroupingFunction = (GroupingFunction.NonEvaluatableGroupingFunction)asChild;
                Expression expression = ReplaceAggregateNestedExpressionWithEval.transformNonEvaluatableGroupingFunction(nonEvaluatableGroupingFunction, evals);
                if (expression == nonEvaluatableGroupingFunction) continue;
                groupingChanged = true;
                newGroupings.set(i, as2.replaceChild(expression));
                continue;
            }
            groupingChanged = true;
            Attribute attribute = as2.toAttribute();
            evals.add(as2);
            evalNames.put(as2.name(), attribute);
            newGroupings.set(i, attribute);
            if (!(asChild instanceof GroupingFunction.EvaluatableGroupingFunction)) continue;
            GroupingFunction.EvaluatableGroupingFunction gf = (GroupingFunction.EvaluatableGroupingFunction)asChild;
            groupingAttributes.put(gf, attribute);
        }
        Holder aggsChanged = new Holder((Object)false);
        List<? extends NamedExpression> aggs = aggregate.aggregates();
        ArrayList<NamedExpression> newAggs = new ArrayList<NamedExpression>(aggs.size());
        HashMap<Expression, Attribute> expToAttribute = new HashMap<Expression, Attribute>();
        for (Alias alias : evals) {
            expToAttribute.put(alias.child().canonical(), alias.toAttribute());
        }
        int[] counter = new int[]{0};
        for (NamedExpression namedExpression : aggs) {
            NamedExpression a = (NamedExpression)namedExpression.transformDown(Alias.class, as -> {
                AggregateFunction af2;
                Expression child = as.child();
                if (child instanceof AggregateFunction && ReplaceAggregateNestedExpressionWithEval.skipOptimisingAgg(af2 = (AggregateFunction)child)) {
                    return as;
                }
                Attribute ref = (Attribute)evalNames.get(as.name());
                if (ref != null) {
                    aggsChanged.set((Object)true);
                    return ref;
                }
                Expression replaced = child.transformUp(AggregateFunction.class, af -> ReplaceAggregateNestedExpressionWithEval.transformAggregateFunction(af, expToAttribute, evals, counter, (Holder<Boolean>)aggsChanged));
                replaced = replaced.transformDown(GroupingFunction.EvaluatableGroupingFunction.class, gf -> {
                    aggsChanged.set((Object)true);
                    return (Expression)groupingAttributes.get(gf);
                });
                return as.replaceChild(replaced);
            });
            newAggs.add(a);
        }
        if (evals.size() > 0) {
            List<Expression> list = groupingChanged ? newGroupings : aggregate.groupings();
            List<NamedExpression> list2 = (Boolean)aggsChanged.get() != false ? newAggs : aggregate.aggregates();
            Eval newEval = new Eval(aggregate.source(), aggregate.child(), evals);
            aggregate = aggregate.with(newEval, list, list2);
        }
        return aggregate;
    }

    private static Expression transformNonEvaluatableGroupingFunction(GroupingFunction.NonEvaluatableGroupingFunction gf, List<Alias> evals) {
        int counter = 0;
        boolean childrenChanged = false;
        ArrayList<Expression> newChildren = new ArrayList<Expression>(gf.children().size());
        for (Expression ex : gf.children()) {
            if (ex instanceof Attribute || ex instanceof MapExpression) {
                newChildren.add(ex);
                continue;
            }
            Alias alias = new Alias(ex.source(), ReplaceAggregateNestedExpressionWithEval.syntheticName(ex, gf, counter++), ex, null, true);
            evals.add(alias);
            newChildren.add(alias.toAttribute());
            childrenChanged = true;
        }
        return childrenChanged ? (Expression)gf.replaceChildren(newChildren) : gf;
    }

    private static boolean skipOptimisingAgg(AggregateFunction af) {
        if (af.field() instanceof Attribute) {
            return true;
        }
        Holder foundNestedAggs = new Holder((Object)Boolean.FALSE);
        af.field().forEachDown(AggregateFunction.class, unused -> foundNestedAggs.set((Object)Boolean.TRUE));
        return (Boolean)foundNestedAggs.get();
    }

    private static Expression transformAggregateFunction(AggregateFunction af, Map<Expression, Attribute> expToAttribute, List<Alias> evals, int[] counter, Holder<Boolean> aggsChanged) {
        Expression result = af;
        Expression field = af.field();
        if (!(field instanceof Attribute) && !field.foldable()) {
            Attribute attr = expToAttribute.computeIfAbsent(field.canonical(), k -> {
                int n = counter[0];
                counter[0] = n + 1;
                Alias newAlias = new Alias(k.source(), ReplaceAggregateNestedExpressionWithEval.syntheticName(k, af, n), (Expression)k, null, true);
                evals.add(newAlias);
                return newAlias.toAttribute();
            });
            aggsChanged.set((Object)true);
            ArrayList newChildren = new ArrayList(af.children());
            newChildren.set(0, attr);
            result = (Expression)af.replaceChildren(newChildren);
        }
        return result;
    }

    private static String syntheticName(Expression expression, Expression func, int counter) {
        return TemporaryNameUtils.temporaryName(expression, func, counter);
    }
}

