/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.planner;

import java.util.HashSet;
import java.util.List;
import java.util.stream.Stream;
import org.elasticsearch.common.Strings;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.IntermediateStateDesc;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.planner.ToAggregator;

public final class AggregateMapper {
    public static List<NamedExpression> mapNonGrouping(List<? extends NamedExpression> aggregates) {
        return AggregateMapper.doMapping(aggregates, false);
    }

    public static List<NamedExpression> mapGrouping(List<? extends NamedExpression> aggregates) {
        return AggregateMapper.doMapping(aggregates, true);
    }

    private static List<NamedExpression> doMapping(List<? extends NamedExpression> aggregates, boolean grouping) {
        HashSet<Expression> seen = new HashSet<Expression>();
        AttributeMap.Builder attrToExpressionsBuilder = AttributeMap.builder();
        for (NamedExpression namedExpression : aggregates) {
            Expression inner = Alias.unwrap((Expression)namedExpression);
            if (!seen.add(inner)) continue;
            for (NamedExpression ne : AggregateMapper.computeEntryForAgg(namedExpression.name(), inner, grouping)) {
                attrToExpressionsBuilder.put(ne.toAttribute(), (Object)ne);
            }
        }
        return attrToExpressionsBuilder.build().values().stream().toList();
    }

    public static List<IntermediateStateDesc> intermediateStateDesc(AggregateFunction fn, boolean grouping) {
        if (fn instanceof ToAggregator) {
            ToAggregator toAggregator = (ToAggregator)((Object)fn);
            AggregatorFunctionSupplier supplier = toAggregator.supplier();
            return grouping ? supplier.groupingIntermediateStateDesc() : supplier.nonGroupingIntermediateStateDesc();
        }
        throw new EsqlIllegalArgumentException("Aggregate has no defined intermediate state: " + String.valueOf(fn));
    }

    private static List<NamedExpression> computeEntryForAgg(String aggAlias, Expression aggregate, boolean grouping) {
        if (aggregate instanceof AggregateFunction) {
            AggregateFunction aggregateFunction = (AggregateFunction)aggregate;
            return AggregateMapper.entryForAgg(aggAlias, aggregateFunction, grouping);
        }
        if (aggregate instanceof FieldAttribute || aggregate instanceof MetadataAttribute || aggregate instanceof ReferenceAttribute) {
            return List.of();
        }
        throw new EsqlIllegalArgumentException("unknown agg: " + String.valueOf(aggregate.getClass()) + ": " + String.valueOf(aggregate));
    }

    private static List<NamedExpression> entryForAgg(String aggAlias, AggregateFunction aggregateFunction, boolean grouping) {
        if (!(aggregateFunction instanceof ToAggregator)) {
            throw new EsqlIllegalArgumentException("Aggregate has no defined intermediate state: " + String.valueOf(aggregateFunction));
        }
        ToAggregator toAggregator = (ToAggregator)((Object)aggregateFunction);
        AggregatorFunctionSupplier supplier = toAggregator.supplier();
        List intermediateState = grouping ? supplier.groupingIntermediateStateDesc() : supplier.nonGroupingIntermediateStateDesc();
        return AggregateMapper.intermediateStateToNamedExpressions(intermediateState, aggAlias).toList();
    }

    private static Stream<NamedExpression> intermediateStateToNamedExpressions(List<IntermediateStateDesc> intermediateStateDescs, String aggAlias) {
        return intermediateStateDescs.stream().map(is -> {
            DataType dataType = Strings.isEmpty((CharSequence)is.dataType()) ? AggregateMapper.toDataType(is.type()) : DataType.fromEs((String)is.dataType());
            return new ReferenceAttribute(Source.EMPTY, null, Attribute.rawTemporaryName((String[])new String[]{aggAlias, is.name()}), dataType);
        });
    }

    private static DataType toDataType(ElementType elementType) {
        return switch (elementType) {
            default -> throw new MatchException(null, null);
            case ElementType.BOOLEAN -> DataType.BOOLEAN;
            case ElementType.BYTES_REF -> DataType.KEYWORD;
            case ElementType.INT -> DataType.INTEGER;
            case ElementType.LONG -> DataType.LONG;
            case ElementType.DOUBLE -> DataType.DOUBLE;
            case ElementType.DOC -> DataType.DOC_DATA_TYPE;
            case ElementType.FLOAT, ElementType.NULL, ElementType.COMPOSITE, ElementType.AGGREGATE_METRIC_DOUBLE, ElementType.UNKNOWN -> throw new EsqlIllegalArgumentException("unsupported agg type: " + String.valueOf(elementType));
        };
    }
}

