/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.compute.aggregation;

import java.util.List;
import org.elasticsearch.compute.aggregation.GroupingAggregatorEvaluationContext;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.IntermediateStateDesc;
import org.elasticsearch.compute.aggregation.LongArrayState;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BooleanVector;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntArrayBlock;
import org.elasticsearch.compute.data.IntBigArrayBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.LongVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.Vector;
import org.elasticsearch.compute.operator.DriverContext;

public class CountGroupingAggregatorFunction
implements GroupingAggregatorFunction {
    private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(new IntermediateStateDesc("count", ElementType.LONG), new IntermediateStateDesc("seen", ElementType.BOOLEAN));
    private final LongArrayState state;
    private final List<Integer> channels;
    private final DriverContext driverContext;
    private final boolean countAll;

    public static CountGroupingAggregatorFunction create(DriverContext driverContext, List<Integer> inputChannels) {
        return new CountGroupingAggregatorFunction(inputChannels, new LongArrayState(driverContext.bigArrays(), 0L), driverContext);
    }

    public static List<IntermediateStateDesc> intermediateStateDesc() {
        return INTERMEDIATE_STATE_DESC;
    }

    private CountGroupingAggregatorFunction(List<Integer> channels, LongArrayState state, DriverContext driverContext) {
        this.channels = channels;
        this.state = state;
        this.driverContext = driverContext;
        this.countAll = channels.isEmpty();
    }

    private int blockIndex() {
        return this.countAll ? 0 : this.channels.get(0);
    }

    @Override
    public int intermediateBlockCount() {
        return CountGroupingAggregatorFunction.intermediateStateDesc().size();
    }

    @Override
    public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) {
        Vector valuesVector;
        Object valuesBlock = page.getBlock(this.blockIndex());
        if (!this.countAll && (valuesVector = valuesBlock.asVector()) == null) {
            if (valuesBlock.mayHaveNulls()) {
                this.state.enableGroupIdTracking(seenGroupIds);
            }
            return new GroupingAggregatorFunction.AddInput(){
                final /* synthetic */ Block val$valuesBlock;
                {
                    this.val$valuesBlock = block;
                }

                @Override
                public void add(int positionOffset, IntArrayBlock groupIds) {
                    CountGroupingAggregatorFunction.this.addRawInput(positionOffset, groupIds, this.val$valuesBlock);
                }

                @Override
                public void add(int positionOffset, IntBigArrayBlock groupIds) {
                    CountGroupingAggregatorFunction.this.addRawInput(positionOffset, groupIds, this.val$valuesBlock);
                }

                @Override
                public void add(int positionOffset, IntVector groupIds) {
                    CountGroupingAggregatorFunction.this.addRawInput(positionOffset, groupIds, this.val$valuesBlock);
                }

                public void close() {
                }
            };
        }
        return new GroupingAggregatorFunction.AddInput(){

            @Override
            public void add(int positionOffset, IntArrayBlock groupIds) {
                CountGroupingAggregatorFunction.this.addRawInput(groupIds);
            }

            @Override
            public void add(int positionOffset, IntBigArrayBlock groupIds) {
                CountGroupingAggregatorFunction.this.addRawInput(groupIds);
            }

            @Override
            public void add(int positionOffset, IntVector groupIds) {
                CountGroupingAggregatorFunction.this.addRawInput(groupIds);
            }

            public void close() {
            }
        };
    }

    private void addRawInput(int positionOffset, IntVector groups, Block values) {
        int position = positionOffset;
        int groupPosition = 0;
        while (groupPosition < groups.getPositionCount()) {
            if (!values.isNull(position)) {
                int groupId = groups.getInt(groupPosition);
                this.state.increment(groupId, values.getValueCount(position));
            }
            ++groupPosition;
            ++position;
        }
    }

    private void addRawInput(int positionOffset, IntArrayBlock groups, Block values) {
        int position = positionOffset;
        int groupPosition = 0;
        while (groupPosition < groups.getPositionCount()) {
            if (!groups.isNull(groupPosition) && !values.isNull(position)) {
                int groupStart = groups.getFirstValueIndex(groupPosition);
                int groupEnd = groupStart + groups.getValueCount(groupPosition);
                for (int g = groupStart; g < groupEnd; ++g) {
                    int groupId = groups.getInt(g);
                    this.state.increment(groupId, values.getValueCount(position));
                }
            }
            ++groupPosition;
            ++position;
        }
    }

    private void addRawInput(int positionOffset, IntBigArrayBlock groups, Block values) {
        int position = positionOffset;
        int groupPosition = 0;
        while (groupPosition < groups.getPositionCount()) {
            if (!groups.isNull(groupPosition) && !values.isNull(position)) {
                int groupStart = groups.getFirstValueIndex(groupPosition);
                int groupEnd = groupStart + groups.getValueCount(groupPosition);
                for (int g = groupStart; g < groupEnd; ++g) {
                    int groupId = groups.getInt(g);
                    this.state.increment(groupId, values.getValueCount(position));
                }
            }
            ++groupPosition;
            ++position;
        }
    }

    private void addRawInput(IntVector groups) {
        if (groups.isConstant()) {
            this.state.increment(groups.getInt(0), groups.getPositionCount());
        } else {
            for (int groupPosition = 0; groupPosition < groups.getPositionCount(); ++groupPosition) {
                int groupId = groups.getInt(groupPosition);
                this.state.increment(groupId, 1L);
            }
        }
    }

    private void addRawInput(IntArrayBlock groups) {
        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); ++groupPosition) {
            if (groups.isNull(groupPosition)) continue;
            int groupStart = groups.getFirstValueIndex(groupPosition);
            int groupEnd = groupStart + groups.getValueCount(groupPosition);
            for (int g = groupStart; g < groupEnd; ++g) {
                int groupId = groups.getInt(g);
                this.state.increment(groupId, 1L);
            }
        }
    }

    private void addRawInput(IntBigArrayBlock groups) {
        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); ++groupPosition) {
            if (groups.isNull(groupPosition)) continue;
            int groupStart = groups.getFirstValueIndex(groupPosition);
            int groupEnd = groupStart + groups.getValueCount(groupPosition);
            for (int g = groupStart; g < groupEnd; ++g) {
                int groupId = groups.getInt(g);
                this.state.increment(groupId, 1L);
            }
        }
    }

    @Override
    public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) {
        this.state.enableGroupIdTracking(seenGroupIds);
    }

    @Override
    public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) {
        assert (this.channels.size() == this.intermediateBlockCount());
        assert (page.getBlockCount() >= this.blockIndex() + CountGroupingAggregatorFunction.intermediateStateDesc().size());
        this.state.enableGroupIdTracking(new SeenGroupIds.Empty());
        LongVector count = ((LongBlock)page.getBlock(this.channels.get(0))).asVector();
        BooleanVector seen = ((BooleanBlock)page.getBlock(this.channels.get(1))).asVector();
        assert (count.getPositionCount() == seen.getPositionCount());
        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); ++groupPosition) {
            if (groups.isNull(groupPosition)) continue;
            int groupStart = groups.getFirstValueIndex(groupPosition);
            int groupEnd = groupStart + groups.getValueCount(groupPosition);
            for (int g = groupStart; g < groupEnd; ++g) {
                int groupId = groups.getInt(g);
                this.state.increment(groupId, count.getLong(groupPosition + positionOffset));
            }
        }
    }

    @Override
    public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) {
        assert (this.channels.size() == this.intermediateBlockCount());
        assert (page.getBlockCount() >= this.blockIndex() + CountGroupingAggregatorFunction.intermediateStateDesc().size());
        this.state.enableGroupIdTracking(new SeenGroupIds.Empty());
        LongVector count = ((LongBlock)page.getBlock(this.channels.get(0))).asVector();
        BooleanVector seen = ((BooleanBlock)page.getBlock(this.channels.get(1))).asVector();
        assert (count.getPositionCount() == seen.getPositionCount());
        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); ++groupPosition) {
            if (groups.isNull(groupPosition)) continue;
            int groupStart = groups.getFirstValueIndex(groupPosition);
            int groupEnd = groupStart + groups.getValueCount(groupPosition);
            for (int g = groupStart; g < groupEnd; ++g) {
                int groupId = groups.getInt(g);
                this.state.increment(groupId, count.getLong(groupPosition + positionOffset));
            }
        }
    }

    @Override
    public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
        assert (this.channels.size() == this.intermediateBlockCount());
        assert (page.getBlockCount() >= this.blockIndex() + CountGroupingAggregatorFunction.intermediateStateDesc().size());
        this.state.enableGroupIdTracking(new SeenGroupIds.Empty());
        LongVector count = ((LongBlock)page.getBlock(this.channels.get(0))).asVector();
        BooleanVector seen = ((BooleanBlock)page.getBlock(this.channels.get(1))).asVector();
        assert (count.getPositionCount() == seen.getPositionCount());
        for (int groupPosition = 0; groupPosition < groups.getPositionCount(); ++groupPosition) {
            this.state.increment(groups.getInt(groupPosition), count.getLong(groupPosition + positionOffset));
        }
    }

    @Override
    public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
        this.state.toIntermediate(blocks, offset, selected, this.driverContext);
    }

    @Override
    public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) {
        try (LongVector.FixedBuilder builder = evaluationContext.blockFactory().newLongVectorFixedBuilder(selected.getPositionCount());){
            for (int i = 0; i < selected.getPositionCount(); ++i) {
                int si = selected.getInt(i);
                builder.appendLong(this.state.hasValue(si) ? this.state.get(si) : 0L);
            }
            blocks[offset] = builder.build().asBlock();
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(this.getClass().getSimpleName()).append("[");
        sb.append("channels=").append(this.channels);
        sb.append("]");
        return sb.toString();
    }

    public void close() {
        this.state.close();
    }
}

