/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.vectors;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocAndFloatFeatureBuffer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues;
import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues;
import org.elasticsearch.search.profile.query.QueryProfiler;
import org.elasticsearch.search.vectors.AbstractIVFKnnVectorQuery;
import org.elasticsearch.search.vectors.KnnScoreDocQuery;
import org.elasticsearch.search.vectors.QueryProfilerProvider;

public abstract class RescoreKnnVectorQuery
extends Query
implements QueryProfilerProvider {
    protected final String fieldName;
    protected final float[] floatTarget;
    protected final VectorSimilarityFunction vectorSimilarityFunction;
    protected final int k;
    protected final Query innerQuery;
    protected long vectorOperations = 0L;

    private RescoreKnnVectorQuery(String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, int k, Query innerQuery) {
        this.fieldName = fieldName;
        this.floatTarget = floatTarget;
        this.vectorSimilarityFunction = vectorSimilarityFunction;
        this.k = k;
        this.innerQuery = innerQuery;
    }

    public static RescoreKnnVectorQuery fromInnerQuery(String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, int k, int rescoreK, Query innerQuery) {
        block3: {
            block2: {
                KnnByteVectorQuery bQuery;
                KnnFloatVectorQuery fQuery;
                if (innerQuery instanceof KnnFloatVectorQuery && (fQuery = (KnnFloatVectorQuery)innerQuery).getK() == rescoreK || innerQuery instanceof KnnByteVectorQuery && (bQuery = (KnnByteVectorQuery)innerQuery).getK() == rescoreK) break block2;
                if (!(innerQuery instanceof AbstractIVFKnnVectorQuery)) break block3;
                AbstractIVFKnnVectorQuery ivfQuery = (AbstractIVFKnnVectorQuery)innerQuery;
                if (ivfQuery.k != rescoreK) break block3;
            }
            return new InlineRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
        }
        return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery);
    }

    public Query innerQuery() {
        return this.innerQuery;
    }

    public int k() {
        return this.k;
    }

    @Override
    public void profile(QueryProfiler queryProfiler) {
        Query query = this.innerQuery;
        if (query instanceof QueryProfilerProvider) {
            QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider)((Object)query);
            queryProfilerProvider.profile(queryProfiler);
        }
        queryProfiler.addVectorOpsCount(this.vectorOperations);
    }

    @Override
    public void visit(QueryVisitor visitor) {
        this.innerQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this));
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        RescoreKnnVectorQuery that = (RescoreKnnVectorQuery)o;
        return Objects.equals(this.fieldName, that.fieldName) && Arrays.equals(this.floatTarget, that.floatTarget) && this.vectorSimilarityFunction == that.vectorSimilarityFunction && Objects.equals(this.k, that.k) && Objects.equals(this.innerQuery, that.innerQuery);
    }

    @Override
    public int hashCode() {
        return Objects.hash(new Object[]{this.fieldName, Arrays.hashCode(this.floatTarget), this.vectorSimilarityFunction, this.k, this.innerQuery});
    }

    @Override
    public String toString(String field) {
        return this.getClass().getSimpleName() + "{fieldName='" + this.fieldName + "', floatTarget=" + this.floatTarget[0] + "..., vectorSimilarityFunction=" + String.valueOf((Object)this.vectorSimilarityFunction) + ", k=" + this.k + ", vectorQuery=" + String.valueOf(this.innerQuery) + "}";
    }

    private static class InlineRescoreQuery
    extends RescoreKnnVectorQuery {
        private InlineRescoreQuery(String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, int k, Query innerQuery) {
            super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
        }

        @Override
        public Query rewrite(IndexSearcher searcher) throws IOException {
            DirectRescoreKnnVectorQuery rescoreQuery = new DirectRescoreKnnVectorQuery(this.fieldName, this.floatTarget, this.innerQuery);
            TopDocs topDocs = searcher.search((Query)rescoreQuery, this.k);
            this.vectorOperations = topDocs.totalHits.value();
            return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            return super.equals(o);
        }

        @Override
        public int hashCode() {
            return super.hashCode();
        }
    }

    private static class LateRescoreQuery
    extends RescoreKnnVectorQuery {
        final int rescoreK;

        private LateRescoreQuery(String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, int k, int rescoreK, Query innerQuery) {
            super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
            this.rescoreK = rescoreK;
        }

        @Override
        public Query rewrite(IndexSearcher searcher) throws IOException {
            TopDocs topDocs = searcher.search(this.innerQuery, this.rescoreK);
            this.vectorOperations = topDocs.totalHits.value();
            KnnScoreDocQuery topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
            DirectRescoreKnnVectorQuery rescoreQuery = new DirectRescoreKnnVectorQuery(this.fieldName, this.floatTarget, topDocsQuery);
            TopDocs rescoreTopDocs = searcher.search(rescoreQuery.rewrite(searcher), this.k);
            return new KnnScoreDocQuery(rescoreTopDocs.scoreDocs, searcher.getIndexReader());
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            LateRescoreQuery that = (LateRescoreQuery)o;
            return super.equals(o) && that.rescoreK == this.rescoreK;
        }

        @Override
        public int hashCode() {
            return Objects.hash(super.hashCode(), this.rescoreK);
        }
    }

    private static class DirectRescoreKnnVectorQuery
    extends Query {
        private final float[] floatTarget;
        private final String fieldName;
        private final Query innerQuery;

        DirectRescoreKnnVectorQuery(String fieldName, float[] floatTarget, Query innerQuery) {
            this.fieldName = fieldName;
            this.floatTarget = floatTarget;
            this.innerQuery = innerQuery;
        }

        @Override
        public String toString(String field) {
            return "DirectRescoreKnnVectorQuery[" + String.valueOf(this.innerQuery) + "]";
        }

        @Override
        public Query rewrite(IndexSearcher indexSearcher) throws IOException {
            Query innerRewritten = this.innerQuery.rewrite(indexSearcher);
            if (innerRewritten.getClass() == MatchNoDocsQuery.class) {
                return new MatchNoDocsQuery();
            }
            assert (innerRewritten.getClass() != MatchAllDocsQuery.class);
            ArrayList<ScoreDoc> results = new ArrayList<ScoreDoc>(10);
            for (LeafReaderContext leaf : indexSearcher.getIndexReader().leaves()) {
                FloatVectorValues knnVectorValues = leaf.reader().getFloatVectorValues(this.fieldName);
                if (knnVectorValues == null) continue;
                if (knnVectorValues.dimension() != this.floatTarget.length) {
                    throw new IllegalArgumentException("vector query dimension: " + this.floatTarget.length + " differs from field dimension: " + knnVectorValues.dimension());
                }
                Weight weight = innerRewritten.createWeight(indexSearcher, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
                Scorer scorer = weight.scorer(leaf);
                if (scorer == null) continue;
                DocIdSetIterator filterIterator = scorer.iterator();
                if (knnVectorValues instanceof BulkScorableFloatVectorValues) {
                    BulkScorableFloatVectorValues rescorableVectorValues = (BulkScorableFloatVectorValues)((Object)knnVectorValues);
                    this.rescoreBulk(leaf.docBase, rescorableVectorValues, results, filterIterator);
                    continue;
                }
                this.rescoreIndividually(leaf.docBase, knnVectorValues, leaf.reader().getFieldInfos().fieldInfo(this.fieldName).getVectorSimilarityFunction(), results, filterIterator);
            }
            ScoreDoc[] arrayResults = results.toArray(new ScoreDoc[0]);
            return new KnnScoreDocQuery(arrayResults, indexSearcher.getIndexReader());
        }

        @Override
        public void visit(QueryVisitor visitor) {
            if (visitor.acceptField(this.fieldName)) {
                visitor.visitLeaf(this);
            }
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || this.getClass() != obj.getClass()) {
                return false;
            }
            DirectRescoreKnnVectorQuery that = (DirectRescoreKnnVectorQuery)obj;
            return Objects.equals(this.innerQuery, that.innerQuery);
        }

        @Override
        public int hashCode() {
            return Objects.hash(this.innerQuery, this.getClass());
        }

        private void rescoreBulk(int docBase, BulkScorableFloatVectorValues rescorableVectorValues, List<ScoreDoc> queue, DocIdSetIterator filterIterator) throws IOException {
            BulkScorableVectorValues.BulkVectorScorer vectorReScorer = rescorableVectorValues.bulkRescorer(this.floatTarget);
            DocIdSetIterator iterator = vectorReScorer.iterator();
            BulkScorableVectorValues.BulkVectorScorer.BulkScorer bulkScorer = vectorReScorer.bulkScore(filterIterator);
            DocAndFloatFeatureBuffer buffer = new DocAndFloatFeatureBuffer();
            while (iterator.docID() != Integer.MAX_VALUE) {
                bulkScorer.nextDocsAndScores(64, null, buffer);
                for (int i = 0; i < buffer.size; ++i) {
                    float score = buffer.features[i];
                    int doc = buffer.docs[i];
                    queue.add(new ScoreDoc(doc + docBase, score));
                }
            }
        }

        private void rescoreIndividually(int docBase, FloatVectorValues knnVectorValues, VectorSimilarityFunction function, List<ScoreDoc> queue, DocIdSetIterator filterIterator) throws IOException {
            int doc;
            KnnVectorValues.DocIndexIterator knnVectorIterator = knnVectorValues.iterator();
            DocIdSetIterator conjunction = ConjunctionUtils.intersectIterators(List.of(knnVectorIterator, filterIterator));
            while ((doc = conjunction.nextDoc()) != Integer.MAX_VALUE) {
                assert (doc == knnVectorIterator.docID());
                float[] vector = knnVectorValues.vectorValue(knnVectorIterator.index());
                float score = function.compare(this.floatTarget, vector);
                if (Float.isNaN(score)) continue;
                queue.add(new ScoreDoc(doc + docBase, score));
            }
        }
    }
}

