/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical.local;

import java.util.LinkedHashSet;
import java.util.List;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AllFirst;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.stats.SearchStats;

public class InferNonNullAggConstraint
extends OptimizerRules.ParameterizedOptimizerRule<Aggregate, LocalLogicalOptimizerContext> {
    public InferNonNullAggConstraint() {
        super(OptimizerRules.TransformDirection.UP);
    }

    @Override
    protected LogicalPlan rule(Aggregate aggregate, LocalLogicalOptimizerContext context) {
        if (aggregate.groupings().size() > 0 || aggregate instanceof TimeSeriesAggregate) {
            return aggregate;
        }
        SearchStats stats = context.searchStats();
        Aggregate plan = aggregate;
        List<? extends NamedExpression> aggs = aggregate.aggregates();
        LinkedHashSet nonNullAggFields = Sets.newLinkedHashSetWithExpectedSize((int)aggs.size());
        for (NamedExpression namedExpression : aggs) {
            FieldAttribute fa;
            Expression expression = Alias.unwrap(namedExpression);
            if (!(expression instanceof AggregateFunction)) continue;
            AggregateFunction af = (AggregateFunction)expression;
            Expression field = af.field();
            if (af instanceof AllFirst) {
                return plan;
            }
            if (!field.foldable() && field instanceof FieldAttribute && stats.isIndexed((fa = (FieldAttribute)field).fieldName())) {
                nonNullAggFields.add(field);
                continue;
            }
            return plan;
        }
        if (nonNullAggFields.size() > 0) {
            Expression condition = Predicates.combineOr(nonNullAggFields.stream().map(f -> new IsNotNull(aggregate.source(), (Expression)f)).toList());
            plan = aggregate.replaceChild(new Filter(aggregate.source(), aggregate.child(), condition));
        }
        return plan;
    }
}

