/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.compute.lucene;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.function.Function;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.compute.lucene.DataPartitioning;
import org.elasticsearch.compute.lucene.IndexedByShardId;
import org.elasticsearch.compute.lucene.LuceneSlice;
import org.elasticsearch.compute.lucene.PartialLeafReaderContext;
import org.elasticsearch.compute.lucene.ShardContext;
import org.elasticsearch.core.Nullable;

public final class LuceneSliceQueue {
    public static final int MAX_DOCS_PER_SLICE = 250000;
    public static final int MAX_SEGMENTS_PER_SLICE = 5;
    private final int totalSlices;
    private final Map<String, PartitioningStrategy> partitioningStrategies;
    private final AtomicReferenceArray<LuceneSlice> slices;
    private final Queue<Integer> queryHeads;
    private final Queue<Integer> segmentHeads;
    private final Queue<Integer> stealableSlices;

    LuceneSliceQueue(List<LuceneSlice> sliceList, Map<String, PartitioningStrategy> partitioningStrategies) {
        this.totalSlices = sliceList.size();
        this.slices = new AtomicReferenceArray(sliceList.size());
        for (int i = 0; i < sliceList.size(); ++i) {
            this.slices.set(i, sliceList.get(i));
        }
        this.partitioningStrategies = partitioningStrategies;
        this.queryHeads = ConcurrentCollections.newQueue();
        this.segmentHeads = ConcurrentCollections.newQueue();
        this.stealableSlices = ConcurrentCollections.newQueue();
        for (LuceneSlice slice : sliceList) {
            if (slice.queryHead()) {
                this.queryHeads.add(slice.slicePosition());
                continue;
            }
            if (slice.getLeaf(0).minDoc() == 0) {
                this.segmentHeads.add(slice.slicePosition());
                continue;
            }
            this.stealableSlices.add(slice.slicePosition());
        }
    }

    @Nullable
    public LuceneSlice nextSlice(LuceneSlice prev) {
        LuceneSlice slice;
        int nextId;
        if (prev != null && (nextId = prev.slicePosition() + 1) < this.totalSlices && (slice = (LuceneSlice)this.slices.getAndSet(nextId, null)) != null) {
            return slice;
        }
        for (Queue<Integer> ids : List.of(this.queryHeads, this.segmentHeads, this.stealableSlices)) {
            Integer nextId2;
            while ((nextId2 = ids.poll()) != null) {
                LuceneSlice slice2 = this.slices.getAndSet(nextId2, null);
                if (slice2 == null) continue;
                return slice2;
            }
        }
        return null;
    }

    public int totalSlices() {
        return this.totalSlices;
    }

    public Map<String, PartitioningStrategy> partitioningStrategies() {
        return this.partitioningStrategies;
    }

    public static LuceneSliceQueue create(IndexedByShardId<? extends ShardContext> contexts, Function<ShardContext, List<QueryAndTags>> queryFunction, DataPartitioning dataPartitioning, Function<Query, PartitioningStrategy> autoStrategy, int taskConcurrency, Function<ShardContext, ScoreMode> scoreModeFunction) {
        ArrayList<LuceneSlice> slices = new ArrayList<LuceneSlice>();
        HashMap<String, PartitioningStrategy> partitioningStrategies = new HashMap<String, PartitioningStrategy>();
        int nextSliceId = 0;
        for (ShardContext shardContext : contexts.collection()) {
            for (QueryAndTags queryAndExtra : queryFunction.apply(shardContext)) {
                ScoreMode scoreMode = scoreModeFunction.apply(shardContext);
                Query query = queryAndExtra.query;
                query = scoreMode.needsScores() ? query : new ConstantScoreQuery(query);
                try {
                    query = shardContext.searcher().rewrite(query);
                }
                catch (IOException e) {
                    throw new UncheckedIOException(e);
                }
                PartitioningStrategy partitioning = PartitioningStrategy.pick(dataPartitioning, autoStrategy, shardContext, query);
                partitioningStrategies.put(shardContext.shardIdentifier(), partitioning);
                List<List<PartialLeafReaderContext>> groups = partitioning.groups(shardContext.searcher(), taskConcurrency);
                Weight weight = LuceneSliceQueue.weight(shardContext, query, scoreMode);
                boolean queryHead = true;
                for (List<PartialLeafReaderContext> group : groups) {
                    if (group.isEmpty()) continue;
                    int slicePosition = nextSliceId++;
                    slices.add(new LuceneSlice(slicePosition, queryHead, shardContext, group, weight, queryAndExtra.tags));
                    queryHead = false;
                }
            }
        }
        return new LuceneSliceQueue(slices, partitioningStrategies);
    }

