/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.compute.lucene;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.Function;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.compute.lucene.DataPartitioning;
import org.elasticsearch.compute.lucene.LuceneSlice;
import org.elasticsearch.compute.lucene.PartialLeafReaderContext;
import org.elasticsearch.compute.lucene.ShardContext;
import org.elasticsearch.core.Nullable;

public final class LuceneSliceQueue {
    public static final int MAX_DOCS_PER_SLICE = 250000;
    public static final int MAX_SEGMENTS_PER_SLICE = 5;
    private final int totalSlices;
    private final Queue<LuceneSlice> slices;
    private final Map<String, PartitioningStrategy> partitioningStrategies;

    private LuceneSliceQueue(List<LuceneSlice> slices, Map<String, PartitioningStrategy> partitioningStrategies) {
        this.totalSlices = slices.size();
        this.slices = new ConcurrentLinkedQueue<LuceneSlice>(slices);
        this.partitioningStrategies = partitioningStrategies;
    }

    @Nullable
    public LuceneSlice nextSlice() {
        return this.slices.poll();
    }

    public int totalSlices() {
        return this.totalSlices;
    }

    public Map<String, PartitioningStrategy> partitioningStrategies() {
        return this.partitioningStrategies;
    }

    public Collection<String> remainingShardsIdentifiers() {
        return this.slices.stream().map(slice -> slice.shardContext().shardIdentifier()).toList();
    }

    public static LuceneSliceQueue create(List<? extends ShardContext> contexts, Function<ShardContext, List<QueryAndTags>> queryFunction, DataPartitioning dataPartitioning, Function<Query, PartitioningStrategy> autoStrategy, int taskConcurrency, Function<ShardContext, ScoreMode> scoreModeFunction) {
        ArrayList<LuceneSlice> slices = new ArrayList<LuceneSlice>();
        HashMap<String, PartitioningStrategy> partitioningStrategies = new HashMap<String, PartitioningStrategy>(contexts.size());
        for (ShardContext shardContext : contexts) {
            for (QueryAndTags queryAndExtra : queryFunction.apply(shardContext)) {
                ScoreMode scoreMode = scoreModeFunction.apply(shardContext);
                Query query = queryAndExtra.query;
                query = scoreMode.needsScores() ? query : new ConstantScoreQuery(query);
                try {
                    query = shardContext.searcher().rewrite(query);
                }
                catch (IOException e) {
                    throw new UncheckedIOException(e);
                }
                PartitioningStrategy partitioning = PartitioningStrategy.pick(dataPartitioning, autoStrategy, shardContext, query);
                partitioningStrategies.put(shardContext.shardIdentifier(), partitioning);
                List<List<PartialLeafReaderContext>> groups = partitioning.groups(shardContext.searcher(), taskConcurrency);
                Weight weight = LuceneSliceQueue.weight(shardContext, query, scoreMode);
                for (List<PartialLeafReaderContext> group : groups) {
                    if (group.isEmpty()) continue;
                    slices.add(new LuceneSlice(shardContext, group, weight, queryAndExtra.tags));
                }
            }
        }
        return new LuceneSliceQueue(slices, partitioningStrategies);
    }

