/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.physical.local;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.DimensionValues;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FirstDocId;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
import org.elasticsearch.xpack.esql.planner.AggregateMapper;

public final class ExtractDimensionFieldsAfterAggregation
extends PhysicalOptimizerRules.ParameterizedOptimizerRule<PhysicalPlan, LocalPhysicalOptimizerContext> {
    @Override
    public PhysicalPlan rule(PhysicalPlan plan, LocalPhysicalOptimizerContext context) {
        TimeSeriesAggregateExec oldAgg;
        if (plan instanceof TimeSeriesAggregateExec && (oldAgg = (TimeSeriesAggregateExec)plan).getMode() == AggregatorMode.INITIAL) {
            return this.rule(oldAgg, context);
        }
        return plan;
    }

    @Override
    private PhysicalPlan rule(TimeSeriesAggregateExec oldAgg, LocalPhysicalOptimizerContext context) {
        AttributeSet inputAttributes = oldAgg.inputSet();
        Attribute sourceAttr = inputAttributes.stream().filter(EsQueryExec::isDocAttribute).findFirst().orElse(null);
        if (sourceAttr == null) {
            return oldAgg;
        }
        ArrayList<Object> newAggregates = new ArrayList<Object>();
        ArrayList<Attribute> dimensionFields = new ArrayList<Attribute>();
        ArrayList<Alias> aliases = new ArrayList<Alias>();
        HashSet<AggregateFunction> seen = new HashSet<AggregateFunction>();
        List<Attribute> oldIntermediates = oldAgg.intermediateAttributes();
        ArrayList<Attribute> newIntermediates = new ArrayList<Attribute>(oldIntermediates.subList(0, oldAgg.groupings().size()));
        int intermediateOffset = oldAgg.groupings().size();
        for (NamedExpression namedExpression : oldAgg.aggregates()) {
            FieldAttribute dimensionField = null;
            Expression expression = Alias.unwrap((Expression)namedExpression);
            if (expression instanceof AggregateFunction) {
                AggregateFunction af = (AggregateFunction)expression;
                dimensionField = ExtractDimensionFieldsAfterAggregation.valuesOfDimensionField(af, inputAttributes);
                if (seen.add(af)) {
                    int size = ExtractDimensionFieldsAfterAggregation.intermediateStateSize(af);
                    if (dimensionField != null) {
                        if (size != 1) {
                            throw new IllegalStateException("expected one intermediate attribute for [" + String.valueOf(af) + "] but got [" + size + "]");
                        }
                        Attribute oldAttr = oldIntermediates.get(intermediateOffset);
                        aliases.add(new Alias(namedExpression.source(), namedExpression.name(), (Expression)dimensionField, oldAttr.id()));
                        dimensionFields.add((Attribute)dimensionField);
                    } else {
                        for (int i = 0; i < size; ++i) {
                            newIntermediates.add(oldIntermediates.get(intermediateOffset + i));
                        }
                    }
                    intermediateOffset += size;
                }
            }
            if (dimensionField != null) continue;
            newAggregates.add(namedExpression);
        }
        if (dimensionFields.isEmpty()) {
            return oldAgg;
        }
        newIntermediates.add((Attribute)new ReferenceAttribute(oldAgg.source(), sourceAttr.qualifier(), sourceAttr.name(), sourceAttr.dataType()));
        newAggregates.add(new Alias(oldAgg.source(), sourceAttr.name(), (Expression)new FirstDocId(oldAgg.source(), (Expression)sourceAttr)));
        FieldExtractExec fieldExtractExec = new FieldExtractExec(oldAgg.source(), new TimeSeriesAggregateExec(oldAgg.source(), oldAgg.child(), oldAgg.groupings(), newAggregates, oldAgg.getMode(), newIntermediates, oldAgg.estimatedRowSize(), oldAgg.timeBucket()), dimensionFields, context.configuration().pragmas().fieldExtractPreference());
        EvalExec evalExec = new EvalExec(oldAgg.source(), fieldExtractExec, aliases);
        return new ProjectExec(oldAgg.source(), evalExec, oldIntermediates);
    }

    private static FieldAttribute valuesOfDimensionField(AggregateFunction af, AttributeSet inputAttributes) {
        FieldAttribute fa;
        Expression expression;
        DimensionValues values;
        if (af instanceof DimensionValues && !(values = (DimensionValues)af).hasFilter() && (expression = values.field()) instanceof FieldAttribute && (fa = (FieldAttribute)expression).isDimension() && !inputAttributes.contains((Object)fa)) {
            return fa;
        }
        return null;
    }

    private static int intermediateStateSize(AggregateFunction af) {
        return AggregateMapper.intermediateStateDesc(af, true).size();
    }
}

