/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.vectors;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.elasticsearch.search.profile.query.QueryProfiler;
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
import org.elasticsearch.search.vectors.KnnScoreDocQuery;
import org.elasticsearch.search.vectors.QueryProfilerProvider;

abstract class AbstractIVFKnnVectorQuery
extends Query
implements QueryProfilerProvider {
    static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
    protected final String field;
    protected final int nProbe;
    protected final int k;
    protected final int numCands;
    protected final Query filter;
    protected final KnnSearchStrategy searchStrategy;
    protected int vectorOpsCount;

    protected AbstractIVFKnnVectorQuery(String field, int nProbe, int k, int numCands, Query filter) {
        if (k < 1) {
            throw new IllegalArgumentException("k must be at least 1, got: " + k);
        }
        if (nProbe < 1 && nProbe != -1) {
            throw new IllegalArgumentException("nProbe must be at least 1 or exactly -1, got: " + nProbe);
        }
        if (numCands < k) {
            throw new IllegalArgumentException("numCands must be at least k, got: " + numCands);
        }
        this.field = field;
        this.nProbe = nProbe;
        this.k = k;
        this.filter = filter;
        this.numCands = numCands;
        this.searchStrategy = new IVFKnnSearchStrategy(nProbe);
    }

    @Override
    public void visit(QueryVisitor visitor) {
        if (visitor.acceptField(this.field)) {
            visitor.visitLeaf(this);
        }
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        AbstractIVFKnnVectorQuery that = (AbstractIVFKnnVectorQuery)o;
        return this.k == that.k && Objects.equals(this.field, that.field) && Objects.equals(this.filter, that.filter) && Objects.equals(this.nProbe, that.nProbe);
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.field, this.k, this.filter, this.nProbe);
    }

    @Override
    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        Weight filterWeight;
        this.vectorOpsCount = 0;
        IndexReader reader = indexSearcher.getIndexReader();
        if (this.filter != null) {
            BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.filter, BooleanClause.Occur.FILTER).add(new FieldExistsQuery(this.field), BooleanClause.Occur.FILTER).build();
            Query rewritten = indexSearcher.rewrite(booleanQuery);
            if (rewritten.getClass() == MatchNoDocsQuery.class) {
                return rewritten;
            }
            filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
        } else {
            filterWeight = null;
        }
        KnnCollectorManager knnCollectorManager = this.getKnnCollectorManager(this.numCands, indexSearcher);
        TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
        List<LeafReaderContext> leafReaderContexts = reader.leaves();
        ArrayList tasks = new ArrayList(leafReaderContexts.size());
        for (LeafReaderContext context : leafReaderContexts) {
            tasks.add(() -> this.searchLeaf(context, filterWeight, knnCollectorManager));
        }
        TopDocs[] perLeafResults = (TopDocs[])taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
        TopDocs topK = TopDocs.merge(this.k, perLeafResults);
        this.vectorOpsCount = (int)topK.totalHits.value();
        if (topK.scoreDocs.length == 0) {
            return new MatchNoDocsQuery();
        }
        return new KnnScoreDocQuery(topK.scoreDocs, reader);
    }

    private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
        TopDocs results = this.getLeafResults(ctx, filterWeight, knnCollectorManager);
        if (ctx.docBase > 0) {
            for (ScoreDoc scoreDoc : results.scoreDocs) {
                scoreDoc.doc += ctx.docBase;
            }
        }
        return results;
    }

    TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) throws IOException {
        LeafReader reader = ctx.reader();
        Bits liveDocs = reader.getLiveDocs();
        if (filterWeight == null) {
            return this.approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
        }
        Scorer scorer = filterWeight.scorer(ctx);
        if (scorer == null) {
            return TopDocsCollector.EMPTY_TOPDOCS;
        }
        BitSet acceptDocs = this.createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
        int cost = acceptDocs.cardinality();
        return this.approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager);
    }

    abstract TopDocs approximateSearch(LeafReaderContext var1, Bits var2, int var3, KnnCollectorManager var4) throws IOException;

    protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
        return new IVFCollectorManager(k);
    }

    @Override
    public final void profile(QueryProfiler queryProfiler) {
        queryProfiler.addVectorOpsCount(this.vectorOpsCount);
    }

    BitSet createBitSet(DocIdSetIterator iterator, final Bits liveDocs, int maxDoc) throws IOException {
        if (liveDocs == null && iterator instanceof BitSetIterator) {
            BitSetIterator bitSetIterator = (BitSetIterator)iterator;
            return bitSetIterator.getBitSet();
        }
        FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(this, iterator){

            @Override
            protected boolean match(int doc) {
                return liveDocs == null || liveDocs.get(doc);
            }
        };
        return BitSet.of(filterIterator, maxDoc);
    }

    static class IVFCollectorManager
    implements KnnCollectorManager {
        private final int k;

        IVFCollectorManager(int k) {
            this.k = k;
        }

        @Override
        public KnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException {
            return new TopKnnCollector(this.k, visitedLimit, searchStrategy);
        }
    }
}

