/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical.local;

import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.rule.Rule;

public class InferIsNotNull
extends Rule<LogicalPlan, LogicalPlan> {
    @Override
    public LogicalPlan apply(LogicalPlan plan) {
        AttributeMap.Builder aliasesBuilder = AttributeMap.builder();
        plan = plan.transformUp(p -> this.inspectPlan((LogicalPlan)p, aliasesBuilder));
        return plan;
    }

    private LogicalPlan inspectPlan(LogicalPlan plan, AttributeMap.Builder<Expression> aliasesBuilder) {
        plan.forEachExpression(Alias.class, a -> aliasesBuilder.put(a.toAttribute(), a.child()));
        LogicalPlan newPlan = (LogicalPlan)plan.transformExpressionsOnlyUp(IsNotNull.class, inn -> this.inferNotNullable((IsNotNull)inn, aliasesBuilder.build()));
        return newPlan;
    }

    private Expression inferNotNullable(IsNotNull inn, AttributeMap<Expression> aliases) {
        Expression result = inn;
        Set<Expression> refs = this.resolveExpressionAsRootAttributes(inn.field(), aliases);
        if (refs.size() > 0) {
            List<Expression> innList = CollectionUtils.combine(refs.stream().map(r -> new IsNotNull(inn.source(), (Expression)r)).toList(), inn);
            result = Predicates.combineAnd(innList);
        }
        return result;
    }

    protected Set<Expression> resolveExpressionAsRootAttributes(Expression exp, AttributeMap<Expression> aliases) {
        LinkedHashSet<Expression> resolvedExpressions = new LinkedHashSet<Expression>();
        boolean changed = this.doResolve(exp, aliases, resolvedExpressions);
        return changed ? resolvedExpressions : Collections.emptySet();
    }

    private boolean doResolve(Expression exp, AttributeMap<Expression> aliases, Set<Expression> resolvedExpressions) {
        boolean changed = false;
        if (InferIsNotNull.skipExpression(exp)) {
            resolvedExpressions.add(exp);
        } else {
            for (Expression e : exp.references()) {
                Expression resolved = aliases.resolve(e, e);
                if (resolved instanceof Attribute) {
                    Attribute a = (Attribute)resolved;
                    if (resolved == e) {
                        resolvedExpressions.add(a);
                        changed |= resolved != exp;
                        continue;
                    }
                }
                changed |= this.doResolve(resolved, aliases, resolvedExpressions);
            }
        }
        return changed;
    }

    private static boolean skipExpression(Expression e) {
        return e instanceof Coalesce || e instanceof Case;
    }
}