    static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
        IndexSearcher searcher = ctx.searcher();
        try {
            Query actualQuery = scoreMode.needsScores() ? query : new ConstantScoreQuery(query);
            return searcher.createWeight(actualQuery, scoreMode, 1.0f);
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    public record QueryAndTags(Query query, List<Object> tags) {
    }

    public static enum PartitioningStrategy implements Writeable
    {
        SHARD(0){

            @Override
            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
                return List.of(searcher.getLeafContexts().stream().map(PartialLeafReaderContext::new).toList());
            }
        }
        ,
        SEGMENT(1){

            @Override
            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
                IndexSearcher.LeafSlice[] gs = IndexSearcher.slices((List)searcher.getLeafContexts(), (int)250000, (int)5, (boolean)false);
                return Arrays.stream(gs).map(g -> Arrays.stream(g.partitions).map(PartialLeafReaderContext::new).toList()).toList();
            }
        }
        ,
        DOC(2){

            @Override
            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
                int totalDocCount = searcher.getIndexReader().maxDoc();
                int normalMaxDocsPerSlice = totalDocCount / requestedNumSlices;
                int extraDocsInFirstSlice = totalDocCount % requestedNumSlices;
                ArrayList<List<PartialLeafReaderContext>> slices = new ArrayList<List<PartialLeafReaderContext>>();
                int docsAllocatedInCurrentSlice = 0;
                ArrayList<PartialLeafReaderContext> currentSlice = null;
                int maxDocsPerSlice = normalMaxDocsPerSlice + extraDocsInFirstSlice;
                for (LeafReaderContext ctx : searcher.getLeafContexts()) {
                    int numDocsToUse;
                    int numDocsInLeaf = ctx.reader().maxDoc();
                    for (int minDoc = 0; minDoc < numDocsInLeaf && (numDocsToUse = Math.min(maxDocsPerSlice - docsAllocatedInCurrentSlice, numDocsInLeaf - minDoc)) > 0; minDoc += numDocsToUse) {
                        if (currentSlice == null) {
                            currentSlice = new ArrayList<PartialLeafReaderContext>();
                        }
                        currentSlice.add(new PartialLeafReaderContext(ctx, minDoc, minDoc + numDocsToUse));
                        if ((docsAllocatedInCurrentSlice += numDocsToUse) != maxDocsPerSlice) continue;
                        slices.add(currentSlice);
                        maxDocsPerSlice = normalMaxDocsPerSlice;
                        currentSlice = null;
                        docsAllocatedInCurrentSlice = 0;
                    }
                }
                if (currentSlice != null) {
                    slices.add(currentSlice);
                }
                if (requestedNumSlices < totalDocCount && slices.size() != requestedNumSlices) {
                    throw new IllegalStateException("wrong number of slices, expected " + requestedNumSlices + " but got " + slices.size());
                }
                if (slices.stream().flatMapToInt(l -> l.stream().mapToInt(partialLeafReaderContext -> partialLeafReaderContext.maxDoc() - partialLeafReaderContext.minDoc())).sum() != totalDocCount) {
                    throw new IllegalStateException("wrong doc count");
                }
                return slices;
            }
        };

        private final byte id;
        private static final int SMALL_INDEX_BOUNDARY = 250000;

        private PartitioningStrategy(int id) {
            this.id = (byte)id;
        }

        public static PartitioningStrategy readFrom(StreamInput in) throws IOException {
            byte id = in.readByte();
            return switch (id) {
                case 0 -> SHARD;
                case 1 -> SEGMENT;
                case 2 -> DOC;
                default -> throw new IllegalArgumentException("invalid PartitioningStrategyId [" + id + "]");
            };
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeByte(this.id);
        }

        abstract List<List<PartialLeafReaderContext>> groups(IndexSearcher var1, int var2);

        private static PartitioningStrategy pick(DataPartitioning dataPartitioning, Function<Query, PartitioningStrategy> autoStrategy, ShardContext ctx, Query query) {
            return switch (dataPartitioning) {
                default -> throw new MatchException(null, null);
                case DataPartitioning.SHARD -> SHARD;
                case DataPartitioning.SEGMENT -> SEGMENT;
                case DataPartitioning.DOC -> DOC;
                case DataPartitioning.AUTO -> PartitioningStrategy.forAuto(autoStrategy, ctx, query);
            };
        }

        private static PartitioningStrategy forAuto(Function<Query, PartitioningStrategy> autoStrategy, ShardContext ctx, Query query) {
            if (ctx.searcher().getIndexReader().maxDoc() < 250000) {
                return SHARD;
            }
            return autoStrategy.apply(query);
        }
    }
}

