/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.List;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.fulltext.Score;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Fork;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnionAll;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
import org.elasticsearch.xpack.esql.rule.Rule;

public final class PushDownAndCombineLimits
extends OptimizerRules.ParameterizedOptimizerRule<Limit, LogicalOptimizerContext>
implements OptimizerRules.LocalAware<Limit> {
    private final boolean local;

    public PushDownAndCombineLimits() {
        this(false);
    }

    private PushDownAndCombineLimits(boolean local) {
        super(OptimizerRules.TransformDirection.DOWN);
        this.local = local;
    }

    @Override
    public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) {
        LogicalPlan logicalPlan = limit.child();
        if (logicalPlan instanceof Limit) {
            Limit childLimit = (Limit)logicalPlan;
            return PushDownAndCombineLimits.combineLimits(limit, childLimit, ctx.foldCtx());
        }
        logicalPlan = limit.child();
        if (logicalPlan instanceof UnaryPlan) {
            UnaryPlan unary = (UnaryPlan)logicalPlan;
            if (unary instanceof Eval || unary instanceof Project || unary instanceof RegexExtract || unary instanceof InferencePlan) {
                if (!this.local && unary instanceof Eval && this.evalAliasNeedsData((Eval)unary)) {
                    return limit;
                }
                return unary.replaceChild(limit.replaceChild(unary.child()));
            }
            if (unary instanceof MvExpand) {
                return PushDownAndCombineLimits.duplicateLimitAsFirstGrandchild(limit, false);
            }
            if (unary instanceof Enrich) {
                Enrich enrich = (Enrich)unary;
                if (enrich.mode() == Enrich.Mode.REMOTE) {
                    return PushDownAndCombineLimits.duplicateLimitAsFirstGrandchild(limit, true);
                }
                return enrich.replaceChild(limit.replaceChild(enrich.child()));
            }
            Limit descendantLimit = PushDownAndCombineLimits.descendantLimit(unary);
            if (descendantLimit != null) {
                int l1 = (Integer)limit.limit().fold(ctx.foldCtx());
                int l2 = (Integer)descendantLimit.limit().fold(ctx.foldCtx());
                if (l2 <= l1) {
                    return limit.withLimit(descendantLimit.limit());
                }
            }
        } else {
            Join join;
            logicalPlan = limit.child();
            if (logicalPlan instanceof Join && (join = (Join)logicalPlan).config().type() == JoinTypes.LEFT && !(join instanceof InlineJoin)) {
                return PushDownAndCombineLimits.duplicateLimitAsFirstGrandchild(limit, false);
            }
            logicalPlan = limit.child();
            if (logicalPlan instanceof Fork) {
                Fork fork = (Fork)logicalPlan;
                return PushDownAndCombineLimits.maybePushDownLimitToFork(limit, fork, ctx);
            }
        }
        return limit;
    }

    private static LogicalPlan maybePushDownLimitToFork(Limit limit, Fork fork, LogicalOptimizerContext ctx) {
        if (fork instanceof UnionAll) {
            return limit;
        }
        ArrayList<LogicalPlan> newForkChildren = new ArrayList<LogicalPlan>();
        boolean changed = false;
        for (LogicalPlan forkChild : fork.children()) {
            LogicalPlan newForkChild = PushDownAndCombineLimits.maybePushDownLimitToForkBranch(limit, forkChild, ctx);
            changed = changed || newForkChild != forkChild;
            newForkChildren.add(newForkChild);
        }
        return changed ? limit.replaceChild((LogicalPlan)fork.replaceChildren(newForkChildren)) : limit;
    }

    private static LogicalPlan maybePushDownLimitToForkBranch(Limit limit, LogicalPlan forkBranch, LogicalOptimizerContext ctx) {
        int limitValue;
        if (!(forkBranch instanceof UnaryPlan)) {
            return forkBranch;
        }
        Limit descendantLimit = PushDownAndCombineLimits.descendantLimit((UnaryPlan)forkBranch);
        if (descendantLimit == null) {
            return forkBranch;
        }
        int descendantLimitValue = (Integer)descendantLimit.limit().fold(ctx.foldCtx());
        return descendantLimitValue > (limitValue = ((Integer)limit.limit().fold(ctx.foldCtx())).intValue()) ? new Limit(forkBranch.source(), limit.limit(), forkBranch) : forkBranch;
    }

    private static Limit combineLimits(Limit upper, Limit lower, FoldContext ctx) {
        int upperLimitValue = (Integer)upper.limit().fold(ctx);
        int lowerLimitValue = (Integer)lower.limit().fold(ctx);
        if (lowerLimitValue < upperLimitValue) {
            return lower;
        }
        if (lowerLimitValue == upperLimitValue) {
            return lower.local() ? lower : lower.withLocal(upper.local());
        }
        return new Limit(upper.source(), upper.limit(), lower.child(), upper.duplicated(), upper.local());
    }

    private boolean evalAliasNeedsData(Eval eval) {
        Holder hasScore = new Holder((Object)false);
        eval.forEachExpression(expr -> {
            if (expr instanceof Score) {
                hasScore.set((Object)true);
            }
        });
        return (Boolean)hasScore.get();
    }

    private static Limit descendantLimit(UnaryPlan unary) {
        UnaryPlan plan = unary;
        while (!(plan instanceof Aggregate)) {
            UnaryPlan unaryPlan;
            if (plan instanceof Limit) {
                Limit limit = (Limit)plan;
                return limit;
            }
            if (plan instanceof MvExpand) {
                return null;
            }
            LogicalPlan logicalPlan = plan.child();
            if (!(logicalPlan instanceof UnaryPlan)) break;
            plan = unaryPlan = (UnaryPlan)logicalPlan;
        }
        return null;
    }

    private static Limit duplicateLimitAsFirstGrandchild(Limit limit, boolean withLocal) {
        if (limit.duplicated()) {
            return limit;
        }
        List grandChildren = limit.child().children();
        LogicalPlan firstGrandChild = (LogicalPlan)grandChildren.getFirst();
        Limit newFirstGrandChild = (withLocal ? limit.withLocal(withLocal) : limit).replaceChild(firstGrandChild);
        ArrayList<LogicalPlan> newGrandChildren = new ArrayList<LogicalPlan>();
        newGrandChildren.add(newFirstGrandChild);
        for (int i = 1; i < grandChildren.size(); ++i) {
            newGrandChildren.add((LogicalPlan)grandChildren.get(i));
        }
        LogicalPlan newChild = (LogicalPlan)limit.child().replaceChildren(newGrandChildren);
        return limit.replaceChild(newChild).withDuplicated(true);
    }

    @Override
    public Rule<Limit, LogicalPlan> local() {
        return new PushDownAndCombineLimits(true);
    }
}

