/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;

public class PushLimitToKnn
extends OptimizerRules.ParameterizedOptimizerRule<Limit, LogicalOptimizerContext> {
    public PushLimitToKnn() {
        super(OptimizerRules.TransformDirection.DOWN);
    }

    @Override
    public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) {
        Holder breakerReached = new Holder((Object)false);
        Holder firstLimit = new Holder((Object)false);
        return limit.transformDown(plan -> {
            if (((Boolean)breakerReached.get()).booleanValue()) {
                return plan;
            }
            if (plan instanceof Filter) {
                Filter filter = (Filter)plan;
                Expression limitAppliedExpression = this.limitFilterExpressions(filter.condition(), limit, ctx);
                if (!limitAppliedExpression.equals(filter.condition())) {
                    return filter.with(limitAppliedExpression);
                }
            } else if (plan instanceof Limit) {
                breakerReached.set((Object)((Boolean)firstLimit.get()));
                firstLimit.set((Object)true);
            } else if (plan instanceof TopN || plan instanceof Rerank || plan instanceof Aggregate) {
                breakerReached.set((Object)true);
            }
            return plan;
        });
    }

    private Expression limitFilterExpressions(Expression condition, Limit limit, LogicalOptimizerContext ctx) {
        return condition.transformDown(exp -> {
            if (exp instanceof Knn) {
                Knn knn = (Knn)exp;
                return knn.withImplicitK((Integer)limit.limit().fold(ctx.foldCtx()));
            }
            return exp;
        });
    }
}

