/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.analysis;

import java.util.ArrayList;
import java.util.Map;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.expression.function.Functions;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.TimeSeriesAggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.rule.Rule;

public class TimeSeriesGroupByAll
extends Rule<LogicalPlan, LogicalPlan> {
    @Override
    public LogicalPlan apply(LogicalPlan logicalPlan) {
        return (LogicalPlan)logicalPlan.transformUp(node -> node instanceof TimeSeriesAggregate, this::rule);
    }

    public LogicalPlan rule(TimeSeriesAggregate aggregate) {
        TimeSeriesAggregateFunction lastTSAggFunction = null;
        AggregateFunction lastNonTSAggFunction = null;
        ArrayList<Object> newAggregateFunctions = new ArrayList<Object>(aggregate.aggregates().size());
        for (NamedExpression namedExpression : aggregate.aggregates()) {
            Object alias;
            Expression expression;
            if (namedExpression instanceof Alias && (expression = (alias = (Alias)namedExpression).child()) instanceof AggregateFunction) {
                AggregateFunction af = (AggregateFunction)expression;
                if (af instanceof TimeSeriesAggregateFunction) {
                    TimeSeriesAggregateFunction tsAgg = (TimeSeriesAggregateFunction)af;
                    newAggregateFunctions.add(new Alias(alias.source(), alias.name(), (Expression)new Values(tsAgg.source(), (Expression)tsAgg)));
                    lastTSAggFunction = tsAgg;
                    continue;
                }
                newAggregateFunctions.add(namedExpression);
                lastNonTSAggFunction = af;
                continue;
            }
            newAggregateFunctions.add(namedExpression);
        }
        if (lastTSAggFunction == null) {
            return aggregate;
        }
        if (lastNonTSAggFunction != null) {
            throw new EsqlIllegalArgumentException("Cannot mix time-series aggregate [{}] and regular aggregate [{}] in the same TimeSeriesAggregate.", lastTSAggFunction.sourceText(), lastNonTSAggFunction.sourceText());
        }
        FieldAttribute timeSeries = new FieldAttribute(aggregate.source(), null, null, "_timeseries", new EsField("_timeseries", DataType.KEYWORD, Map.of(), false, EsField.TimeSeriesFieldType.DIMENSION));
        ArrayList<Expression> arrayList = new ArrayList<Expression>();
        arrayList.add((Expression)timeSeries);
        for (Expression grouping : aggregate.groupings()) {
            if (!Functions.isGrouping(Alias.unwrap((Expression)grouping))) {
                throw new EsqlIllegalArgumentException("Cannot mix time-series aggregate and grouping attributes. Found [{}].", grouping.sourceText());
            }
            arrayList.add(grouping);
        }
        TimeSeriesAggregate newStats = new TimeSeriesAggregate(aggregate.source(), aggregate.child(), arrayList, newAggregateFunctions, null);
        return (LogicalPlan)newStats.transformDown(EsRelation.class, r -> {
            ArrayList<Attribute> attributes = new ArrayList<Attribute>(r.output());
            attributes.add((Attribute)timeSeries);
            return new EsRelation(r.source(), r.indexPattern(), r.indexMode(), r.originalIndices(), r.concreteIndices(), r.indexNameWithModes(), attributes);
        });
    }
}

