/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.compute.aggregation;

import java.util.stream.IntStream;
import org.elasticsearch.compute.aggregation.GroupingAggregatorEvaluationContext;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.BooleanVector;
import org.elasticsearch.compute.data.IntArrayBlock;
import org.elasticsearch.compute.data.IntBigArrayBlock;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.data.ToMask;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;

record FilteredGroupingAggregatorFunction(GroupingAggregatorFunction next, EvalOperator.ExpressionEvaluator filter) implements GroupingAggregatorFunction
{
    FilteredGroupingAggregatorFunction {
        next.selectedMayContainUnseenGroups(new SeenGroupIds.Empty());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) {
        try (BooleanBlock filterResult = (BooleanBlock)this.filter.eval(page);){
            FilteredAddInput filteredAddInput;
            ToMask mask = filterResult.toMask();
            GroupingAggregatorFunction.AddInput nextAdd = null;
            try {
                nextAdd = this.next.prepareProcessRawInputPage(seenGroupIds, page);
                FilteredAddInput result = new FilteredAddInput(mask.mask(), nextAdd, page.getPositionCount());
                mask = null;
                nextAdd = null;
                filteredAddInput = result;
            }
            catch (Throwable throwable) {
                Releasables.close((Releasable[])new Releasable[]{mask, nextAdd});
                throw throwable;
            }
            Releasables.close((Releasable[])new Releasable[]{mask, nextAdd});
            return filteredAddInput;
        }
    }

    @Override
    public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) {
    }

    @Override
    public void addIntermediateInput(int positionOffset, IntVector groupIdVector, Page page) {
        this.next.addIntermediateInput(positionOffset, groupIdVector, page);
    }

    @Override
    public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
        this.next.addIntermediateRowInput(groupId, ((FilteredGroupingAggregatorFunction)input).next(), position);
    }

    @Override
    public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
        this.next.evaluateIntermediate(blocks, offset, selected);
    }

    @Override
    public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) {
        this.next.evaluateFinal(blocks, offset, selected, evaluationContext);
    }

    @Override
    public int intermediateBlockCount() {
        return this.next.intermediateBlockCount();
    }

    public void close() {
        Releasables.closeExpectNoException((Releasable[])new Releasable[]{this.next, this.filter});
    }

    private record FilteredAddInput(BooleanVector mask, GroupingAggregatorFunction.AddInput nextAdd, int positionCount) implements GroupingAggregatorFunction.AddInput
    {
        @Override
        public void add(int positionOffset, IntArrayBlock groupIds) {
            this.addBlock(positionOffset, groupIds);
        }

        @Override
        public void add(int positionOffset, IntBigArrayBlock groupIds) {
            this.addBlock(positionOffset, groupIds);
        }

        @Override
        public void add(int positionOffset, IntVector groupIds) {
            this.addBlock(positionOffset, groupIds.asBlock());
        }

        private void addBlock(int positionOffset, IntBlock groupIds) {
            if (positionOffset == 0) {
                try (IntBlock filtered = groupIds.keepMask(this.mask);){
                    this.nextAdd.add(positionOffset, filtered);
                }
            }
            try (BooleanVector offsetMask = this.mask.filter(IntStream.range(positionOffset, positionOffset + groupIds.getPositionCount()).toArray());
                 IntBlock filtered = groupIds.keepMask(offsetMask);){
                this.nextAdd.add(positionOffset, filtered);
            }
        }

        public void close() {
            Releasables.close((Releasable[])new Releasable[]{this.mask, this.nextAdd});
        }
    }
}