    static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
        IndexSearcher searcher = ctx.searcher();
        try {
            Query actualQuery = scoreMode.needsScores() ? query : new ConstantScoreQuery(query);
            return searcher.createWeight(actualQuery, scoreMode, 1.0f);
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    public record QueryAndTags(Query query, List<Object> tags) {
    }

    public static enum PartitioningStrategy implements Writeable
    {
        SHARD(0){

            @Override
            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency) {
                return List.of(searcher.getLeafContexts().stream().map(PartialLeafReaderContext::new).toList());
            }
        }
        ,
        SEGMENT(1){

            @Override
            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency) {
                IndexSearcher.LeafSlice[] gs = IndexSearcher.slices((List)searcher.getLeafContexts(), (int)250000, (int)5, (boolean)false);
                return Arrays.stream(gs).map(g -> Arrays.stream(g.partitions).map(PartialLeafReaderContext::new).toList()).toList();
            }
        }
        ,
        DOC(2){

            @Override
            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency) {
                int totalDocCount = searcher.getIndexReader().maxDoc();
                int desiredSliceSize = Math.clamp((long)Math.ceilDiv(totalDocCount, taskConcurrency), 1, 250000);
                return new AdaptivePartitioner(Math.max(1, desiredSliceSize), 5).partition(searcher.getLeafContexts());
            }
        };

        private final byte id;
        private static final int SMALL_INDEX_BOUNDARY = 250000;

        private PartitioningStrategy(int id) {
            this.id = (byte)id;
        }

        public static PartitioningStrategy readFrom(StreamInput in) throws IOException {
            byte id = in.readByte();
            return switch (id) {
                case 0 -> SHARD;
                case 1 -> SEGMENT;
                case 2 -> DOC;
                default -> throw new IllegalArgumentException("invalid PartitioningStrategyId [" + id + "]");
            };
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeByte(this.id);
        }

        abstract List<List<PartialLeafReaderContext>> groups(IndexSearcher var1, int var2);

        private static PartitioningStrategy pick(DataPartitioning dataPartitioning, Function<Query, PartitioningStrategy> autoStrategy, ShardContext ctx, Query query) {
            return switch (dataPartitioning) {
                default -> throw new MatchException(null, null);
                case DataPartitioning.SHARD -> SHARD;
                case DataPartitioning.SEGMENT -> SEGMENT;
                case DataPartitioning.DOC -> DOC;
                case DataPartitioning.AUTO -> PartitioningStrategy.forAuto(autoStrategy, ctx, query);
            };
        }

        private static PartitioningStrategy forAuto(Function<Query, PartitioningStrategy> autoStrategy, ShardContext ctx, Query query) {
            if (ctx.searcher().getIndexReader().maxDoc() < 250000) {
                return SHARD;
            }
            return autoStrategy.apply(query);
        }
    }

    static final class AdaptivePartitioner {
        final int desiredDocsPerSlice;
        final int maxDocsPerSlice;
        final int maxSegmentsPerSlice;

        AdaptivePartitioner(int desiredDocsPerSlice, int maxSegmentsPerSlice) {
            this.desiredDocsPerSlice = desiredDocsPerSlice;
            this.maxDocsPerSlice = desiredDocsPerSlice * 5 / 4;
            this.maxSegmentsPerSlice = maxSegmentsPerSlice;
        }

        List<List<PartialLeafReaderContext>> partition(List<LeafReaderContext> leaves) {
            ArrayList<LeafReaderContext> smallSegments = new ArrayList<LeafReaderContext>();
            ArrayList<LeafReaderContext> largeSegments = new ArrayList<LeafReaderContext>();
            ArrayList<List<PartialLeafReaderContext>> results = new ArrayList<List<PartialLeafReaderContext>>();
            for (LeafReaderContext leaf : leaves) {
                if (leaf.reader().maxDoc() >= 5 * this.desiredDocsPerSlice) {
                    largeSegments.add(leaf);
                    continue;
                }
                smallSegments.add(leaf);
            }
            largeSegments.sort(Collections.reverseOrder(Comparator.comparingInt(l -> l.reader().maxDoc())));
            for (LeafReaderContext segment : largeSegments) {
                results.addAll(this.partitionOneLargeSegment(segment));
            }
            results.addAll(this.partitionSmallSegments(smallSegments));
            return results;
        }

        List<List<PartialLeafReaderContext>> partitionOneLargeSegment(LeafReaderContext leaf) {
            int numDocsInLeaf = leaf.reader().maxDoc();
            int numSlices = Math.max(1, numDocsInLeaf / this.desiredDocsPerSlice);
            while (Math.ceilDiv(numDocsInLeaf, numSlices) > this.maxDocsPerSlice) {
                ++numSlices;
            }
            int docPerSlice = numDocsInLeaf / numSlices;
            int leftoverDocs = numDocsInLeaf % numSlices;
            int minDoc = 0;
            ArrayList<List<PartialLeafReaderContext>> results = new ArrayList<List<PartialLeafReaderContext>>();
            while (minDoc < numDocsInLeaf) {
                int docsToUse = docPerSlice;
                if (leftoverDocs > 0) {
                    --leftoverDocs;
                    ++docsToUse;
                }
                int maxDoc = Math.min(minDoc + docsToUse, numDocsInLeaf);
                results.add(List.of(new PartialLeafReaderContext(leaf, minDoc, maxDoc)));
                minDoc = maxDoc;
            }
            assert (leftoverDocs == 0) : leftoverDocs;
            assert (results.stream().allMatch(s -> s.size() == 1)) : "must have one partial leaf per slice";
            assert (results.stream().flatMapToInt(ss -> ss.stream().mapToInt(s -> s.maxDoc() - s.minDoc())).sum() == numDocsInLeaf);
            return results;
        }

        List<List<PartialLeafReaderContext>> partitionSmallSegments(List<LeafReaderContext> leaves) {
            IndexSearcher.LeafSlice[] slices = IndexSearcher.slices(leaves, (int)this.maxDocsPerSlice, (int)this.maxSegmentsPerSlice, (boolean)true);
            return Arrays.stream(slices).map(g -> Arrays.stream(g.partitions).map(PartialLeafReaderContext::new).toList()).toList();
        }
    }
}

