/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.compute.lucene;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollectorManager;
import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocBlock;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.lucene.DataPartitioning;
import org.elasticsearch.compute.lucene.LuceneOperator;
import org.elasticsearch.compute.lucene.LuceneSliceQueue;
import org.elasticsearch.compute.lucene.ShardContext;
import org.elasticsearch.compute.lucene.ShardRefCounted;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortAndFormats;
import org.elasticsearch.search.sort.SortBuilder;

public final class LuceneTopNSourceOperator
extends LuceneOperator {
    private static final int NUM_DOCS_INTERVAL = 4096;
    private final CircuitBreaker breaker;
    private final List<SortBuilder<?>> sorts;
    private final long estimatedPerRowSortSize;
    private final int limit;
    private final boolean needsScore;
    private ScoreDoc[] topDocs;
    private ShardRefCounted shardRefCounted;
    private int offset = 0;
    private PerShardCollector perShardCollector;
    private static final int FIELD_DOC_SIZE = Math.toIntExact(RamUsageEstimator.shallowSizeOf(FieldDoc.class));

    public LuceneTopNSourceOperator(List<? extends ShardContext> contexts, CircuitBreaker breaker, BlockFactory blockFactory, int maxPageSize, List<SortBuilder<?>> sorts, long estimatedPerRowSortSize, int limit, LuceneSliceQueue sliceQueue, boolean needsScore) {
        super(contexts, blockFactory, maxPageSize, sliceQueue);
        this.breaker = breaker;
        this.sorts = sorts;
        this.estimatedPerRowSortSize = estimatedPerRowSortSize;
        this.limit = limit;
        this.needsScore = needsScore;
        breaker.addEstimateBytesAndMaybeBreak(this.reserveSize(), "esql lucene topn");
    }

    @Override
    public boolean isFinished() {
        return this.doneCollecting && !this.isEmitting();
    }

    @Override
    public void finish() {
        this.doneCollecting = true;
        this.topDocs = null;
        this.shardRefCounted = null;
        assert (this.isFinished());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Page getCheckedOutput() throws IOException {
        if (this.isFinished()) {
            return null;
        }
        long start = System.nanoTime();
        try {
            if (this.isEmitting()) {
                Page page = this.emit();
                return page;
            }
            Page page = this.collect();
            return page;
        }
        finally {
            this.processingNanos += System.nanoTime() - start;
        }
    }

    private Page collect() throws IOException {
        LuceneOperator.LuceneScorer nextScorer;
        assert (!this.doneCollecting);
        LuceneOperator.LuceneScorer scorer = this.getCurrentOrLoadNextScorer();
        if (scorer == null) {
            this.doneCollecting = true;
            this.startEmitting();
            return this.emit();
        }
        try {
            if (!scorer.tags().isEmpty()) {
                throw new UnsupportedOperationException("tags not supported by " + String.valueOf(this.getClass()));
            }
            if (this.perShardCollector == null || this.perShardCollector.shardContext.index() != scorer.shardContext().index()) {
                this.perShardCollector = LuceneTopNSourceOperator.newPerShardCollector(scorer.shardContext(), this.sorts, this.needsScore, this.limit);
            }
            LeafCollector leafCollector = this.perShardCollector.getLeafCollector(scorer.leafReaderContext());
            scorer.scoreNextRange(leafCollector, scorer.leafReaderContext().reader().getLiveDocs(), 4096);
        }
        catch (CollectionTerminatedException cte) {
            scorer.markAsDone();
        }
        if (scorer.isDone() && ((nextScorer = this.getCurrentOrLoadNextScorer()) == null || nextScorer.shardContext().index() != scorer.shardContext().index())) {
            this.startEmitting();
            return this.emit();
        }
        return null;
    }

    private boolean isEmitting() {
        return this.topDocs != null;
    }

    private void startEmitting() {
        assert (!this.isEmitting()) : "offset=" + this.offset + " score_docs=" + Arrays.toString(this.topDocs);
        this.offset = 0;
        if (this.perShardCollector != null) {
            this.topDocs = this.perShardCollector.collector.topDocs().scoreDocs;
            int shardId = this.perShardCollector.shardContext.index();
            this.shardRefCounted = new ShardRefCounted.Single(shardId, (RefCounted)this.shardContextCounters.get(shardId));
        } else {
            this.topDocs = new ScoreDoc[0];
        }
    }

    private void stopEmitting() {
        this.topDocs = null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private Page emit() {
        if (this.offset >= this.topDocs.length) {
            this.stopEmitting();
            return null;
        }
        int size = Math.min(this.maxPageSize, this.topDocs.length - this.offset);
        IntBlock shard = null;
        IntVector segments = null;
        IntVector docs = null;
        DocBlock docBlock = null;
        DoubleBlock scores = null;
        Page page = null;
        try {
            try (IntVector.FixedBuilder currentSegmentBuilder = this.blockFactory.newIntVectorFixedBuilder(size);
                 IntVector.FixedBuilder currentDocsBuilder = this.blockFactory.newIntVectorFixedBuilder(size);
                 DoubleVector.Builder currentScoresBuilder = this.scoreVectorOrNull(size);){
                int start = this.offset;
                this.offset += size;
                List leafContexts = this.perShardCollector.shardContext.searcher().getLeafContexts();
                for (int i = start; i < this.offset; ++i) {
                    int doc = this.topDocs[i].doc;
                    int segment = ReaderUtil.subIndex((int)doc, (List)leafContexts);
                    currentSegmentBuilder.appendInt(segment);
                    currentDocsBuilder.appendInt(doc - ((LeafReaderContext)leafContexts.get((int)segment)).docBase);
                    if (currentScoresBuilder != null) {
                        float score = this.getScore(this.topDocs[i]);
                        currentScoresBuilder.appendDouble(score);
                    }
                    this.topDocs[i] = null;
                }
                int shardId = this.perShardCollector.shardContext.index();
                shard = this.blockFactory.newConstantIntBlockWith(shardId, size);
                segments = currentSegmentBuilder.build();
                docs = currentDocsBuilder.build();
                docBlock = new DocVector(this.shardRefCounted, shard.asVector(), segments, docs, null).asBlock();
                shard = null;
                segments = null;
                docs = null;
                if (currentScoresBuilder == null) {
                    page = new Page(size, docBlock);
                } else {
                    scores = currentScoresBuilder.build().asBlock();
                    page = new Page(size, docBlock, scores);
                }
            }
            if (page != null) return page;
        }
        catch (Throwable throwable) {
            if (page != null) throw throwable;
            Releasables.closeExpectNoException((Releasable[])new Releasable[]{shard, segments, docs, docBlock, scores});
            throw throwable;
        }
        Releasables.closeExpectNoException((Releasable[])new Releasable[]{shard, segments, docs, docBlock, scores});
        return page;
    }

    private float getScore(ScoreDoc scoreDoc) {
        if (scoreDoc instanceof FieldDoc) {
            FieldDoc fieldDoc = (FieldDoc)scoreDoc;
            if (Float.isNaN(fieldDoc.score)) {
                if (this.sorts != null) {
                    return ((Float)fieldDoc.fields[this.sorts.size() + 1]).floatValue();
                }
                return ((Float)fieldDoc.fields[0]).floatValue();
            }
            return fieldDoc.score;
        }
        return scoreDoc.score;
    }

    private DoubleVector.Builder scoreVectorOrNull(int size) {
        if (this.needsScore) {
            return this.blockFactory.newDoubleVectorFixedBuilder(size);
        }
        return null;
    }

    @Override
    protected void describe(StringBuilder sb) {
        sb.append(", limit = ").append(this.limit);
        sb.append(", needsScore = ").append(this.needsScore);
        String notPrettySorts = this.sorts.stream().map(Strings::toString).collect(Collectors.joining(","));
        sb.append(", sorts = [").append(notPrettySorts).append("]");
    }

    @Override
    protected void additionalClose() {
        Releasables.close(() -> this.breaker.addWithoutBreaking(-this.reserveSize()));
    }

    private long reserveSize() {
        long perRowSize = (long)FIELD_DOC_SIZE + this.estimatedPerRowSortSize;
        return (long)this.limit * perRowSize;
    }

    private static Function<ShardContext, ScoreMode> scoreModeFunction(List<SortBuilder<?>> sorts, boolean needsScore) {
        return ctx -> {
            try {
                return LuceneTopNSourceOperator.newPerShardCollector((ShardContext)ctx, sorts, (boolean)needsScore, (int)1).collector.scoreMode();
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        };
    }

    private static PerShardCollector newPerShardCollector(ShardContext context, List<SortBuilder<?>> sorts, boolean needsScore, int limit) throws IOException {
        Optional<SortAndFormats> sortAndFormats = context.buildSort(sorts);
        if (sortAndFormats.isEmpty()) {
            throw new IllegalStateException("sorts must not be disabled in TopN");
        }
        if (!needsScore) {
            return new NonScoringPerShardCollector(context, sortAndFormats.get().sort, limit);
        }
        Sort sort = sortAndFormats.get().sort;
        if (Sort.RELEVANCE.equals((Object)sort)) {
            return new ScoringPerShardCollector(context, (TopDocsCollector<?>)new TopScoreDocCollectorManager(limit, null, 0).newCollector());
        }
        ArrayList<SortField> l = new ArrayList<SortField>(Arrays.asList(sort.getSort()));
        l.add(SortField.FIELD_DOC);
        l.add(SortField.FIELD_SCORE);
        sort = new Sort((SortField[])l.toArray(SortField[]::new));
        return new ScoringPerShardCollector(context, (TopDocsCollector<?>)new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
    }

    static abstract class PerShardCollector {
        private final ShardContext shardContext;
        private final TopDocsCollector<?> collector;
        private int leafIndex;
        private LeafCollector leafCollector;
        private Thread currentThread;

        PerShardCollector(ShardContext shardContext, TopDocsCollector<?> collector) {
            this.shardContext = shardContext;
            this.collector = collector;
        }

        LeafCollector getLeafCollector(LeafReaderContext leafReaderContext) throws IOException {
            if (this.currentThread != Thread.currentThread() || this.leafIndex != leafReaderContext.ord) {
                this.leafCollector = this.collector.getLeafCollector(leafReaderContext);
                this.leafIndex = leafReaderContext.ord;
                this.currentThread = Thread.currentThread();
            }
            return this.leafCollector;
        }
    }

    static final class NonScoringPerShardCollector
    extends PerShardCollector {
        NonScoringPerShardCollector(ShardContext shardContext, Sort sort, int limit) {
            super(shardContext, (TopDocsCollector<?>)new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
        }
    }

    static final class ScoringPerShardCollector
    extends PerShardCollector {
        ScoringPerShardCollector(ShardContext shardContext, TopDocsCollector<?> topDocsCollector) {
            super(shardContext, topDocsCollector);
        }
    }

    public static class Factory
    extends LuceneOperator.Factory {
        private final List<? extends ShardContext> contexts;
        private final int maxPageSize;
        private final List<SortBuilder<?>> sorts;
        private final long estimatedPerRowSortSize;

        public Factory(List<? extends ShardContext> contexts, Function<ShardContext, List<LuceneSliceQueue.QueryAndTags>> queryFunction, DataPartitioning dataPartitioning, int taskConcurrency, int maxPageSize, int limit, List<SortBuilder<?>> sorts, long estimatedPerRowSortSize, boolean needsScore) {
            super(contexts, queryFunction, dataPartitioning, query -> LuceneSliceQueue.PartitioningStrategy.SHARD, taskConcurrency, limit, needsScore, LuceneTopNSourceOperator.scoreModeFunction(sorts, needsScore));
            this.contexts = contexts;
            this.maxPageSize = maxPageSize;
            this.sorts = sorts;
            this.estimatedPerRowSortSize = estimatedPerRowSortSize;
        }

        @Override
        public SourceOperator get(DriverContext driverContext) {
            return new LuceneTopNSourceOperator(this.contexts, driverContext.breaker(), driverContext.blockFactory(), this.maxPageSize, this.sorts, this.estimatedPerRowSortSize, this.limit, this.sliceQueue, this.needsScore);
        }

        public int maxPageSize() {
            return this.maxPageSize;
        }

        @Override
        public String describe() {
            String notPrettySorts = this.sorts.stream().map(Strings::toString).collect(Collectors.joining(","));
            return "LuceneTopNSourceOperator[dataPartitioning = " + String.valueOf((Object)this.dataPartitioning) + ", maxPageSize = " + this.maxPageSize + ", limit = " + this.limit + ", needsScore = " + this.needsScore + ", sorts = [" + notPrettySorts + "]]";
        }
    }
}

