/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.compute.operator;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.elasticsearch.common.Rounding;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.IntArray;
import org.elasticsearch.compute.Describable;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.GroupingAggregatorEvaluationContext;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.TimeSeriesGroupingAggregatorEvaluationContext;
import org.elasticsearch.compute.aggregation.WindowGroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.aggregation.blockhash.BytesRefLongBlockHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.mapper.DateFieldMapper;

public class TimeSeriesAggregationOperator
extends HashAggregationOperator {
    private final Rounding.Prepared timeBucket;
    private final DateFieldMapper.Resolution timeResolution;
    private ExpandingGroups expandingGroups = null;
    private static final Set<String> VALUES_CLASSES = Set.of("org.elasticsearch.compute.aggregation.ValuesBooleanGroupingAggregatorFunction", "org.elasticsearch.compute.aggregation.ValuesBytesRefGroupingAggregatorFunction", "org.elasticsearch.compute.aggregation.ValuesIntGroupingAggregatorFunction", "org.elasticsearch.compute.aggregation.ValuesLongGroupingAggregatorFunction", "org.elasticsearch.compute.aggregation.ValuesDoubleGroupingAggregatorFunction", "org.elasticsearch.compute.aggregation.DimensionValuesByteRefGroupingAggregatorFunction");

    public TimeSeriesAggregationOperator(Rounding.Prepared timeBucket, DateFieldMapper.Resolution timeResolution, List<GroupingAggregator.Factory> aggregators, Supplier<BlockHash> blockHash, DriverContext driverContext) {
        super(aggregators, blockHash, driverContext);
        this.timeBucket = timeBucket;
        this.timeResolution = timeResolution;
    }

    @Override
    public void finish() {
        this.expandWindowBuckets();
        super.finish();
    }

    private long largestWindowMillis() {
        long largestWindow = Long.MIN_VALUE;
        for (GroupingAggregator aggregator : this.aggregators) {
            GroupingAggregatorFunction groupingAggregatorFunction = aggregator.aggregatorFunction();
            if (!(groupingAggregatorFunction instanceof WindowGroupingAggregatorFunction)) continue;
            WindowGroupingAggregatorFunction aggregatorFunction = (WindowGroupingAggregatorFunction)groupingAggregatorFunction;
            largestWindow = Math.max(largestWindow, aggregatorFunction.window().toMillis());
        }
        return largestWindow;
    }

    private void expandWindowBuckets() {
        for (GroupingAggregator aggregator : this.aggregators) {
            if (!aggregator.mode().isOutputPartial()) continue;
            return;
        }
        long windowMillis = this.largestWindowMillis();
        if (windowMillis <= 0L) {
            return;
        }
        BytesRefLongBlockHash tsBlockHash = (BytesRefLongBlockHash)this.blockHash;
        long numGroups = tsBlockHash.numGroups();
        if (numGroups == 0L) {
            return;
        }
        this.expandingGroups = new ExpandingGroups(this.driverContext.bigArrays());
        for (long groupId = 0L; groupId < numGroups; ++groupId) {
            long tsid = tsBlockHash.getBytesRefKeyFromGroup(groupId);
            long endTimestamp = tsBlockHash.getLongKeyFromGroup(groupId);
            long bucket = this.timeBucket.nextRoundingValue(endTimestamp - this.timeResolution.convert(this.largestWindowMillis()));
            bucket = Math.max(bucket, tsBlockHash.getMinLongKey());
            while (bucket < endTimestamp) {
                if (tsBlockHash.addGroup(tsid, bucket) >= 0L) {
                    this.expandingGroups.addGroup(Math.toIntExact(groupId));
                }
                bucket = this.timeBucket.nextRoundingValue(bucket);
            }
        }
    }

    @Override
    protected void evaluateAggregator(GroupingAggregator aggregator, Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) {
        if (this.expandingGroups != null && this.expandingGroups.count > 0 && TimeSeriesAggregationOperator.isValuesAggregator(aggregator.aggregatorFunction())) {
            try (IntVector valuesSelected = TimeSeriesAggregationOperator.selectedForValuesAggregator(this.driverContext.blockFactory(), selected, this.expandingGroups);){
                super.evaluateAggregator(aggregator, blocks, offset, valuesSelected, evaluationContext);
            }
        } else {
            super.evaluateAggregator(aggregator, blocks, offset, selected, evaluationContext);
        }
    }

    private static IntVector selectedForValuesAggregator(BlockFactory blockFactory, IntVector selected, ExpandingGroups expandingGroups) {
        try (IntVector.FixedBuilder builder = blockFactory.newIntVectorFixedBuilder(selected.getPositionCount());){
            int i;
            int first = selected.getPositionCount() - expandingGroups.count;
            for (i = 0; i < first; ++i) {
                builder.appendInt(i, selected.getInt(i));
            }
            for (i = 0; i < expandingGroups.count; ++i) {
                builder.appendInt(first + i, expandingGroups.getGroup(i));
            }
            IntVector intVector = builder.build();
            return intVector;
        }
    }

    static boolean isValuesAggregator(GroupingAggregatorFunction aggregatorFunction) {
        return VALUES_CLASSES.contains(aggregatorFunction.getClass().getName());
    }

    @Override
    protected GroupingAggregatorEvaluationContext evaluationContext(BlockHash blockHash, Block[] keys) {
        if (keys.length < 2) {
            return super.evaluationContext(blockHash, keys);
        }
        final BytesRefLongBlockHash hash = (BytesRefLongBlockHash)blockHash;
        final LongBlock timestamps = keys[0].elementType() == ElementType.LONG ? (LongBlock)keys[0] : (LongBlock)keys[1];
        return new TimeSeriesGroupingAggregatorEvaluationContext(this.driverContext){

            @Override
            public long rangeStartInMillis(int groupId) {
                return TimeSeriesAggregationOperator.this.timeResolution.roundDownToMillis(timestamps.getLong(groupId));
            }

            @Override
            public long rangeEndInMillis(int groupId) {
                return TimeSeriesAggregationOperator.this.timeResolution.roundDownToMillis(TimeSeriesAggregationOperator.this.timeBucket.nextRoundingValue(timestamps.getLong(groupId)));
            }

            @Override
            public List<Integer> groupIdsFromWindow(int startingGroupId, Duration window) {
                long tsid = hash.getBytesRefKeyFromGroup(startingGroupId);
                long bucket = hash.getLongKeyFromGroup(startingGroupId);
                ArrayList<Integer> results = new ArrayList<Integer>();
                results.add(startingGroupId);
                long endTimestamp = bucket + TimeSeriesAggregationOperator.this.timeResolution.convert(window.toMillis());
                while ((bucket = TimeSeriesAggregationOperator.this.timeBucket.nextRoundingValue(bucket)) < endTimestamp) {
                    long nextGroupId = hash.getGroupId(tsid, bucket);
                    if (nextGroupId == -1L) continue;
                    results.add(Math.toIntExact(nextGroupId));
                }
                return results;
            }
        };
    }

    @Override
    public void close() {
        Releasables.close((Releasable[])new Releasable[]{this.expandingGroups, () -> super.close()});
    }

    static class ExpandingGroups
    extends AbstractRefCounted
    implements Releasable {
        private final BigArrays bigArrays;
        private IntArray newGroups;
        private int count;

        ExpandingGroups(BigArrays bigArrays) {
            this.bigArrays = bigArrays;
            this.newGroups = bigArrays.newIntArray(128L);
        }

        void addGroup(int groupId) {
            this.newGroups = this.bigArrays.grow(this.newGroups, (long)(this.count + 1));
            this.newGroups.set((long)this.count++, groupId);
        }

        int getGroup(int index) {
            return this.newGroups.get((long)index);
        }

        protected void closeInternal() {
            this.newGroups.close();
        }

        public void close() {
            this.decRef();
        }
    }

    public record Factory(Rounding.Prepared timeBucket, boolean dateNanos, List<BlockHash.GroupSpec> groups, AggregatorMode aggregatorMode, List<GroupingAggregator.Factory> aggregators, int maxPageSize) implements Operator.OperatorFactory
    {
        @Override
        public Operator get(DriverContext driverContext) {
            return new TimeSeriesAggregationOperator(this.timeBucket, this.dateNanos ? DateFieldMapper.Resolution.NANOSECONDS : DateFieldMapper.Resolution.MILLISECONDS, this.aggregators, () -> BlockHash.build(this.groups, driverContext.blockFactory(), this.maxPageSize, true), driverContext);
        }

        @Override
        public String describe() {
            return "TimeSeriesAggregationOperator[mode = <not-needed>, aggs = " + this.aggregators.stream().map(Describable::describe).collect(Collectors.joining(", ")) + "]";
        }
    }
}

