/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.TemporaryNameUtils;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.rule.Rule;

public final class CombineProjections
extends OptimizerRules.OptimizerRule<UnaryPlan>
implements OptimizerRules.LocalAware<UnaryPlan> {
    private final boolean local;

    public CombineProjections() {
        this(false);
    }

    public CombineProjections(boolean local) {
        super(OptimizerRules.TransformDirection.UP);
        this.local = local;
    }

    @Override
    protected LogicalPlan rule(UnaryPlan plan) {
        LogicalPlan child = plan.child();
        if (plan instanceof Project) {
            Project project = (Project)plan;
            if (child instanceof Project) {
                Project p = (Project)child;
                project = p.withProjections(CombineProjections.combineProjections(project.projections(), p.projections()));
                child = project.child();
                plan = project;
            }
            if (child instanceof Aggregate) {
                Aggregate a = (Aggregate)child;
                List<? extends NamedExpression> aggs = a.aggregates();
                List<? extends NamedExpression> newAggs = CombineProjections.projectAggregations(project.projections(), aggs);
                if (newAggs != null) {
                    List<Expression> list = this.replacePrunedAliasesUsedInGroupBy(a.groupings(), aggs, newAggs);
                    plan = a.with(list, newAggs);
                }
            }
            return plan;
        }
        if (plan instanceof Aggregate) {
            Aggregate a = (Aggregate)plan;
            if (child instanceof Project) {
                Project p = (Project)child;
                List<Expression> groupings = a.groupings();
                for (Expression expression : groupings) {
                    Alias alias;
                    if (expression instanceof Attribute || expression instanceof Alias && (alias = (Alias)expression).child() instanceof GroupingFunction.NonEvaluatableGroupingFunction) continue;
                    throw new EsqlIllegalArgumentException("Expected an attribute or grouping function, got {}", expression);
                }
                assert (groupings.size() <= 1 || !groupings.stream().anyMatch(group -> group.anyMatch(expr -> expr instanceof GroupingFunction.NonEvaluatableGroupingFunction))) : "CombineProjections only tested with a single CATEGORIZE with no additional groups";
                AttributeMap.Builder<Attribute> aliasesBuilder = AttributeMap.builder();
                for (NamedExpression namedExpression : p.projections()) {
                    aliasesBuilder.put(namedExpression.toAttribute(), (Attribute)Alias.unwrap(namedExpression));
                }
                AttributeMap attributeMap = aliasesBuilder.build();
                ArrayList<Expression> arrayList = new ArrayList<Expression>();
                for (Expression grouping : groupings) {
                    Expression transformed = grouping.transformUp(Attribute.class, as -> aliases.resolve(as, as));
                    arrayList.add(transformed);
                }
                if (this.local) {
                    HashSet<Expression> seenResolvedGroupings = new HashSet<Expression>(arrayList.size());
                    ArrayList<Expression> newGroupings = new ArrayList<Expression>();
                    ArrayList<Alias> aliasesAgainstDuplication = new ArrayList<Alias>();
                    for (int i = 0; i < groupings.size(); ++i) {
                        Expression resolvedGrouping = (Expression)arrayList.get(i);
                        if (seenResolvedGroupings.add(resolvedGrouping)) {
                            newGroupings.add(resolvedGrouping);
                            continue;
                        }
                        Attribute coreAttribute = resolvedGrouping.references().iterator().next();
                        Alias renameAgainstDuplication = new Alias(coreAttribute.source(), TemporaryNameUtils.locallyUniqueTemporaryName(coreAttribute.name()), coreAttribute);
                        aliasesAgainstDuplication.add(renameAgainstDuplication);
                        AttributeMap.Builder<Attribute> resolverBuilder = AttributeMap.builder();
                        resolverBuilder.put(coreAttribute, renameAgainstDuplication.toAttribute());
                        AttributeMap resolver = resolverBuilder.build();
                        newGroupings.add(resolvedGrouping.transformUp(Attribute.class, attr -> resolver.resolve(attr, attr)));
                    }
                    LogicalPlan newChild = aliasesAgainstDuplication.isEmpty() ? p.child() : new Eval(p.source(), p.child(), aliasesAgainstDuplication);
                    plan = a.with(newChild, newGroupings, CombineProjections.combineProjections(a.aggregates(), p.projections()));
                } else {
                    ArrayList<Expression> newGroupings = new ArrayList<Expression>(new LinkedHashSet(arrayList));
                    plan = a.with(p.child(), newGroupings, CombineProjections.combineProjections(a.aggregates(), p.projections()));
                }
            }
        }
        return plan;
    }

    private static List<? extends NamedExpression> projectAggregations(List<? extends NamedExpression> upperProjection, List<? extends NamedExpression> lowerAggregations) {
        AttributeSet.Builder seen = AttributeSet.builder();
        for (NamedExpression namedExpression : upperProjection) {
            Expression unwrapped = Alias.unwrap(namedExpression);
            if (seen.contains(unwrapped)) {
                return null;
            }
            seen.add(Expressions.attribute(unwrapped));
        }
        lowerAggregations = CombineProjections.combineProjections(upperProjection, lowerAggregations);
        return lowerAggregations;
    }

    private static List<NamedExpression> combineProjections(List<? extends NamedExpression> upper, List<? extends NamedExpression> lower) {
        AttributeMap.Builder<NamedExpression> namedExpressionsBuilder = AttributeMap.builder();
        AttributeMap.Builder<Expression> aliasesBuilder = AttributeMap.builder(lower.size());
        for (NamedExpression namedExpression : lower) {
            aliasesBuilder.put(namedExpression.toAttribute(), Alias.unwrap(namedExpression));
            if (namedExpression instanceof Alias) {
                Alias as = (Alias)namedExpression;
                Expression child = as.child();
                namedExpressionsBuilder.put(namedExpression.toAttribute(), as.replaceChild(aliasesBuilder.build().resolve(child, child)));
                continue;
            }
            if (!(namedExpression instanceof ReferenceAttribute)) continue;
            ReferenceAttribute referenceAttribute = (ReferenceAttribute)namedExpression;
            namedExpressionsBuilder.put(referenceAttribute, referenceAttribute);
        }
        ArrayList<NamedExpression> replaced = new ArrayList<NamedExpression>(upper.size());
        AttributeMap attributeMap = namedExpressionsBuilder.build();
        for (NamedExpression namedExpression : upper) {
            NamedExpression replacedExp = (NamedExpression)namedExpression.transformUp(Attribute.class, a -> namedExpressions.resolve(a, a));
            replaced.add((NamedExpression)CombineProjections.trimNonTopLevelAliases(replacedExp));
        }
        return replaced;
    }

    private List<Expression> replacePrunedAliasesUsedInGroupBy(List<Expression> groupings, List<? extends NamedExpression> oldAggs, List<? extends NamedExpression> newAggs) {
        AttributeMap.Builder<Expression> removedAliasesBuilder = AttributeMap.builder();
        AttributeSet currentAliases = AttributeSet.of(Expressions.asAttributes(newAggs));
        for (NamedExpression namedExpression : oldAggs) {
            if (!(namedExpression instanceof Alias)) continue;
            Alias alias = (Alias)namedExpression;
            Attribute attr = namedExpression.toAttribute();
            if (currentAliases.contains(attr)) continue;
            removedAliasesBuilder.put(attr, alias.child());
        }
        AttributeMap removedAliases = removedAliasesBuilder.build();
        if (removedAliases.isEmpty()) {
            return groupings;
        }
        ArrayList<Expression> arrayList = new ArrayList<Expression>(groupings.size());
        for (Expression group : groupings) {
            Expression transformed = group.transformUp(Attribute.class, a -> removedAliases.resolve(a, a));
            if (Expressions.anyMatch(arrayList, g -> Expressions.equalsAsAttribute(g, transformed))) continue;
            arrayList.add(transformed);
        }
        return arrayList;
    }

    public static Expression trimNonTopLevelAliases(Expression e) {
        Expression expression;
        if (e instanceof Alias) {
            Alias a = (Alias)e;
            expression = a.replaceChild(CombineProjections.trimAliases(a.child()));
        } else {
            expression = CombineProjections.trimAliases(e);
        }
        return expression;
    }

    private static Expression trimAliases(Expression e) {
        return e.transformDown(Alias.class, Alias::child);
    }

    @Override
    public Rule<UnaryPlan, LogicalPlan> local() {
        return new CombineProjections(true);
    }
}

