/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.function.Function;
import java.util.function.Predicate;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownUtils;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;

public final class PushDownAndCombineFilters
extends OptimizerRules.ParameterizedOptimizerRule<Filter, LogicalOptimizerContext> {
    private static Function<Expression, Expression> NO_OP = expression -> expression;

    public PushDownAndCombineFilters() {
        super(OptimizerRules.TransformDirection.DOWN);
    }

    @Override
    protected LogicalPlan rule(Filter filter, LogicalOptimizerContext ctx) {
        LogicalPlan plan = filter;
        LogicalPlan child = filter.child();
        Expression condition = filter.condition();
        if (child instanceof Filter) {
            Filter f = (Filter)child;
            plan = f.with(Predicates.combineAnd(List.of(f.condition(), condition)));
        } else if (child instanceof Eval) {
            Eval eval = (Eval)child;
            AttributeMap.Builder aliasesBuilder = AttributeMap.builder();
            for (Alias alias : eval.fields()) {
                aliasesBuilder.put(alias.toAttribute(), (Object)alias.child());
            }
            AttributeMap evalAliases = aliasesBuilder.build();
            Function<Expression, Expression> resolveRenames = expr -> (Expression)expr.transformDown(ReferenceAttribute.class, r -> {
                Expression resolved = (Expression)evalAliases.resolve(r, null);
                if (resolved instanceof Attribute && eval.inputSet().contains((Object)resolved)) {
                    return resolved;
                }
                return r;
            });
            plan = PushDownAndCombineFilters.maybePushDownPastUnary(filter, eval, arg_0 -> ((AttributeMap)evalAliases).containsKey(arg_0), resolveRenames);
        } else if (child instanceof RegexExtract) {
            RegexExtract re = (RegexExtract)child;
            AttributeSet attributes = AttributeSet.of((Collection)Expressions.asAttributes(re.extractedFields()));
            plan = PushDownAndCombineFilters.maybePushDownPastUnary(filter, re, arg_0 -> ((AttributeSet)attributes).contains(arg_0), NO_OP);
        } else if (child instanceof InferencePlan) {
            InferencePlan inferencePlan = (InferencePlan)child;
            AttributeSet attributes = AttributeSet.of((Collection)inferencePlan.generatedAttributes());
            plan = PushDownAndCombineFilters.maybePushDownPastUnary(filter, inferencePlan, arg_0 -> ((AttributeSet)attributes).contains(arg_0), NO_OP);
        } else if (child instanceof Enrich) {
            Enrich enrich = (Enrich)child;
            AttributeSet attributes = AttributeSet.of((Collection)Expressions.asAttributes(enrich.enrichFields()));
            plan = PushDownAndCombineFilters.maybePushDownPastUnary(filter, enrich, arg_0 -> ((AttributeSet)attributes).contains(arg_0), NO_OP);
        } else {
            if (child instanceof Project) {
                return PushDownUtils.pushDownPastProject(filter);
            }
            if (child instanceof OrderBy) {
                OrderBy orderBy = (OrderBy)child;
                plan = orderBy.replaceChild(filter.with(orderBy.child(), condition));
            } else if (child instanceof Join) {
                Join join = (Join)child;
                return PushDownAndCombineFilters.pushDownPastJoin(filter, join, ctx.foldCtx());
            }
        }
        return plan;
    }

    private static ScopedFilter scopeFilter(List<Expression> filters, LogicalPlan left, LogicalPlan right) {
        ArrayList<Expression> rest = new ArrayList<Expression>(filters);
        ArrayList<Expression> leftFilters = new ArrayList<Expression>();
        ArrayList<Expression> rightFilters = new ArrayList<Expression>();
        AttributeSet leftOutput = left.outputSet();
        AttributeSet rightOutput = right.outputSet();
        rest.removeIf(f -> f.references().subsetOf(leftOutput) && leftFilters.add((Expression)f));
        rest.removeIf(f -> f.references().subsetOf(rightOutput) && rightFilters.add((Expression)f));
        return new ScopedFilter(rest, leftFilters, rightFilters);
    }

    private static ScopedFilter scopeInlineStatsFilter(List<Expression> filters, InlineJoin ij) {
        ArrayList<Expression> rightFilters = new ArrayList<Expression>();
        ArrayList<Expression> bothSides = new ArrayList<Expression>();
        ArrayList<Expression> leftFilters = new ArrayList<Expression>(filters);
        AttributeSet leftOutputSet = ij.left().outputSet();
        AttributeSet rightOutputSet = ij.right().outputSet();
        AttributeSet rightOutputSetWithoutKeys = rightOutputSet.subtract(AttributeSet.of(ij.config().rightFields()));
        leftFilters.removeIf(f -> {
            if (f.references().subsetOf(rightOutputSet)) {
                if (f.references().subsetOf(leftOutputSet)) {
                    bothSides.add((Expression)f);
                    return true;
                }
                if (f.references().subsetOf(rightOutputSetWithoutKeys)) {
                    rightFilters.add((Expression)f);
                    return true;
                }
            }
            return false;
        });
        return new ScopedFilter(leftFilters, bothSides, rightFilters);
    }

    private static LogicalPlan pushDownPastJoin(Filter filter, Join join, FoldContext foldCtx) {
        LogicalPlan plan = filter;
        if (join.config().type() == JoinTypes.LEFT) {
            List<Expression> rightPushableFilters;
            ScopedFilter scopedFilter;
            LogicalPlan left = join.left();
            LogicalPlan right = join.right();
            List<Expression> conjunctions = Predicates.splitAnd(filter.condition());
            if (join instanceof InlineJoin) {
                InlineJoin ij = (InlineJoin)join;
                scopedFilter = PushDownAndCombineFilters.scopeInlineStatsFilter(conjunctions, ij);
            } else {
                scopedFilter = PushDownAndCombineFilters.scopeFilter(conjunctions, left, right);
            }
            ScopedFilter scoped = scopedFilter;
            boolean optimizationApplied = false;
            if (scoped.leftFilters.size() > 0) {
                left = new Filter(left.source(), left, Predicates.combineAnd(scoped.leftFilters));
                join = (Join)join.replaceLeft(left);
                scoped = new ScopedFilter(scoped.commonFilters, List.of(), scoped.rightFilters);
                optimizationApplied = true;
            }
            if (!(scoped.rightFilters.isEmpty() || join instanceof InlineJoin || (rightPushableFilters = PushDownAndCombineFilters.buildRightPushableFilters(scoped.rightFilters, foldCtx)).isEmpty())) {
                LogicalPlan logicalPlan = join.right();
                if (logicalPlan instanceof Filter) {
                    Filter existingRightFilter = (Filter)logicalPlan;
                    ArrayList<Expression> existingFilters = new ArrayList<Expression>(Predicates.splitAnd(existingRightFilter.condition()));
                    int sizeBefore = existingFilters.size();
                    rightPushableFilters.stream().filter(e -> !existingFilters.stream().anyMatch(x -> x.semanticEquals(e))).forEach(existingFilters::add);
                    if (sizeBefore != existingFilters.size()) {
                        right = existingRightFilter.with(Predicates.combineAnd(existingFilters));
                        join = (Join)join.replaceRight(right);
                        optimizationApplied = true;
                    }
                } else {
                    right = new Filter(right.source(), right, Predicates.combineAnd(rightPushableFilters));
                    join = (Join)join.replaceRight(right);
                    optimizationApplied = true;
                }
            }
            if (optimizationApplied) {
                Expression remainingFilter = Predicates.combineAnd(CollectionUtils.combine(scoped.commonFilters(), scoped.rightFilters));
                plan = remainingFilter != null ? filter.with(join, remainingFilter) : join;
            }
        }
        return plan;
    }

    private static List<Expression> buildRightPushableFilters(List<Expression> expressions, FoldContext foldCtx) {
        return expressions.stream().filter(x -> PushDownAndCombineFilters.isRightPushableFilter(x, foldCtx)).toList();
    }

    private static boolean isRightPushableFilter(Expression filter, FoldContext foldCtx) {
        Expression nullifiedFilter = (Expression)filter.transformUp(Attribute.class, r -> new Literal(r.source(), null, r.dataType()));
        if (nullifiedFilter.foldable()) {
            Object folded = nullifiedFilter.fold(foldCtx);
            return folded == null || Boolean.FALSE.equals(folded);
        }
        return false;
    }

    private static LogicalPlan maybePushDownPastUnary(Filter filter, UnaryPlan unary, Predicate<Expression> cannotPush, Function<Expression, Expression> resolveRenames) {
        UnaryPlan plan;
        ArrayList<Expression> pushable = new ArrayList<Expression>();
        ArrayList<Expression> nonPushable = new ArrayList<Expression>();
        for (Expression exp : Predicates.splitAnd(filter.condition())) {
            Expression resolvedExp = resolveRenames.apply(exp);
            if (resolvedExp.anyMatch(cannotPush)) {
                nonPushable.add(exp);
                continue;
            }
            pushable.add(resolvedExp);
        }
        if (!pushable.isEmpty()) {
            Filter pushed = filter.with(unary.child(), Predicates.combineAnd(pushable));
            plan = !nonPushable.isEmpty() ? filter.with(unary.replaceChild(pushed), Predicates.combineAnd(nonPushable)) : unary.replaceChild(pushed);
        } else {
            plan = filter;
        }
        return plan;
    }

    private record ScopedFilter(List<Expression> commonFilters, List<Expression> leftFilters, List<Expression> rightFilters) {
    }
}

