/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.List;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;

public class ReplaceStatsFilteredAggWithEval
extends OptimizerRules.OptimizerRule<LogicalPlan>
implements OptimizerRules.CoordinatorOnly {
    @Override
    protected LogicalPlan rule(LogicalPlan plan) {
        Aggregate aggregate;
        InlineJoin ij = null;
        if (plan instanceof Aggregate) {
            Aggregate a;
            aggregate = a = (Aggregate)plan;
        } else if (plan instanceof InlineJoin) {
            InlineJoin inlineJoin;
            ij = inlineJoin = (InlineJoin)plan;
            Holder aggHolder = new Holder();
            inlineJoin.right().forEachDown(Aggregate.class, arg_0 -> ((Holder)aggHolder).setIfAbsent(arg_0));
            aggregate = (Aggregate)aggHolder.get();
        } else {
            return plan;
        }
        if (aggregate != null) {
            int oldAggSize = aggregate.aggregates().size();
            ArrayList<NamedExpression> newAggs = new ArrayList<NamedExpression>(oldAggSize);
            ArrayList<Alias> newEvals = new ArrayList<Alias>(oldAggSize);
            ArrayList<NamedExpression> newProjections = new ArrayList<NamedExpression>(oldAggSize);
            for (NamedExpression namedExpression : aggregate.aggregates()) {
                Literal literal;
                AggregateFunction aggFunction;
                Alias alias;
                Expression expression;
                if (namedExpression instanceof Alias && (expression = (alias = (Alias)namedExpression).child()) instanceof AggregateFunction && (aggFunction = (AggregateFunction)expression).hasFilter() && (expression = aggFunction.filter()) instanceof Literal && Boolean.FALSE.equals((literal = (Literal)expression).value())) {
                    Long value = aggFunction instanceof Count || aggFunction instanceof CountDistinct ? Long.valueOf(0L) : null;
                    Alias newAlias = alias.replaceChild((Expression)Literal.of((Expression)aggFunction, (Object)value));
                    newEvals.add(newAlias);
                    newProjections.add((NamedExpression)newAlias.toAttribute());
                    continue;
                }
                newAggs.add(namedExpression);
                newProjections.add((NamedExpression)namedExpression.toAttribute());
            }
            if (!newEvals.isEmpty()) {
                if (newAggs.isEmpty()) {
                    if (ij != null) {
                        LogicalPlan leftHandSide = ij.left();
                        LogicalPlan logicalPlan2 = (LogicalPlan)ij.right().transformDown(Aggregate.class, agg -> agg == aggregate ? new Eval(aggregate.source(), aggregate.child(), newEvals) : agg);
                        logicalPlan2 = InlineJoin.replaceStub(leftHandSide, logicalPlan2);
                        plan = new Project(ij.source(), logicalPlan2, ij.output());
                    } else {
                        plan = ReplaceStatsFilteredAggWithEval.localRelation(aggregate.source(), newEvals);
                    }
                } else {
                    plan = ij != null ? ij.replaceRight((LogicalPlan)ij.right().transformUp(Aggregate.class, agg -> ReplaceStatsFilteredAggWithEval.updateAggregate(agg, newAggs, newEvals, newProjections))) : ReplaceStatsFilteredAggWithEval.updateAggregate(aggregate, newAggs, newEvals, newProjections);
                }
            }
        }
        return plan;
    }

    private static LogicalPlan updateAggregate(Aggregate agg, List<NamedExpression> newAggs, List<Alias> newEvals, List<NamedExpression> newProjections) {
        UnaryPlan newAgg = agg.with(agg.child(), agg.groupings(), newAggs);
        newAgg = new Eval(agg.source(), newAgg, newEvals);
        newAgg = new Project(agg.source(), newAgg, newProjections);
        return newAgg;
    }

    private static LocalRelation localRelation(Source source, List<Alias> newEvals) {
        Block[] blocks = new Block[newEvals.size()];
        ArrayList<Attribute> attributes = new ArrayList<Attribute>(newEvals.size());
        for (int i = 0; i < newEvals.size(); ++i) {
            Alias alias = newEvals.get(i);
            attributes.add(alias.toAttribute());
            blocks[i] = BlockUtils.constantBlock((BlockFactory)PlannerUtils.NON_BREAKING_BLOCK_FACTORY, (Object)((Literal)alias.child()).value(), (int)1);
        }
        return new LocalRelation(source, attributes, LocalSupplier.of(new Page(blocks)));
    }
}

