/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors;

import java.io.Closeable;
import java.io.IOException;
import java.util.function.IntPredicate;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.AbstractKnnCollector;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;

public abstract class IVFVectorsReader
extends KnnVectorsReader {
    private final IndexInput ivfCentroids;
    private final IndexInput ivfClusters;
    private final SegmentReadState state;
    private final FieldInfos fieldInfos;
    protected final IntObjectHashMap<FieldEntry> fields;
    private final FlatVectorsReader rawVectorsReader;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
        this.state = state;
        this.fieldInfos = state.fieldInfos;
        this.rawVectorsReader = rawVectorsReader;
        this.fields = new IntObjectHashMap();
        String meta = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, "mivf");
        int versionMeta = -1;
        boolean success = false;
        try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta);){
            Throwable priorE = null;
            try {
                versionMeta = CodecUtil.checkIndexHeader(ivfMeta, "IVFVectorsFormat", 0, 0, state.segmentInfo.getId(), state.segmentSuffix);
                this.readFields(ivfMeta);
            }
            catch (Throwable exception) {
                priorE = exception;
            }
            finally {
                CodecUtil.checkFooter(ivfMeta, priorE);
            }
            this.ivfCentroids = IVFVectorsReader.openDataInput(state, versionMeta, "cenivf", "IVFVectorsFormat", state.context);
            this.ivfClusters = IVFVectorsReader.openDataInput(state, versionMeta, "clivf", "IVFVectorsFormat", state.context);
            success = true;
        }
        finally {
            if (!success) {
                IOUtils.closeWhileHandlingException((Closeable)this);
            }
        }
    }

    abstract CentroidQueryScorer getCentroidScorer(FieldInfo var1, int var2, IndexInput var3, float[] var4) throws IOException;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static IndexInput openDataInput(SegmentReadState state, int versionMeta, String fileExtension, String codecName, IOContext context) throws IOException {
        String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
        IndexInput in = state.directory.openInput(fileName, context);
        boolean success = false;
        try {
            int versionVectorData = CodecUtil.checkIndexHeader(in, codecName, 0, 0, state.segmentInfo.getId(), state.segmentSuffix);
            if (versionMeta != versionVectorData) {
                throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, in);
            }
            CodecUtil.retrieveChecksum(in);
            success = true;
            IndexInput indexInput = in;
            return indexInput;
        }
        finally {
            if (!success) {
                IOUtils.closeWhileHandlingException((Closeable)in);
            }
        }
    }

    private void readFields(ChecksumIndexInput meta) throws IOException {
        int fieldNumber = meta.readInt();
        while (fieldNumber != -1) {
            FieldInfo info = this.fieldInfos.fieldInfo(fieldNumber);
            if (info == null) {
                throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
            }
            this.fields.put(info.number, this.readField(meta, info));
            fieldNumber = meta.readInt();
        }
    }

    private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
        VectorEncoding vectorEncoding = IVFVectorsReader.readVectorEncoding(input);
        VectorSimilarityFunction similarityFunction = IVFVectorsReader.readSimilarityFunction(input);
        long centroidOffset = input.readLong();
        long centroidLength = input.readLong();
        int numPostingLists = input.readVInt();
        long[] postingListOffsets = new long[numPostingLists];
        for (int i = 0; i < numPostingLists; ++i) {
            postingListOffsets[i] = input.readLong();
        }
        float[] globalCentroid = new float[info.getVectorDimension()];
        float globalCentroidDp = 0.0f;
        if (numPostingLists > 0) {
            input.readFloats(globalCentroid, 0, globalCentroid.length);
            globalCentroidDp = Float.intBitsToFloat(input.readInt());
        }
        if (similarityFunction != info.getVectorSimilarityFunction()) {
            throw new IllegalStateException("Inconsistent vector similarity function for field=\"" + info.name + "\"; " + String.valueOf((Object)similarityFunction) + " != " + String.valueOf((Object)info.getVectorSimilarityFunction()));
        }
        return new FieldEntry(similarityFunction, vectorEncoding, centroidOffset, centroidLength, postingListOffsets, globalCentroid, globalCentroidDp);
    }

    private static VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
        int i = input.readInt();
        if (i < 0 || i >= Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS.size()) {
            throw new IllegalArgumentException("invalid distance function: " + i);
        }
        return Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS.get(i);
    }

    private static VectorEncoding readVectorEncoding(DataInput input) throws IOException {
        int encodingId = input.readInt();
        if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
            throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
        }
        return VectorEncoding.values()[encodingId];
    }

    @Override
    public final void checkIntegrity() throws IOException {
        this.rawVectorsReader.checkIntegrity();
        CodecUtil.checksumEntireFile(this.ivfCentroids);
        CodecUtil.checksumEntireFile(this.ivfClusters);
    }

    @Override
    public final FloatVectorValues getFloatVectorValues(String field) throws IOException {
        return this.rawVectorsReader.getFloatVectorValues(field);
    }

    @Override
    public final ByteVectorValues getByteVectorValues(String field) throws IOException {
        return this.rawVectorsReader.getByteVectorValues(field);
    }

    @Override
    public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
        FieldInfo fieldInfo = this.state.fieldInfos.fieldInfo(field);
        if (!fieldInfo.getVectorEncoding().equals((Object)VectorEncoding.FLOAT32)) {
            this.rawVectorsReader.search(field, target, knnCollector, acceptDocs);
            return;
        }
        if (fieldInfo.getVectorDimension() != target.length) {
            throw new IllegalArgumentException("vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension());
        }
        float percentFiltered = 1.0f;
        if (acceptDocs instanceof BitSet) {
            BitSet bitSet = (BitSet)acceptDocs;
            percentFiltered = Math.max(0.0f, Math.min(1.0f, (float)bitSet.approximateCardinality() / (float)bitSet.length()));
        }
        int numVectors = this.rawVectorsReader.getFloatVectorValues(field).size();
        FixedBitSet visitedDocs = new FixedBitSet(this.state.segmentInfo.maxDoc() + 1);
        IntPredicate needsScoring = docId -> {
            if (acceptDocs != null && !acceptDocs.get(docId)) {
                return false;
            }
            return !visitedDocs.getAndSet(docId);
        };
        assert (knnCollector instanceof AbstractKnnCollector);
        AbstractKnnCollector knnCollectorImpl = (AbstractKnnCollector)knnCollector;
        int nProbe = -1;
        KnnSearchStrategy knnSearchStrategy = knnCollector.getSearchStrategy();
        if (knnSearchStrategy instanceof IVFKnnSearchStrategy) {
            IVFKnnSearchStrategy ivfSearchStrategy = (IVFKnnSearchStrategy)knnSearchStrategy;
            nProbe = ivfSearchStrategy.getNProbe();
        }
        FieldEntry entry = this.fields.get(fieldInfo.number);
        CentroidQueryScorer centroidQueryScorer = this.getCentroidScorer(fieldInfo, entry.postingListOffsets.length, entry.centroidSlice(this.ivfCentroids), target);
        if (nProbe == -1) {
            nProbe = (int)Math.round(Math.log10(centroidQueryScorer.size()) * Math.sqrt(knnCollector.k()));
            nProbe = Math.max(Math.min(nProbe, centroidQueryScorer.size()), 1);
        }
        NeighborQueue centroidQueue = this.scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe);
        PostingVisitor scorer = this.getPostingVisitor(fieldInfo, this.ivfClusters, target, needsScoring);
        int centroidsVisited = 0;
        long expectedDocs = 0L;
        long actualDocs = 0L;
        while (centroidQueue.size() > 0 && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
            ++centroidsVisited;
            int centroidOrdinal = centroidQueue.pop();
            expectedDocs += (long)scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
            actualDocs += (long)scorer.visit(knnCollector);
        }
        if (acceptDocs != null) {
            float unfilteredRatioVisited = (float)expectedDocs / (float)numVectors;
            int filteredVectors = (int)Math.ceil((float)numVectors * percentFiltered);
            float expectedScored = Math.min((float)(2 * filteredVectors) * unfilteredRatioVisited, (float)expectedDocs / 2.0f);
            while (centroidQueue.size() > 0 && ((float)actualDocs < expectedScored || actualDocs < (long)knnCollector.k())) {
                int centroidOrdinal = centroidQueue.pop();
                scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
                actualDocs += (long)scorer.visit(knnCollector);
            }
        }
    }

    @Override
    public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
        FieldInfo fieldInfo = this.state.fieldInfos.fieldInfo(field);
        ByteVectorValues values = this.rawVectorsReader.getByteVectorValues(field);
        for (int i = 0; i < values.size(); ++i) {
            float score = fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i));
            knnCollector.collect(values.ordToDoc(i), score);
            if (!knnCollector.earlyTerminated()) continue;
            return;
        }
    }

    abstract NeighborQueue scorePostingLists(FieldInfo var1, KnnCollector var2, CentroidQueryScorer var3, int var4) throws IOException;

    @Override
    public void close() throws IOException {
        IOUtils.close(this.rawVectorsReader, this.ivfCentroids, this.ivfClusters);
    }

    abstract PostingVisitor getPostingVisitor(FieldInfo var1, IndexInput var2, float[] var3, IntPredicate var4) throws IOException;

    protected record FieldEntry(VectorSimilarityFunction similarityFunction, VectorEncoding vectorEncoding, long centroidOffset, long centroidLength, long[] postingListOffsets, float[] globalCentroid, float globalCentroidDp) {
        IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
            return centroidFile.slice("centroids", this.centroidOffset, this.centroidLength);
        }
    }

    static interface CentroidQueryScorer {
        public int size();

        public float[] centroid(int var1) throws IOException;

        public void bulkScore(NeighborQueue var1) throws IOException;
    }

    static interface PostingVisitor {
        public int resetPostingsScorer(int var1, float[] var2) throws IOException;

        public int visit(KnnCollector var1) throws IOException;
    }
}

