/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

package org.elasticsearch.xpack.esql.planner;

import org.elasticsearch.compute.aggregation.Aggregator;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.FilteredAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.operator.AggregationOperator;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.index.analysis.AnalysisRegistry;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;

import static java.util.Collections.emptyList;

public abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders {

    private final FoldContext foldContext;
    private final AnalysisRegistry analysisRegistry;

    AbstractPhysicalOperationProviders(FoldContext foldContext, AnalysisRegistry analysisRegistry) {
        this.foldContext = foldContext;
        this.analysisRegistry = analysisRegistry;
    }

    @Override
    public final PhysicalOperation groupingPhysicalOperation(
        AggregateExec aggregateExec,
        PhysicalOperation source,
        LocalExecutionPlannerContext context
    ) {
        // The layout this operation will produce.
        Layout.Builder layout = new Layout.Builder();
        Operator.OperatorFactory operatorFactory = null;
        AggregatorMode aggregatorMode = aggregateExec.getMode();
        var aggregates = aggregateExec.aggregates();

        var sourceLayout = source.layout;

        if (aggregateExec.groupings().isEmpty()) {
            // not grouping
            List<Aggregator.Factory> aggregatorFactories = new ArrayList<>();

            layout.append(aggregateExec.output());

            // create the agg factories
            aggregatesToFactory(
                aggregateExec,
                aggregates,
                aggregatorMode,
                sourceLayout,
                false, // non-grouping
                s -> aggregatorFactories.add(s.supplier.aggregatorFactory(s.mode, s.channels)),
                context
            );

            if (aggregatorFactories.isEmpty() == false) {
                operatorFactory = new AggregationOperator.AggregationOperatorFactory(aggregatorFactories, aggregatorMode);
            }
        } else {
            // grouping
            List<GroupingAggregator.Factory> aggregatorFactories = new ArrayList<>();
            List<GroupSpec> groupSpecs = new ArrayList<>(aggregateExec.groupings().size());
            for (Expression group : aggregateExec.groupings()) {
                Attribute groupAttribute = Expressions.attribute(group);
                // In case of `... BY groupAttribute = CATEGORIZE(sourceGroupAttribute)` the actual source attribute is different.
                Attribute sourceGroupAttribute = (aggregatorMode.isInputPartial() == false
                    && group instanceof Alias as
                    && as.child() instanceof Categorize categorize) ? Expressions.attribute(categorize.field()) : groupAttribute;
                if (sourceGroupAttribute == null) {
                    throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping in [{}]", group, aggregateExec);
                }
                Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<>(), sourceGroupAttribute.dataType());
                groupAttributeLayout.nameIds()
                    .add(group instanceof Alias as && as.child() instanceof Categorize ? groupAttribute.id() : sourceGroupAttribute.id());

                /*
                 * Check for aliasing in aggregates which occurs in two cases (due to combining project + stats):
                 *  - before stats (keep x = a | stats by x) which requires the partial input to use a's channel
                 *  - after  stats (stats by a | keep x = a) which causes the output layout to refer to the follow-up alias
                 */
                // TODO: This is likely required only for pre-8.14 node compatibility; confirm and remove if possible.
                // Since https://github.com/elastic/elasticsearch/pull/104958, it shouldn't be possible to have aliases in the aggregates
                // which the groupings refer to. Except for `BY CATEGORIZE(field)`, which remains as alias in the grouping, all aliases
                // should've become EVALs before or after the STATS.
                for (NamedExpression agg : aggregates) {
                    if (agg instanceof Alias a) {
                        if (a.child() instanceof Attribute attr) {
                            if (sourceGroupAttribute.id().equals(attr.id())) {
                                groupAttributeLayout.nameIds().add(a.id());
                                // TODO: investigate whether a break could be used since it shouldn't be possible to have multiple
                                // attributes pointing to the same attribute
                            }
                            // partial mode only
                            // check if there's any alias used in grouping - no need for the final reduction since the intermediate data
                            // is in the output form
                            // if the group points to an alias declared in the aggregate, use the alias child as source
                            else if (aggregatorMode.isOutputPartial()) {
                                if (sourceGroupAttribute.semanticEquals(a.toAttribute())) {
                                    sourceGroupAttribute = attr;
                                    break;
                                }
                            }
                        }
                    }
                }
                layout.append(groupAttributeLayout);
                Layout.ChannelAndType groupInput = source.layout.get(sourceGroupAttribute.id());
                groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group));
            }

            if (aggregatorMode.isOutputPartial()) {
                List<Attribute> output = aggregateExec.output();
                for (int i = aggregateExec.groupings().size(); i < output.size(); i++) {
                    layout.append(output.get(i));
                }
            } else {
                for (var agg : aggregates) {
                    if (Alias.unwrap(agg) instanceof AggregateFunction) {
                        layout.append(agg);
                    }
                }
            }
            // create the agg factories
            aggregatesToFactory(
                aggregateExec,
                aggregates,
                aggregatorMode,
                sourceLayout,
                true, // grouping
                s -> aggregatorFactories.add(s.supplier.groupingAggregatorFactory(s.mode, s.channels)),
                context
            );
            // time-series aggregation
            if (aggregateExec instanceof TimeSeriesAggregateExec ts) {
                operatorFactory = timeSeriesAggregatorOperatorFactory(
                    ts,
                    aggregatorMode,
                    aggregatorFactories,
                    groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(),
                    context
                );
            } else {
                operatorFactory = new HashAggregationOperatorFactory(
                    groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(),
                    aggregatorMode,
                    aggregatorFactories,
                    context.pageSize(aggregateExec, aggregateExec.estimatedRowSize()),
                    analysisRegistry
                );
            }
        }
        if (operatorFactory != null) {
            return source.with(operatorFactory, layout.build());
        }
        throw new EsqlIllegalArgumentException("no operator factory");
    }

    /***
     * Creates a standard layout for intermediate aggregations, typically used across exchanges.
     * Puts the group first, followed by each aggregation.
     * <p>
     *     It's similar to the code above (groupingPhysicalOperation) but ignores the factory creation.
     * </p>
     */
    public static List<Attribute> intermediateAttributes(List<? extends NamedExpression> aggregates, List<? extends Expression> groupings) {
        // TODO: This should take CATEGORIZE into account:
        // it currently works because the CATEGORIZE intermediate state is just 1 block with the same type as the function return,
        // so the attribute generated here is the expected one

        List<Attribute> attrs = new ArrayList<>();

        // no groups
        if (groupings.isEmpty()) {
            attrs = Expressions.asAttributes(AggregateMapper.mapNonGrouping(aggregates));
        }
        // groups
        else {
            for (Expression group : groupings) {
                var groupAttribute = Expressions.attribute(group);
                if (groupAttribute == null) {
                    throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping", group);
                }
                Set<NameId> grpAttribIds = new HashSet<>();
                grpAttribIds.add(groupAttribute.id());

                /*
                 * Check for aliasing in aggregates which occurs in two cases (due to combining project + stats):
                 *  - before stats (keep x = a | stats by x) which requires the partial input to use a's channel
                 *  - after  stats (stats by a | keep x = a) which causes the output layout to refer to the follow-up alias
                 */
                for (NamedExpression agg : aggregates) {
                    if (agg instanceof Alias a) {
                        if (a.child() instanceof Attribute attr) {
                            if (groupAttribute.id().equals(attr.id())) {
                                grpAttribIds.add(a.id());
                                // TODO: investigate whether a break could be used since it shouldn't be possible to have multiple
                                // attributes
                                // pointing to the same attribute
                            }
                        }
                    }
                }
                attrs.add(groupAttribute);
            }

            attrs.addAll(Expressions.asAttributes(AggregateMapper.mapGrouping(aggregates)));
        }
        return attrs;
    }

    private record AggFunctionSupplierContext(AggregatorFunctionSupplier supplier, List<Integer> channels, AggregatorMode mode) {}

    private static class IntermediateInputs {
        private final List<Attribute> inputAttributes;
        private int nextOffset;
        private final Map<AggregateFunction, Integer> offsets = new HashMap<>();

        IntermediateInputs(AggregateExec aggregateExec) {
            inputAttributes = aggregateExec.child().output();
            nextOffset = aggregateExec.groupings().size(); // skip grouping attributes
        }

        List<Attribute> nextInputAttributes(AggregateFunction af, boolean grouping) {
            int intermediateStateSize = AggregateMapper.intermediateStateDesc(af, grouping).size();
            int offset = offsets.computeIfAbsent(af, unused -> {
                int v = nextOffset;
                nextOffset += intermediateStateSize;
                return v;
            });
            return inputAttributes.subList(offset, offset + intermediateStateSize);
        }
    }

    private void aggregatesToFactory(
        AggregateExec aggregateExec,
        List<? extends NamedExpression> aggregates,
        AggregatorMode mode,
        Layout layout,
        boolean grouping,
        Consumer<AggFunctionSupplierContext> consumer,
        LocalExecutionPlannerContext context
    ) {
        IntermediateInputs intermediateInputs = mode.isInputPartial() ? new IntermediateInputs(aggregateExec) : null;
        // extract filtering channels - and wrap the aggregation with the new evaluator expression only during the init phase
        for (NamedExpression ne : aggregates) {
            // a filter can only appear on aggregate function, not on the grouping columns
            if (ne instanceof Alias alias) {
                var child = alias.child();
                if (child instanceof AggregateFunction aggregateFunction) {
                    final List<Attribute> sourceAttr;
                    if (mode.isInputPartial()) {
                        sourceAttr = intermediateInputs.nextInputAttributes(aggregateFunction, grouping);
                    } else {
                        // TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1)
                        Expression field = aggregateFunction.field();
                        // Only count can now support literals - all the other aggs should be optimized away
                        if (field.foldable()) {
                            if (aggregateFunction instanceof Count) {
                                sourceAttr = emptyList();
                            } else {
                                throw new InvalidArgumentException(
                                    "Does not support yet aggregations over constants - [{}]",
                                    aggregateFunction.sourceText()
                                );
                            }
                        } else {
                            // extra dependencies like TS ones (that require a timestamp)
                            sourceAttr = new ArrayList<>();
                            for (Expression input : aggregateFunction.aggregateInputReferences(aggregateExec.child()::output)) {
                                Attribute attr = Expressions.attribute(input);
                                if (attr == null) {
                                    throw new EsqlIllegalArgumentException(
                                        "Cannot work with target field [{}] for agg [{}]",
                                        input.sourceText(),
                                        aggregateFunction.sourceText()
                                    );
                                }
                                sourceAttr.add(attr);
                            }
                        }
                    }

                    AggregatorFunctionSupplier aggSupplier = supplier(aggregateFunction);

                    List<Integer> inputChannels = sourceAttr.stream().map(attr -> layout.get(attr.id()).channel()).toList();
                    assert inputChannels.stream().allMatch(i -> i >= 0) : inputChannels;

                    // apply the filter only in the initial phase - as the rest of the data is already filtered
                    if (aggregateFunction.hasFilter() && mode.isInputPartial() == false) {
                        EvalOperator.ExpressionEvaluator.Factory evalFactory = EvalMapper.toEvaluator(
                            foldContext,
                            aggregateFunction.filter(),
                            layout,
                            context.shardContexts()
                        );
                        aggSupplier = new FilteredAggregatorFunctionSupplier(aggSupplier, evalFactory);
                    }
                    consumer.accept(new AggFunctionSupplierContext(aggSupplier, inputChannels, mode));
                }
            }
        }
    }

    private static AggregatorFunctionSupplier supplier(AggregateFunction aggregateFunction) {
        if (aggregateFunction instanceof ToAggregator delegate) {
            return delegate.supplier();
        }
        throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
    }

    /**
     * The input configuration of this group.
     *
     * @param channel The source channel of this group
     * @param attribute The attribute, source of this group
     * @param expression The expression being used to group
     */
    private record GroupSpec(Integer channel, Attribute attribute, Expression expression) {
        BlockHash.GroupSpec toHashGroupSpec() {
            if (channel == null) {
                throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
            }
            return new BlockHash.GroupSpec(
                channel,
                elementType(),
                Alias.unwrap(expression) instanceof Categorize categorize ? categorize.categorizeDef() : null,
                null
            );
        }

        ElementType elementType() {
            return PlannerUtils.toElementType(attribute.dataType());
        }
    }

    public abstract Operator.OperatorFactory timeSeriesAggregatorOperatorFactory(
        TimeSeriesAggregateExec ts,
        AggregatorMode aggregatorMode,
        List<GroupingAggregator.Factory> aggregatorFactories,
        List<BlockHash.GroupSpec> groupSpecs,
        LocalExecutionPlannerContext context
    );
}
