/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.gpu.codec;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizer;

class MergedQuantizedVectorValues
extends QuantizedByteVectorValues {
    private static final float REQUANTIZATION_LIMIT = 0.2f;
    private final List<QuantizedByteVectorValueSub> subs;
    private final DocIDMerger<QuantizedByteVectorValueSub> docIdMerger;
    private final int size;
    private QuantizedByteVectorValueSub current;

    private MergedQuantizedVectorValues(List<QuantizedByteVectorValueSub> subs, MergeState mergeState) throws IOException {
        this.subs = subs;
        this.docIdMerger = DocIDMerger.of(subs, (boolean)mergeState.needsIndexSort);
        int totalSize = 0;
        for (QuantizedByteVectorValueSub sub : subs) {
            totalSize += sub.values.size();
        }
        this.size = totalSize;
    }

    public byte[] vectorValue(int ord) throws IOException {
        return this.current.values.vectorValue(this.current.index());
    }

    public KnnVectorValues.DocIndexIterator iterator() {
        return new CompositeIterator();
    }

    public int size() {
        return this.size;
    }

    public int dimension() {
        return this.subs.get((int)0).values.dimension();
    }

    public float getScoreCorrectionConstant(int ord) throws IOException {
        return this.current.values.getScoreCorrectionConstant(this.current.index());
    }

    private static QuantizedVectorsReader getQuantizedKnnVectorsReader(KnnVectorsReader vectorsReader, String fieldName) {
        if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader) {
            PerFieldKnnVectorsFormat.FieldsReader candidateReader = (PerFieldKnnVectorsFormat.FieldsReader)vectorsReader;
            vectorsReader = candidateReader.getFieldReader(fieldName);
        }
        if (vectorsReader instanceof QuantizedVectorsReader) {
            QuantizedVectorsReader reader = (QuantizedVectorsReader)vectorsReader;
            return reader;
        }
        return null;
    }

    static MergedQuantizedVectorValues mergeQuantizedByteVectorValues(FieldInfo fieldInfo, MergeState mergeState, ScalarQuantizer scalarQuantizer) throws IOException {
        assert (fieldInfo != null && fieldInfo.hasVectorValues());
        ArrayList<QuantizedByteVectorValueSub> subs = new ArrayList<QuantizedByteVectorValueSub>();
        for (int i = 0; i < mergeState.knnVectorsReaders.length; ++i) {
            QuantizedByteVectorValueSub sub;
            if (!KnnVectorsWriter.MergedVectorValues.hasVectorValues((FieldInfos)mergeState.fieldInfos[i], (String)fieldInfo.name)) continue;
            QuantizedVectorsReader reader = MergedQuantizedVectorValues.getQuantizedKnnVectorsReader(mergeState.knnVectorsReaders[i], fieldInfo.name);
            assert (scalarQuantizer != null);
            if (reader == null || reader.getQuantizationState(fieldInfo.name) == null || scalarQuantizer.getBits() <= 4 || MergedQuantizedVectorValues.shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
                FloatVectorValues toQuantize = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name);
                if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) {
                    toQuantize = new NormalizedFloatVectorValues(toQuantize);
                }
                sub = new QuantizedByteVectorValueSub(mergeState.docMaps[i], new QuantizedFloatVectorValues(toQuantize, fieldInfo.getVectorSimilarityFunction(), scalarQuantizer));
            } else {
                sub = new QuantizedByteVectorValueSub(mergeState.docMaps[i], new OffsetCorrectedQuantizedByteVectorValues(reader.getQuantizedVectorValues(fieldInfo.name), fieldInfo.getVectorSimilarityFunction(), scalarQuantizer, reader.getQuantizationState(fieldInfo.name)));
            }
            subs.add(sub);
        }
        return new MergedQuantizedVectorValues(subs, mergeState);
    }

    private static boolean shouldRequantize(ScalarQuantizer existingQuantiles, ScalarQuantizer newQuantiles) {
        float tol = 0.2f * (newQuantiles.getUpperQuantile() - newQuantiles.getLowerQuantile()) / 128.0f;
        if (Math.abs(existingQuantiles.getUpperQuantile() - newQuantiles.getUpperQuantile()) > tol) {
            return true;
        }
        return Math.abs(existingQuantiles.getLowerQuantile() - newQuantiles.getLowerQuantile()) > tol;
    }

    private static class QuantizedByteVectorValueSub
    extends DocIDMerger.Sub {
        private final QuantizedByteVectorValues values;
        private final KnnVectorValues.DocIndexIterator iterator;

        QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) {
            super(docMap);
            this.values = values;
            this.iterator = values.iterator();
            assert (this.iterator.docID() == -1);
        }

        public int nextDoc() throws IOException {
            return this.iterator.nextDoc();
        }

        public int index() {
            return this.iterator.index();
        }
    }

    private class CompositeIterator
    extends KnnVectorValues.DocIndexIterator {
        private int docId = -1;
        private int ord = -1;

        CompositeIterator() {
        }

        public int index() {
            return this.ord;
        }

        public int docID() {
            return this.docId;
        }

        public int nextDoc() throws IOException {
            MergedQuantizedVectorValues.this.current = (QuantizedByteVectorValueSub)MergedQuantizedVectorValues.this.docIdMerger.next();
            if (MergedQuantizedVectorValues.this.current == null) {
                this.docId = Integer.MAX_VALUE;
                this.ord = Integer.MAX_VALUE;
            } else {
                this.docId = MergedQuantizedVectorValues.this.current.mappedDocID;
                ++this.ord;
            }
            return this.docId;
        }

        public int advance(int target) throws IOException {
            throw new UnsupportedOperationException();
        }

        public long cost() {
            return MergedQuantizedVectorValues.this.size;
        }
    }

    private static final class NormalizedFloatVectorValues
    extends FloatVectorValues {
        private final FloatVectorValues values;
        private final float[] normalizedVector;

        NormalizedFloatVectorValues(FloatVectorValues values) {
            this.values = values;
            this.normalizedVector = new float[values.dimension()];
        }

        public int dimension() {
            return this.values.dimension();
        }

        public int size() {
            return this.values.size();
        }

        public int ordToDoc(int ord) {
            return this.values.ordToDoc(ord);
        }

        public float[] vectorValue(int ord) throws IOException {
            System.arraycopy(this.values.vectorValue(ord), 0, this.normalizedVector, 0, this.normalizedVector.length);
            VectorUtil.l2normalize((float[])this.normalizedVector);
            return this.normalizedVector;
        }

        public KnnVectorValues.DocIndexIterator iterator() {
            return this.values.iterator();
        }

        public NormalizedFloatVectorValues copy() throws IOException {
            return new NormalizedFloatVectorValues(this.values.copy());
        }
    }

    private static class QuantizedFloatVectorValues
    extends QuantizedByteVectorValues {
        private final FloatVectorValues values;
        private final ScalarQuantizer quantizer;
        private final byte[] quantizedVector;
        private int lastOrd = -1;
        private float offsetValue = 0.0f;
        private final VectorSimilarityFunction vectorSimilarityFunction;

        QuantizedFloatVectorValues(FloatVectorValues values, VectorSimilarityFunction vectorSimilarityFunction, ScalarQuantizer quantizer) {
            this.values = values;
            this.quantizer = quantizer;
            this.quantizedVector = new byte[values.dimension()];
            this.vectorSimilarityFunction = vectorSimilarityFunction;
        }

        public float getScoreCorrectionConstant(int ord) {
            if (ord != this.lastOrd) {
                throw new IllegalStateException("attempt to retrieve score correction for different ord " + ord + " than the quantization was done for: " + this.lastOrd);
            }
            return this.offsetValue;
        }

        public int dimension() {
            return this.values.dimension();
        }

        public int size() {
            return this.values.size();
        }

        public byte[] vectorValue(int ord) throws IOException {
            if (ord != this.lastOrd) {
                this.offsetValue = this.quantize(ord);
                this.lastOrd = ord;
            }
            return this.quantizedVector;
        }

        public VectorScorer scorer(float[] target) throws IOException {
            throw new UnsupportedOperationException();
        }

        private float quantize(int ord) throws IOException {
            return this.quantizer.quantize(this.values.vectorValue(ord), this.quantizedVector, this.vectorSimilarityFunction);
        }

        public int ordToDoc(int ord) {
            return this.values.ordToDoc(ord);
        }

        public KnnVectorValues.DocIndexIterator iterator() {
            return this.values.iterator();
        }
    }

    private static final class OffsetCorrectedQuantizedByteVectorValues
    extends QuantizedByteVectorValues {
        private final QuantizedByteVectorValues in;
        private final VectorSimilarityFunction vectorSimilarityFunction;
        private final ScalarQuantizer scalarQuantizer;
        private final ScalarQuantizer oldScalarQuantizer;

        OffsetCorrectedQuantizedByteVectorValues(QuantizedByteVectorValues in, VectorSimilarityFunction vectorSimilarityFunction, ScalarQuantizer scalarQuantizer, ScalarQuantizer oldScalarQuantizer) {
            this.in = in;
            this.vectorSimilarityFunction = vectorSimilarityFunction;
            this.scalarQuantizer = scalarQuantizer;
            this.oldScalarQuantizer = oldScalarQuantizer;
        }

        public float getScoreCorrectionConstant(int ord) throws IOException {
            return this.scalarQuantizer.recalculateCorrectiveOffset(this.in.vectorValue(ord), this.oldScalarQuantizer, this.vectorSimilarityFunction);
        }

        public int dimension() {
            return this.in.dimension();
        }

        public int size() {
            return this.in.size();
        }

        public byte[] vectorValue(int ord) throws IOException {
            return this.in.vectorValue(ord);
        }

        public int ordToDoc(int ord) {
            return this.in.ordToDoc(ord);
        }

        public KnnVectorValues.DocIndexIterator iterator() {
            return this.in.iterator();
        }
    }
}

