/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors.diskbbq;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;

public abstract class IVFVectorsReader
extends KnnVectorsReader {
    private final IndexInput ivfCentroids;
    private final IndexInput ivfClusters;
    private final SegmentReadState state;
    private final FieldInfos fieldInfos;
    protected final IntObjectHashMap<FieldEntry> fields;
    private final Map<String, FlatVectorsReader> rawVectorReaders;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected IVFVectorsReader(SegmentReadState state, Map<String, FlatVectorsReader> rawVectorReaders) throws IOException {
        this.state = state;
        this.fieldInfos = state.fieldInfos;
        this.fields = new IntObjectHashMap();
        this.rawVectorReaders = rawVectorReaders;
        String meta = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, "mivf");
        int versionMeta = -1;
        boolean success = false;
        try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta);){
            Throwable priorE = null;
            try {
                versionMeta = CodecUtil.checkIndexHeader(ivfMeta, "ES920DiskBBQVectorsFormat", 0, 0, state.segmentInfo.getId(), state.segmentSuffix);
                this.readFields(ivfMeta);
            }
            catch (Throwable exception) {
                priorE = exception;
            }
            finally {
                CodecUtil.checkFooter(ivfMeta, priorE);
            }
            this.ivfCentroids = IVFVectorsReader.openDataInput(state, versionMeta, "cenivf", "ES920DiskBBQVectorsFormat", state.context);
            this.ivfClusters = IVFVectorsReader.openDataInput(state, versionMeta, "clivf", "ES920DiskBBQVectorsFormat", state.context);
            success = true;
        }
        finally {
            if (!success) {
                IOUtils.closeWhileHandlingException((Closeable)this);
            }
        }
    }

    abstract CentroidIterator getCentroidIterator(FieldInfo var1, int var2, IndexInput var3, float[] var4, IndexInput var5, float var6) throws IOException;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static IndexInput openDataInput(SegmentReadState state, int versionMeta, String fileExtension, String codecName, IOContext context) throws IOException {
        String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
        IndexInput in = state.directory.openInput(fileName, context);
        boolean success = false;
        try {
            int versionVectorData = CodecUtil.checkIndexHeader(in, codecName, 0, 0, state.segmentInfo.getId(), state.segmentSuffix);
            if (versionMeta != versionVectorData) {
                throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, in);
            }
            CodecUtil.retrieveChecksum(in);
            success = true;
            IndexInput indexInput = in;
            return indexInput;
        }
        finally {
            if (!success) {
                IOUtils.closeWhileHandlingException((Closeable)in);
            }
        }
    }

    private void readFields(ChecksumIndexInput meta) throws IOException {
        int fieldNumber = meta.readInt();
        while (fieldNumber != -1) {
            FieldInfo info = this.fieldInfos.fieldInfo(fieldNumber);
            if (info == null) {
                throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
            }
            this.fields.put(info.number, this.readField(meta, info));
            fieldNumber = meta.readInt();
        }
    }

    private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
        String rawVectorFormat = input.readString();
        VectorEncoding vectorEncoding = IVFVectorsReader.readVectorEncoding(input);
        VectorSimilarityFunction similarityFunction = IVFVectorsReader.readSimilarityFunction(input);
        if (similarityFunction != info.getVectorSimilarityFunction()) {
            throw new IllegalStateException("Inconsistent vector similarity function for field=\"" + info.name + "\"; " + String.valueOf((Object)similarityFunction) + " != " + String.valueOf((Object)info.getVectorSimilarityFunction()));
        }
        int numCentroids = input.readInt();
        long centroidOffset = input.readLong();
        long centroidLength = input.readLong();
        float[] globalCentroid = new float[info.getVectorDimension()];
        long postingListOffset = -1L;
        long postingListLength = -1L;
        float globalCentroidDp = 0.0f;
        if (centroidLength > 0L) {
            postingListOffset = input.readLong();
            postingListLength = input.readLong();
            input.readFloats(globalCentroid, 0, globalCentroid.length);
            globalCentroidDp = Float.intBitsToFloat(input.readInt());
        }
        return new FieldEntry(rawVectorFormat, similarityFunction, vectorEncoding, numCentroids, centroidOffset, centroidLength, postingListOffset, postingListLength, globalCentroid, globalCentroidDp);
    }

    private static VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
        int i = input.readInt();
        if (i < 0 || i >= Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS.size()) {
            throw new IllegalArgumentException("invalid distance function: " + i);
        }
        return Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS.get(i);
    }

    private static VectorEncoding readVectorEncoding(DataInput input) throws IOException {
        int encodingId = input.readInt();
        if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
            throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
        }
        return VectorEncoding.values()[encodingId];
    }

    @Override
    public final void checkIntegrity() throws IOException {
        for (FlatVectorsReader reader : this.rawVectorReaders.values()) {
            reader.checkIntegrity();
        }
        CodecUtil.checksumEntireFile(this.ivfCentroids);
        CodecUtil.checksumEntireFile(this.ivfClusters);
    }

    private FieldEntry getFieldEntryOrThrow(String field) {
        FieldEntry entry;
        FieldInfo info = this.fieldInfos.fieldInfo(field);
        if (info == null || (entry = this.fields.get(info.number)) == null) {
            throw new IllegalArgumentException("field=\"" + field + "\" not found");
        }
        return entry;
    }

    private FlatVectorsReader getReaderForField(String field) {
        String formatName = this.getFieldEntryOrThrow((String)field).rawVectorFormatName;
        FlatVectorsReader reader = this.rawVectorReaders.get(formatName);
        if (reader == null) {
            throw new IllegalArgumentException("Could not find raw vector format [" + formatName + "] for field [" + field + "]");
        }
        return reader;
    }

    @Override
    public final FloatVectorValues getFloatVectorValues(String field) throws IOException {
        return this.getReaderForField(field).getFloatVectorValues(field);
    }

    @Override
    public final ByteVectorValues getByteVectorValues(String field) throws IOException {
        return this.getReaderForField(field).getByteVectorValues(field);
    }

    @Override
    public final void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
        FieldInfo fieldInfo = this.state.fieldInfos.fieldInfo(field);
        if (!fieldInfo.getVectorEncoding().equals((Object)VectorEncoding.FLOAT32)) {
            this.getReaderForField(field).search(field, target, knnCollector, acceptDocs);
            return;
        }
        if (fieldInfo.getVectorDimension() != target.length) {
            throw new IllegalArgumentException("vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension());
        }
        int numVectors = this.getReaderForField(field).getFloatVectorValues(field).size();
        float percentFiltered = Math.max(0.0f, Math.min(1.0f, (float)acceptDocs.cost() / (float)numVectors));
        float visitRatio = 0.0f;
        KnnSearchStrategy knnSearchStrategy = knnCollector.getSearchStrategy();
        if (knnSearchStrategy instanceof IVFKnnSearchStrategy) {
            IVFKnnSearchStrategy ivfSearchStrategy = (IVFKnnSearchStrategy)knnSearchStrategy;
            visitRatio = ivfSearchStrategy.getVisitRatio();
        }
        FieldEntry entry = this.fields.get(fieldInfo.number);
        if (visitRatio == 0.0f) {
            float estimated = Math.round(Math.log10(numVectors) * Math.log10(numVectors) * (double)knnCollector.k());
            visitRatio = estimated / (float)numVectors;
        }
        long maxVectorVisited = (long)(2.0 * (double)visitRatio * (double)numVectors);
        IndexInput postListSlice = entry.postingListSlice(this.ivfClusters);
        CentroidIterator centroidPrefetchingIterator = this.getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(this.ivfCentroids), target, postListSlice, visitRatio);
        Bits acceptDocsBits = acceptDocs.bits();
        PostingVisitor scorer = this.getPostingVisitor(fieldInfo, postListSlice, target, acceptDocsBits);
        long expectedDocs = 0L;
        long actualDocs = 0L;
        while (centroidPrefetchingIterator.hasNext() && (maxVectorVisited > expectedDocs || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) {
            CentroidOffsetAndLength offsetAndLength = centroidPrefetchingIterator.nextPostingListOffsetAndLength();
            expectedDocs += (long)scorer.resetPostingsScorer(offsetAndLength.offset());
            actualDocs += (long)scorer.visit(knnCollector);
            if (knnCollector.getSearchStrategy() == null) continue;
            knnCollector.getSearchStrategy().nextVectorsBlock();
        }
        if (acceptDocsBits != null) {
            float unfilteredRatioVisited = (float)expectedDocs / (float)numVectors;
            int filteredVectors = (int)Math.ceil((float)numVectors * percentFiltered);
            float expectedScored = Math.min((float)(2 * filteredVectors) * unfilteredRatioVisited, (float)expectedDocs / 2.0f);
            while (centroidPrefetchingIterator.hasNext() && ((float)actualDocs < expectedScored || actualDocs < (long)knnCollector.k())) {
                CentroidOffsetAndLength offsetAndLength = centroidPrefetchingIterator.nextPostingListOffsetAndLength();
                scorer.resetPostingsScorer(offsetAndLength.offset());
                actualDocs += (long)scorer.visit(knnCollector);
                if (knnCollector.getSearchStrategy() == null) continue;
                knnCollector.getSearchStrategy().nextVectorsBlock();
            }
        }
    }

    @Override
    public final void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
        FieldInfo fieldInfo = this.state.fieldInfos.fieldInfo(field);
        ByteVectorValues values = this.getReaderForField(field).getByteVectorValues(field);
        for (int i = 0; i < values.size(); ++i) {
            float score = fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i));
            knnCollector.collect(values.ordToDoc(i), score);
            if (!knnCollector.earlyTerminated()) continue;
            return;
        }
    }

    @Override
    public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
        Map<String, Long> raw = this.getReaderForField(fieldInfo.name).getOffHeapByteSize(fieldInfo);
        FieldEntry fe = this.fields.get(fieldInfo.number);
        if (fe == null) {
            assert (fieldInfo.getVectorEncoding() == VectorEncoding.BYTE);
            return raw;
        }
        return raw;
    }

    @Override
    public void close() throws IOException {
        ArrayList<FlatVectorsReader> closeables = new ArrayList<FlatVectorsReader>(this.rawVectorReaders.values());
        Collections.addAll(closeables, this.ivfCentroids, this.ivfClusters);
        IOUtils.close(closeables);
    }

    abstract PostingVisitor getPostingVisitor(FieldInfo var1, IndexInput var2, float[] var3, Bits var4) throws IOException;

    protected record FieldEntry(String rawVectorFormatName, VectorSimilarityFunction similarityFunction, VectorEncoding vectorEncoding, int numCentroids, long centroidOffset, long centroidLength, long postingListOffset, long postingListLength, float[] globalCentroid, float globalCentroidDp) {
        IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
            return centroidFile.slice("centroids", this.centroidOffset, this.centroidLength);
        }

        IndexInput postingListSlice(IndexInput postingListFile) throws IOException {
            return postingListFile.slice("postingLists", this.postingListOffset, this.postingListLength);
        }
    }

    static interface CentroidIterator {
        public boolean hasNext();

        public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException;
    }

    static interface PostingVisitor {
        public int resetPostingsScorer(long var1) throws IOException;

        public int visit(KnnCollector var1) throws IOException;
    }

    record CentroidOffsetAndLength(long offset, long length) {
    }
}

