/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.DimensionValues;
import org.elasticsearch.xpack.esql.expression.function.aggregate.LastOverTime;
import org.elasticsearch.xpack.esql.expression.function.aggregate.TimeSeriesAggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket;
import org.elasticsearch.xpack.esql.expression.function.scalar.internal.PackDimension;
import org.elasticsearch.xpack.esql.expression.function.scalar.internal.UnpackDimension;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;

public final class TranslateTimeSeriesAggregate
extends OptimizerRules.ParameterizedOptimizerRule<TimeSeriesAggregate, LogicalOptimizerContext> {
    public TranslateTimeSeriesAggregate() {
        super(OptimizerRules.TransformDirection.UP);
    }

    @Override
    protected LogicalPlan rule(TimeSeriesAggregate aggregate, LogicalOptimizerContext context) {
        Holder tsid = new Holder();
        Holder timestamp = new Holder();
        aggregate.forEachDown(EsRelation.class, r -> {
            for (Attribute attr : r.output()) {
                if (attr.name().equals("_tsid")) {
                    tsid.set((Object)attr);
                }
                if (!attr.name().equals("@timestamp")) continue;
                timestamp.set((Object)attr);
            }
        });
        if (tsid.get() == null) {
            tsid.set((Object)new MetadataAttribute(aggregate.source(), "_tsid", DataType.KEYWORD, false));
        }
        if (timestamp.get() == null) {
            throw new IllegalArgumentException("_tsid or @timestamp field are missing from the time-series source");
        }
        HashMap<AggregateFunction, Alias> timeSeriesAggs = new HashMap<AggregateFunction, Alias>();
        ArrayList<Alias> firstPassAggs = new ArrayList<Alias>();
        ArrayList<Alias> secondPassAggs = new ArrayList<Alias>();
        Holder requiredTimeSeriesSource = new Holder((Object)Boolean.FALSE);
        InternalNames internalNames = new InternalNames();
        for (NamedExpression namedExpression : aggregate.aggregates()) {
            Expression inlineFilter;
            Alias alias;
            Expression expression;
            if (!(namedExpression instanceof Alias) || !((expression = (alias = (Alias)namedExpression).child()) instanceof AggregateFunction)) continue;
            AggregateFunction af = (AggregateFunction)expression;
            Holder changed = new Holder((Object)Boolean.FALSE);
            if (af.hasFilter()) {
                inlineFilter = af.filter();
                af = af.withFilter((Expression)Literal.TRUE);
            } else {
                inlineFilter = null;
            }
            Expression outerAgg = (Expression)af.transformDown(TimeSeriesAggregateFunction.class, tsAgg -> {
                if (inlineFilter != null) {
                    if (!tsAgg.hasFilter()) {
                        throw new IllegalStateException("inline filter isn't propagated to time-series aggregation");
                    }
                } else if (tsAgg.hasFilter()) {
                    throw new IllegalStateException("unexpected inline filter in time-series aggregation");
                }
                changed.set((Object)Boolean.TRUE);
                if (tsAgg.requiredTimeSeriesSource()) {
                    requiredTimeSeriesSource.set((Object)Boolean.TRUE);
                }
                AggregateFunction firstStageFn = tsAgg.perTimeSeriesAggregation();
                Alias newAgg = timeSeriesAggs.computeIfAbsent(firstStageFn, k -> {
                    Alias firstStageAlias = new Alias(tsAgg.source(), internalNames.next(tsAgg.functionName()), (Expression)firstStageFn);
                    firstPassAggs.add(firstStageAlias);
                    return firstStageAlias;
                });
                return newAgg.toAttribute();
            });
            if (((Boolean)changed.get()).booleanValue()) {
                secondPassAggs.add(new Alias(alias.source(), alias.name(), outerAgg, namedExpression.id()));
                continue;
            }
            Expression aggField = af.field();
            LastOverTime tsAgg2 = new LastOverTime(af.source(), aggField, (Expression)timestamp.get());
            LastOverTime firstStageFn = inlineFilter != null ? tsAgg2.perTimeSeriesAggregation().withFilter(inlineFilter) : tsAgg2.perTimeSeriesAggregation();
            Alias newAgg = timeSeriesAggs.computeIfAbsent(firstStageFn, k -> {
                Alias firstStageAlias = new Alias(tsAgg2.source(), internalNames.next(tsAgg2.functionName()), (Expression)firstStageFn);
                firstPassAggs.add(firstStageAlias);
                return firstStageAlias;
            });
            secondPassAggs.add((Alias)namedExpression.transformUp(f -> f == aggField || f instanceof AggregateFunction, e -> {
                if (e == aggField) {
                    return newAgg.toAttribute();
                }
                if (e instanceof AggregateFunction) {
                    AggregateFunction f = (AggregateFunction)e;
                    return f.withFilter((Expression)Literal.TRUE);
                }
                return e;
            }));
        }
        ArrayList<Expression> firstPassGroupings = new ArrayList<Expression>();
        firstPassGroupings.add((Expression)tsid.get());
        ArrayList<Alias> arrayList = new ArrayList<Alias>();
        ArrayList<Expression> secondPassGroupings = new ArrayList<Expression>();
        ArrayList<Alias> unpackDimensions = new ArrayList<Alias>();
        Holder timeBucketRef = new Holder();
        aggregate.child().forEachExpressionUp(NamedExpression.class, e -> {
            for (Expression child : e.children()) {
                TBucket tbucket;
                Bucket bucket;
                if (child instanceof Bucket && (bucket = (Bucket)child).field().equals(timestamp.get())) {
                    if (timeBucketRef.get() != null) {
                        throw new IllegalArgumentException("expected at most one time bucket");
                    }
                    timeBucketRef.set(e);
                    continue;
                }
                if (!(child instanceof TBucket) || !(tbucket = (TBucket)child).field().equals(timestamp.get())) continue;
                if (timeBucketRef.get() != null) {
                    throw new IllegalArgumentException("expected at most one time tbucket");
                }
                Bucket bucket2 = (Bucket)tbucket.surrogate();
                timeBucketRef.set((Object)new Alias(e.source(), bucket2.functionName(), (Expression)bucket2, e.id()));
            }
        });
        NamedExpression timeBucket = (NamedExpression)timeBucketRef.get();
        for (Expression group : aggregate.groupings()) {
            if (!(group instanceof Attribute)) {
                throw new EsqlIllegalArgumentException("expected named expression for grouping; got " + String.valueOf(group));
            }
            Attribute g = (Attribute)group;
            if (timeBucket != null && g.id().equals((Object)timeBucket.id())) {
                Attribute newFinalGroup = timeBucket.toAttribute();
                firstPassGroupings.add((Expression)newFinalGroup);
                secondPassGroupings.add((Expression)new Alias(g.source(), g.name(), (Expression)newFinalGroup.toAttribute(), g.id()));
                continue;
            }
            Alias valuesAgg = new Alias(g.source(), g.name(), (Expression)this.valuesAggregate(context, g));
            firstPassAggs.add(valuesAgg);
            Alias pack = new Alias(g.source(), internalNames.next("pack" + g.name()), (Expression)new PackDimension(g.source(), (Expression)valuesAgg.toAttribute()));
            arrayList.add(pack);
            Alias grouping = new Alias(g.source(), internalNames.next("group" + g.name()), (Expression)pack.toAttribute());
            secondPassGroupings.add((Expression)grouping);
            Alias unpack = new Alias(g.source(), g.name(), (Expression)new UnpackDimension(g.source(), (Expression)grouping.toAttribute(), g.dataType().noText()), g.id());
            unpackDimensions.add(unpack);
        }
        LogicalPlan newChild = (LogicalPlan)aggregate.child().transformUp(EsRelation.class, r -> {
            IndexMode indexMode;
            IndexMode indexMode2 = indexMode = (Boolean)requiredTimeSeriesSource.get() != false ? r.indexMode() : IndexMode.STANDARD;
            if (!r.output().contains(tsid.get())) {
                return new EsRelation(r.source(), r.indexPattern(), indexMode, r.indexNameWithModes(), CollectionUtils.combine(r.output(), (Object[])new Attribute[]{(Attribute)tsid.get()}));
            }
            return new EsRelation(r.source(), r.indexPattern(), indexMode, r.indexNameWithModes(), r.output());
        });
        TimeSeriesAggregate firstPhase = new TimeSeriesAggregate(newChild.source(), newChild, firstPassGroupings, TranslateTimeSeriesAggregate.mergeExpressions(firstPassAggs, firstPassGroupings), (Bucket)Alias.unwrap((Expression)timeBucket));
        if (arrayList.isEmpty()) {
            return new Aggregate(firstPhase.source(), firstPhase, secondPassGroupings, TranslateTimeSeriesAggregate.mergeExpressions(secondPassAggs, secondPassGroupings));
        }
        Eval packValues = new Eval(firstPhase.source(), firstPhase, arrayList);
        Aggregate secondPhase = new Aggregate(firstPhase.source(), packValues, secondPassGroupings, TranslateTimeSeriesAggregate.mergeExpressions(secondPassAggs, secondPassGroupings));
        Eval unpackValues = new Eval(secondPhase.source(), secondPhase, unpackDimensions);
        ArrayList<Attribute> projects = new ArrayList<Attribute>();
        for (NamedExpression namedExpression : secondPassAggs) {
            projects.add(Expressions.attribute((Expression)namedExpression));
        }
        int pos = 0;
        for (Expression group : secondPassGroupings) {
            Attribute g = Expressions.attribute((Expression)group);
            if (timeBucket != null && g.id().equals((Object)timeBucket.id())) {
                projects.add(g);
                continue;
            }
            projects.add(((Alias)unpackDimensions.get(pos++)).toAttribute());
        }
        return new Project(newChild.source(), unpackValues, projects);
    }

    private static List<? extends NamedExpression> mergeExpressions(List<? extends NamedExpression> aggregates, List<Expression> groupings) {
        ArrayList<? extends NamedExpression> merged = new ArrayList<NamedExpression>(aggregates.size() + groupings.size());
        merged.addAll(aggregates);
        groupings.forEach(g -> merged.add((NamedExpression)Expressions.attribute((Expression)g)));
        return merged;
    }

    private AggregateFunction valuesAggregate(LogicalOptimizerContext context, Attribute group) {
        if (group.isDimension() && context.minimumVersion().supports(DimensionValues.DIMENSION_VALUES_VERSION)) {
            return new DimensionValues(group.source(), (Expression)group);
        }
        return new Values(group.source(), (Expression)group);
    }

    private static class InternalNames {
        final Map<String, Integer> next = new HashMap<String, Integer>();

        private InternalNames() {
        }

        String next(String prefix) {
            int id = this.next.merge(prefix, 1, Integer::sum);
            return prefix + "_$" + id;
        }
    }
}

