/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.plan.physical;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;

public class AggregateExec
extends UnaryExec
implements EstimatesRowSize {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(PhysicalPlan.class, "AggregateExec", AggregateExec::new);
    private final List<? extends Expression> groupings;
    private final List<? extends NamedExpression> aggregates;
    private final List<Attribute> intermediateAttributes;
    private final AggregatorMode mode;
    private final Integer estimatedRowSize;

    public AggregateExec(Source source, PhysicalPlan child, List<? extends Expression> groupings, List<? extends NamedExpression> aggregates, AggregatorMode mode, List<Attribute> intermediateAttributes, Integer estimatedRowSize) {
        super(source, child);
        this.groupings = groupings;
        this.aggregates = aggregates;
        this.mode = mode;
        this.intermediateAttributes = intermediateAttributes;
        this.estimatedRowSize = estimatedRowSize;
    }

    protected AggregateExec(StreamInput in) throws IOException {
        this(Source.readFrom((StreamInput)((PlanStreamInput)in)), (PhysicalPlan)in.readNamedWriteable(PhysicalPlan.class), in.readNamedWriteableCollectionAsList(Expression.class), in.readNamedWriteableCollectionAsList(NamedExpression.class), (AggregatorMode)in.readEnum(AggregatorMode.class), in.readNamedWriteableCollectionAsList(Attribute.class), in.readOptionalVInt());
    }

    public void writeTo(StreamOutput out) throws IOException {
        Source.EMPTY.writeTo(out);
        out.writeNamedWriteable((NamedWriteable)this.child());
        out.writeNamedWriteableCollection(this.groupings());
        out.writeNamedWriteableCollection(this.aggregates());
        if (out.getTransportVersion().onOrAfter((VersionId)TransportVersions.V_8_16_0)) {
            out.writeEnum((Enum)this.getMode());
            out.writeNamedWriteableCollection(this.intermediateAttributes());
        } else {
            out.writeEnum((Enum)Mode.fromAggregatorMode(this.getMode()));
        }
        out.writeOptionalVInt(this.estimatedRowSize());
    }

    public String getWriteableName() {
        return AggregateExec.ENTRY.name;
    }

    protected NodeInfo<AggregateExec> info() {
        return NodeInfo.create((Node)this, AggregateExec::new, (Object)((Object)this.child()), this.groupings, this.aggregates, (Object)this.mode, this.intermediateAttributes, (Object)this.estimatedRowSize);
    }

    @Override
    public AggregateExec replaceChild(PhysicalPlan newChild) {
        return new AggregateExec(this.source(), newChild, this.groupings, this.aggregates, this.mode, this.intermediateAttributes, this.estimatedRowSize);
    }

    public List<? extends Expression> groupings() {
        return this.groupings;
    }

    public List<? extends NamedExpression> aggregates() {
        return this.aggregates;
    }

    public AggregateExec withAggregates(List<? extends NamedExpression> newAggregates) {
        return new AggregateExec(this.source(), this.child(), this.groupings, newAggregates, this.mode, this.intermediateAttributes, this.estimatedRowSize);
    }

    public AggregateExec withMode(AggregatorMode newMode) {
        return new AggregateExec(this.source(), this.child(), this.groupings, this.aggregates, newMode, this.intermediateAttributes, this.estimatedRowSize);
    }

    public Integer estimatedRowSize() {
        return this.estimatedRowSize;
    }

    @Override
    public PhysicalPlan estimateRowSize(EstimatesRowSize.State state) {
        state.add(false, this.aggregates);
        int size = state.consumeAllFields(true);
        size = Math.max(size, 1);
        return Objects.equals(this.estimatedRowSize, size) ? this : this.withEstimatedSize(size);
    }

    protected AggregateExec withEstimatedSize(int estimatedRowSize) {
        return new AggregateExec(this.source(), this.child(), this.groupings, this.aggregates, this.mode, this.intermediateAttributes, estimatedRowSize);
    }

    public AggregatorMode getMode() {
        return this.mode;
    }

    public List<Attribute> intermediateAttributes() {
        return this.intermediateAttributes;
    }

    @Override
    public List<Attribute> output() {
        return this.mode.isOutputPartial() ? this.intermediateAttributes : Aggregate.output(this.aggregates);
    }

    @Override
    protected AttributeSet computeReferences() {
        return this.mode.isInputPartial() ? AttributeSet.of(this.intermediateAttributes) : Aggregate.computeReferences(this.aggregates, this.groupings).subtract(AttributeSet.of(this.ordinalAttributes()));
    }

    public List<Attribute> ordinalAttributes() {
        ArrayList<Attribute> orginalAttributs = new ArrayList<Attribute>(this.groupings.size());
        if (this.groupings().size() == 1 && !this.groupings.get(0).anyMatch(e -> e instanceof Categorize)) {
            HashSet leaves = new HashSet();
            this.aggregates.stream().filter(a -> !this.groupings.contains(a)).forEach(a -> leaves.addAll(a.collectLeaves()));
            this.groupings.forEach(g -> {
                if (!leaves.contains(g)) {
                    orginalAttributs.add((Attribute)g);
                }
            });
        }
        return orginalAttributs;
    }

    @Override
    public int hashCode() {
        return Objects.hash(new Object[]{this.groupings, this.aggregates, this.mode, this.intermediateAttributes, this.estimatedRowSize, this.child()});
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        AggregateExec other = (AggregateExec)obj;
        return Objects.equals(this.groupings, other.groupings) && Objects.equals(this.aggregates, other.aggregates) && Objects.equals(this.mode, other.mode) && Objects.equals(this.intermediateAttributes, other.intermediateAttributes) && Objects.equals(this.estimatedRowSize, other.estimatedRowSize) && Objects.equals((Object)this.child(), (Object)other.child());
    }

    @Deprecated
    private static enum Mode {
        SINGLE,
        PARTIAL,
        FINAL;


        static Mode fromAggregatorMode(AggregatorMode aggregatorMode) {
            return switch (aggregatorMode) {
                default -> throw new MatchException(null, null);
                case AggregatorMode.SINGLE -> SINGLE;
                case AggregatorMode.INITIAL -> PARTIAL;
                case AggregatorMode.FINAL -> FINAL;
                case AggregatorMode.INTERMEDIATE -> throw new UnsupportedOperationException("cannot turn intermediate aggregation into single, partial or final.");
            };
        }
    }
}

