/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.Subquery;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnionAll;
import org.elasticsearch.xpack.esql.rule.Rule;

public final class PushDownFilterAndLimitIntoUnionAll
extends Rule<LogicalPlan, LogicalPlan> {
    private static final String UNIONALL = "unionall";
    private static final String prefix = "$$unionall$";

    @Override
    public LogicalPlan apply(LogicalPlan logicalPlan) {
        LogicalPlan planWithFilterPushedDownPastUnionAll = (LogicalPlan)logicalPlan.transformDown(Filter.class, filter -> {
            LogicalPlan logicalPlan;
            LogicalPlan patt0$temp = filter.child();
            if (patt0$temp instanceof UnionAll) {
                UnionAll unionAll = (UnionAll)patt0$temp;
                logicalPlan = PushDownFilterAndLimitIntoUnionAll.maybePushDownPastUnionAll(filter, unionAll);
            } else {
                logicalPlan = filter;
            }
            return logicalPlan;
        });
        return (LogicalPlan)planWithFilterPushedDownPastUnionAll.transformDown(Limit.class, PushDownFilterAndLimitIntoUnionAll::pushLimitAndFilterPastSubquery);
    }

    private static LogicalPlan maybePushDownPastUnionAll(Filter filter, UnionAll unionAll) {
        AttributeSet unionAllOutputSet = unionAll.outputSet();
        Tuple<List<Expression>, List<Expression>> pushablesAndNonPushables = PushDownFilterAndLimitIntoUnionAll.splitPushableAndNonPushablePredicates(Predicates.splitAnd(filter.condition()), exp -> !PushDownFilterAndLimitIntoUnionAll.isSubset(exp.references(), unionAllOutputSet));
        List pushable = (List)pushablesAndNonPushables.v1();
        List nonPushable = (List)pushablesAndNonPushables.v2();
        if (pushable.isEmpty()) {
            return filter;
        }
        ArrayList<LogicalPlan> newChildren = new ArrayList<LogicalPlan>();
        boolean changed = false;
        for (LogicalPlan child : unionAll.children()) {
            LogicalPlan newChild;
            LogicalPlan logicalPlan;
            Objects.requireNonNull(child);
            int n = 0;
            switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{Project.class, Limit.class}, (Object)((Object)logicalPlan), n)) {
                case 0: {
                    Project project = (Project)logicalPlan;
                    LogicalPlan logicalPlan2 = PushDownFilterAndLimitIntoUnionAll.maybePushDownFilterPastProjectForUnionAllChild(pushable, project);
                    break;
                }
                case 1: {
                    Limit limit = (Limit)logicalPlan;
                    LogicalPlan logicalPlan2 = PushDownFilterAndLimitIntoUnionAll.maybePushDownFilterPastLimitForUnionAllChild(pushable, limit);
                    break;
                }
                default: {
                    LogicalPlan logicalPlan2 = newChild = null;
                }
            }
            if (newChild == null) {
                return filter;
            }
            if (newChild != child) {
                changed = true;
                newChildren.add(newChild);
                continue;
            }
            return filter;
        }
        if (!changed) {
            return filter;
        }
        Node newUnionAll = unionAll.replaceChildren(newChildren);
        if (nonPushable.isEmpty()) {
            return newUnionAll;
        }
        return filter.with((LogicalPlan)newUnionAll, Predicates.combineAnd(nonPushable));
    }

    private static LogicalPlan maybePushDownFilterPastProjectForUnionAllChild(List<Expression> pushable, Project project) {
        List<Expression> resolvedPushable = PushDownFilterAndLimitIntoUnionAll.resolvePushableAgainstOutput(pushable, project.projections());
        if (resolvedPushable == null) {
            return project;
        }
        LogicalPlan child = project.child();
        Tuple<List<Expression>, List<Expression>> pushablesAndNonPushables = PushDownFilterAndLimitIntoUnionAll.splitPushableAndNonPushablePredicates(resolvedPushable, exp -> !PushDownFilterAndLimitIntoUnionAll.isSubset(exp.references(), child.outputSet()));
        List newResolvedPushable = (List)pushablesAndNonPushables.v1();
        List newResolvedNonPushable = (List)pushablesAndNonPushables.v2();
        if (newResolvedPushable.isEmpty()) {
            return newResolvedNonPushable.isEmpty() ? project : PushDownFilterAndLimitIntoUnionAll.filterWithPlanAsChild(project, newResolvedNonPushable);
        }
        LogicalPlan planWithNewResolvedPushablePushedDown = project;
        if (child instanceof Eval) {
            Eval eval = (Eval)child;
            planWithNewResolvedPushablePushedDown = PushDownFilterAndLimitIntoUnionAll.pushDownFilterPastEvalForUnionAllChild(newResolvedPushable, project, eval);
        } else if (child instanceof Limit) {
            Limit limit = (Limit)child;
            LogicalPlan newLimit = PushDownFilterAndLimitIntoUnionAll.pushDownFilterPastLimitForUnionAllChild(newResolvedPushable, limit);
            planWithNewResolvedPushablePushedDown = project.replaceChild(newLimit);
        }
        if (planWithNewResolvedPushablePushedDown == project) {
            return project;
        }
        return newResolvedNonPushable.isEmpty() ? planWithNewResolvedPushablePushedDown : PushDownFilterAndLimitIntoUnionAll.filterWithPlanAsChild(planWithNewResolvedPushablePushedDown, newResolvedNonPushable);
    }

    private static LogicalPlan maybePushDownFilterPastLimitForUnionAllChild(List<Expression> pushable, Limit limit) {
        List<Expression> resolvedPushable = PushDownFilterAndLimitIntoUnionAll.resolvePushableAgainstOutput(pushable, limit.output());
        if (resolvedPushable == null) {
            return limit;
        }
        return PushDownFilterAndLimitIntoUnionAll.pushDownFilterPastLimitForUnionAllChild(resolvedPushable, limit);
    }

    private static LogicalPlan pushDownFilterPastEvalForUnionAllChild(List<Expression> pushable, Project project, Eval eval) {
        AttributeMap<Expression> evalAliases = PushDownFilterAndLimitIntoUnionAll.buildEvaAliases(eval);
        Tuple<List<Expression>, List<Expression>> pushablesAndNonPushables = PushDownFilterAndLimitIntoUnionAll.splitPushableAndNonPushablePredicates(pushable, exp -> exp.references().stream().anyMatch(arg_0 -> ((AttributeMap)evalAliases).containsKey(arg_0)));
        List pushables = (List)pushablesAndNonPushables.v1();
        List nonPushables = (List)pushablesAndNonPushables.v2();
        LogicalPlan evalChild = eval.child();
        if (pushables.isEmpty()) {
            return nonPushables.isEmpty() ? project : PushDownFilterAndLimitIntoUnionAll.projectWithFilterAsChild(project, eval, nonPushables);
        }
        if (evalChild instanceof Limit) {
            Limit limit = (Limit)evalChild;
            LogicalPlan newLimit = PushDownFilterAndLimitIntoUnionAll.pushDownFilterPastLimitForUnionAllChild(pushables, limit);
            UnaryPlan newEval = eval.replaceChild(newLimit);
            return nonPushables.isEmpty() ? project.replaceChild(newEval) : PushDownFilterAndLimitIntoUnionAll.projectWithFilterAsChild(project, newEval, nonPushables);
        }
        return project;
    }

    private static LogicalPlan projectWithFilterAsChild(Project project, LogicalPlan child, List<Expression> predicates) {
        Expression combined = Predicates.combineAnd(predicates);
        return project.replaceChild(new Filter(project.source(), child, combined));
    }

    private static Filter filterWithPlanAsChild(LogicalPlan logicalPlan, List<Expression> predicates) {
        Expression combined = Predicates.combineAnd(predicates);
        return new Filter(logicalPlan.source(), logicalPlan, combined);
    }

    private static LogicalPlan pushDownFilterPastLimitForUnionAllChild(List<Expression> pushable, Limit limit) {
        if (pushable.isEmpty()) {
            return limit;
        }
        Expression combined = Predicates.combineAnd(pushable);
        Filter pushed = new Filter(limit.source(), limit.child(), combined);
        return limit.replaceChild(pushed);
    }

    private static boolean isSubset(AttributeSet subset, AttributeSet superset) {
        return subset.stream().allMatch(attr -> superset.stream().anyMatch(superAttr -> superAttr.name().equals(attr.name()) && superAttr.id().equals((Object)attr.id())));
    }

    private static AttributeMap<Expression> buildEvaAliases(Eval eval) {
        AttributeMap.Builder builder = AttributeMap.builder();
        for (Alias alias : eval.fields()) {
            builder.put(alias.toAttribute(), (Object)alias.child());
        }
        return builder.build();
    }

    private static Tuple<List<Expression>, List<Expression>> splitPushableAndNonPushablePredicates(List<Expression> predicates, Predicate<Expression> nonPushableCheck) {
        ArrayList<Expression> pushable = new ArrayList<Expression>();
        ArrayList<Expression> nonPushable = new ArrayList<Expression>();
        for (Expression exp : predicates) {
            if (nonPushableCheck.test(exp)) {
                nonPushable.add(exp);
                continue;
            }
            pushable.add(exp);
        }
        return Tuple.tuple(pushable, nonPushable);
    }

    private static List<Expression> resolvePushableAgainstOutput(List<Expression> pushable, List<? extends NamedExpression> output) {
        ArrayList<Expression> resolved = new ArrayList<Expression>();
        for (Expression exp : pushable) {
            if (exp.references().isEmpty()) {
                resolved.add(exp);
                continue;
            }
            Expression resolvedExp = PushDownFilterAndLimitIntoUnionAll.resolveUnionAllOutputByName(exp, output);
            if (resolvedExp == null || resolvedExp == exp) {
                return null;
            }
            resolved.add(resolvedExp);
        }
        return resolved.size() == pushable.size() ? resolved : null;
    }

    private static Expression resolveUnionAllOutputByName(Expression expr, List<? extends NamedExpression> namedExpressions) {
        Expression renamed = (Expression)expr.transformUp(Attribute.class, attr -> {
            for (NamedExpression ne : namedExpressions) {
                if (!ne.name().equals(attr.name())) continue;
                return attr.withName(Attribute.rawTemporaryName((String[])new String[]{UNIONALL, ne.name()}));
            }
            return attr;
        });
        return (Expression)renamed.transformUp(Attribute.class, attr -> {
            String originalName = attr.name().startsWith(prefix) ? attr.name().substring(prefix.length()) : attr.name();
            for (NamedExpression ne : namedExpressions) {
                if (!ne.name().equals(originalName)) continue;
                return ne.toAttribute();
            }
            return attr;
        });
    }

    private static LogicalPlan pushLimitAndFilterPastSubquery(Limit limit) {
        Filter filter;
        LogicalPlan logicalPlan;
        LogicalPlan child = limit.child();
        if (child instanceof Subquery) {
            Subquery subquery = (Subquery)child;
            Limit newLimit = limit.replaceChild(subquery.child());
            return subquery.replaceChild(newLimit);
        }
        if (child instanceof Filter && (logicalPlan = (filter = (Filter)child).child()) instanceof Subquery) {
            Subquery subquery = (Subquery)logicalPlan;
            Filter newFilter = filter.replaceChild(subquery.child());
            Limit newLimit = limit.replaceChild(newFilter);
            return subquery.replaceChild(newLimit);
        }
        return limit;
    }
}

