/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors.es93;

import java.io.IOException;
import java.util.List;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocAndFloatFeatureBuffer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues;
import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues;
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
import org.elasticsearch.index.codec.vectors.MergeReaderWrapper;

public class DirectIOCapableLucene99FlatVectorsFormat
extends DirectIOCapableFlatVectorsFormat {
    static final String NAME = "Lucene99FlatVectorsFormat";
    private final FlatVectorsScorer vectorsScorer;

    public DirectIOCapableLucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
        super(NAME);
        this.vectorsScorer = vectorsScorer;
    }

    @Override
    public FlatVectorsScorer flatVectorsScorer() {
        return this.vectorsScorer;
    }

    @Override
    protected FlatVectorsReader createReader(SegmentReadState state) throws IOException {
        return new Lucene99FlatVectorsReader(state, this.vectorsScorer);
    }

    public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
        return new Lucene99FlatVectorsWriter(state, this.vectorsScorer);
    }

    @Override
    public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
        if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && DirectIOCapableLucene99FlatVectorsFormat.canUseDirectIO(state)) {
            SegmentReadState directIOState = new SegmentReadState(state.directory, state.segmentInfo, state.fieldInfos, (IOContext)new DirectIOCapableFlatVectorsFormat.DirectIOContext(state.context.hints()), state.segmentSuffix);
            return new MergeReaderWrapper(new Lucene99FlatBulkScoringVectorsReader(directIOState, new Lucene99FlatVectorsReader(directIOState, this.vectorsScorer), this.vectorsScorer), (FlatVectorsReader)new Lucene99FlatVectorsReader(state, this.vectorsScorer));
        }
        return new Lucene99FlatVectorsReader(state, this.vectorsScorer);
    }

    static class Lucene99FlatBulkScoringVectorsReader
    extends FlatVectorsReader {
        private final Lucene99FlatVectorsReader inner;
        private final SegmentReadState state;

        Lucene99FlatBulkScoringVectorsReader(SegmentReadState state, Lucene99FlatVectorsReader inner, FlatVectorsScorer scorer) {
            super(scorer);
            this.inner = inner;
            this.state = state;
        }

        public void close() throws IOException {
            this.inner.close();
        }

        public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
            return this.inner.getRandomVectorScorer(field, target);
        }

        public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
            return this.inner.getRandomVectorScorer(field, target);
        }

        public void checkIntegrity() throws IOException {
            this.inner.checkIntegrity();
        }

        public FloatVectorValues getFloatVectorValues(String field) throws IOException {
            FloatVectorValues vectorValues = this.inner.getFloatVectorValues(field);
            if (vectorValues == null || vectorValues.size() == 0) {
                return null;
            }
            FieldInfo info = this.state.fieldInfos.fieldInfo(field);
            return new RescorerOffHeapVectorValues(vectorValues, info.getVectorSimilarityFunction(), this.vectorScorer);
        }

        public ByteVectorValues getByteVectorValues(String field) throws IOException {
            return this.inner.getByteVectorValues(field);
        }

        public long ramBytesUsed() {
            return this.inner.ramBytesUsed();
        }
    }

    private static class FloatBulkScorer
    implements BulkScorableVectorValues.BulkVectorScorer.BulkScorer {
        private final KnnVectorValues.DocIndexIterator indexIterator;
        private final DocIdSetIterator matchingDocs;
        private final RandomVectorScorer inner;
        private final int bulkSize;
        private final IndexInput inputSlice;
        private final int byteSize;
        private final int[] docBuffer;
        private final float[] scoreBuffer;

        FloatBulkScorer(RandomVectorScorer fvv, IndexInput inputSlice, int byteSize, int bulkSize, KnnVectorValues.DocIndexIterator iterator, DocIdSetIterator matchingDocs) {
            this.indexIterator = iterator;
            this.matchingDocs = matchingDocs;
            this.inner = fvv;
            this.bulkSize = bulkSize;
            this.inputSlice = inputSlice;
            this.docBuffer = new int[bulkSize];
            this.scoreBuffer = new float[bulkSize];
            this.byteSize = byteSize;
        }

        @Override
        public void nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException {
            int i;
            buffer.growNoCopy(nextCount);
            int size = 0;
            int doc = this.matchingDocs.docID();
            while (doc != Integer.MAX_VALUE && size < nextCount) {
                if (liveDocs == null || liveDocs.get(doc)) {
                    buffer.docs[size++] = this.indexIterator.index();
                }
                doc = this.matchingDocs.nextDoc();
            }
            int firstBulkSize = Math.min(this.bulkSize, size);
            for (int j = 0; j < firstBulkSize; ++j) {
                long ord = buffer.docs[j];
                this.inputSlice.prefetch(ord * (long)this.byteSize, (long)this.byteSize);
            }
            int loopBound = size - size % this.bulkSize;
            for (i = 0; i < loopBound; i += this.bulkSize) {
                int nextI = i + this.bulkSize;
                int nextBulkSize = Math.min(this.bulkSize, size - nextI);
                for (int j = 0; j < nextBulkSize; ++j) {
                    long ord = buffer.docs[nextI + j];
                    this.inputSlice.prefetch(ord * (long)this.byteSize, (long)this.byteSize);
                }
                System.arraycopy(buffer.docs, i, this.docBuffer, 0, this.bulkSize);
                this.inner.bulkScore(this.docBuffer, this.scoreBuffer, this.bulkSize);
                System.arraycopy(this.scoreBuffer, 0, buffer.features, i, this.bulkSize);
            }
            int countLeft = size - i;
            System.arraycopy(buffer.docs, i, this.docBuffer, 0, countLeft);
            this.inner.bulkScore(this.docBuffer, this.scoreBuffer, countLeft);
            System.arraycopy(this.scoreBuffer, 0, buffer.features, i, countLeft);
            buffer.size = size;
            for (int j = 0; j < size; ++j) {
                buffer.docs[j] = this.inner.ordToDoc(buffer.docs[j]);
            }
        }
    }

    private record PreFetchingFloatBulkScorer(RandomVectorScorer inner, KnnVectorValues.DocIndexIterator indexIterator, IndexInput inputSlice, int byteSize) implements BulkScorableVectorValues.BulkVectorScorer
    {
        public float score() throws IOException {
            return this.inner.score(this.indexIterator.index());
        }

        public DocIdSetIterator iterator() {
            return this.indexIterator;
        }

        @Override
        public BulkScorableVectorValues.BulkVectorScorer.BulkScorer bulkScore(DocIdSetIterator matchingDocs) throws IOException {
            KnnVectorValues.DocIndexIterator conjunctionScorer;
            Object object = conjunctionScorer = matchingDocs == null ? this.indexIterator : ConjunctionUtils.intersectIterators(List.of(matchingDocs, this.indexIterator));
            if (conjunctionScorer.docID() == -1) {
                conjunctionScorer.nextDoc();
            }
            return new FloatBulkScorer(this.inner, this.inputSlice, this.byteSize, 32, this.indexIterator, (DocIdSetIterator)conjunctionScorer);
        }
    }

    static class RescorerOffHeapVectorValues
    extends FloatVectorValues
    implements BulkScorableFloatVectorValues {
        private final VectorSimilarityFunction similarityFunction;
        private final FloatVectorValues inner;
        private final IndexInput inputSlice;
        private final FlatVectorsScorer scorer;

        RescorerOffHeapVectorValues(FloatVectorValues inner, VectorSimilarityFunction similarityFunction, FlatVectorsScorer scorer) {
            this.inner = inner;
            if (inner instanceof HasIndexSlice) {
                HasIndexSlice slice = (HasIndexSlice)inner;
                this.inputSlice = slice.getSlice();
            } else {
                this.inputSlice = null;
            }
            this.similarityFunction = similarityFunction;
            this.scorer = scorer;
        }

        public float[] vectorValue(int ord) throws IOException {
            return this.inner.vectorValue(ord);
        }

        public int dimension() {
            return this.inner.dimension();
        }

        public int size() {
            return this.inner.size();
        }

        public KnnVectorValues.DocIndexIterator iterator() {
            return this.inner.iterator();
        }

        public RescorerOffHeapVectorValues copy() throws IOException {
            return new RescorerOffHeapVectorValues(this.inner.copy(), this.similarityFunction, this.scorer);
        }

        @Override
        public BulkScorableVectorValues.BulkVectorScorer bulkRescorer(float[] target) throws IOException {
            return this.bulkScorer(target);
        }

        @Override
        public BulkScorableVectorValues.BulkVectorScorer bulkScorer(float[] target) throws IOException {
            KnnVectorValues.DocIndexIterator indexIterator = this.inner.iterator();
            RandomVectorScorer randomScorer = this.scorer.getRandomVectorScorer(this.similarityFunction, (KnnVectorValues)this.inner, target);
            return new PreFetchingFloatBulkScorer(randomScorer, indexIterator, this.inputSlice, this.dimension() * 4);
        }

        public VectorScorer scorer(float[] target) throws IOException {
            return this.inner.scorer(target);
        }
    }
}

