/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec.internal;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
import org.elasticsearch.simdvec.internal.Similarities;

/*
 * Multiple versions of this class in jar - see https://www.benf.org/other/cfr/multi-version-jar.html
 */
public abstract sealed class Int7SQVectorScorer
extends RandomVectorScorer.AbstractRandomVectorScorer {
    final int vectorByteSize;
    final MemorySegmentAccessInput input;
    final MemorySegment query;
    final float scoreCorrectionConstant;
    final float queryCorrection;
    byte[] scratch;

    public static Optional<RandomVectorScorer> create(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector) {
        Int7SQVectorScorer.checkDimensions(queryVector.length, values.dimension());
        IndexInput input = values.getSlice();
        if (input == null) {
            return Optional.empty();
        }
        if (!((input = FilterIndexInput.unwrapOnlyTest((IndexInput)input)) instanceof MemorySegmentAccessInput)) {
            return Optional.empty();
        }
        MemorySegmentAccessInput msInput = (MemorySegmentAccessInput)input;
        Int7SQVectorScorer.checkInvariants(values.size(), values.dimension(), input);
        ScalarQuantizer scalarQuantizer = values.getScalarQuantizer();
        byte[] quantizedQuery = new byte[queryVector.length];
        float queryCorrection = ScalarQuantizedVectorScorer.quantizeQuery((float[])queryVector, (byte[])quantizedQuery, (VectorSimilarityFunction)sim, (ScalarQuantizer)scalarQuantizer);
        return switch (sim) {
            default -> throw new MatchException(null, null);
            case VectorSimilarityFunction.COSINE, VectorSimilarityFunction.DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, quantizedQuery, queryCorrection));
            case VectorSimilarityFunction.EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, quantizedQuery, queryCorrection));
            case VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductScorer(msInput, values, quantizedQuery, queryCorrection));
        };
    }

    Int7SQVectorScorer(MemorySegmentAccessInput input, QuantizedByteVectorValues values, byte[] queryVector, float queryCorrection) {
        super((KnnVectorValues)values);
        this.input = input;
        assert (queryVector.length == values.getVectorByteLength());
        this.vectorByteSize = values.getVectorByteLength();
        this.query = MemorySegment.ofArray(queryVector);
        this.queryCorrection = queryCorrection;
        this.scoreCorrectionConstant = values.getScalarQuantizer().getConstantMultiplier();
    }

    final MemorySegment getSegment(int ord) throws IOException {
        this.checkOrdinal(ord);
        long byteOffset = (long)ord * (long)(this.vectorByteSize + 4);
        MemorySegment seg = this.input.segmentSliceOrNull(byteOffset, (long)this.vectorByteSize);
        if (seg == null) {
            if (this.scratch == null) {
                this.scratch = new byte[this.vectorByteSize];
            }
            this.input.readBytes(byteOffset, this.scratch, 0, this.vectorByteSize);
            seg = MemorySegment.ofArray(this.scratch);
        }
        return seg;
    }

    static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
        if (input.length() < (long)vectorByteLength * (long)maxOrd) {
            throw new IllegalArgumentException("input length is less than expected vector data");
        }
    }

    final void checkOrdinal(int ord) {
        if (ord < 0 || ord >= this.maxOrd()) {
            throw new IllegalArgumentException("illegal ordinal: " + ord);
        }
    }

    static void checkDimensions(int queryLen, int fieldLen) {
        if (queryLen != fieldLen) {
            throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
        }
    }

    public static final class DotProductScorer
    extends Int7SQVectorScorer {
        public DotProductScorer(MemorySegmentAccessInput in, QuantizedByteVectorValues values, byte[] query, float correction) {
            super(in, values, query, correction);
        }

        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            int dotProduct = Similarities.dotProduct7u(this.query, this.getSegment(node), this.vectorByteSize);
            assert (dotProduct >= 0);
            long byteOffset = (long)node * (long)(this.vectorByteSize + 4);
            float nodeCorrection = Float.intBitsToFloat(this.input.readInt(byteOffset + (long)this.vectorByteSize));
            float adjustedDistance = (float)dotProduct * this.scoreCorrectionConstant + this.queryCorrection + nodeCorrection;
            return Math.max((1.0f + adjustedDistance) / 2.0f, 0.0f);
        }
    }

    public static final class EuclideanScorer
    extends Int7SQVectorScorer {
        public EuclideanScorer(MemorySegmentAccessInput in, QuantizedByteVectorValues values, byte[] query, float correction) {
            super(in, values, query, correction);
        }

        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            int sqDist = Similarities.squareDistance7u(this.query, this.getSegment(node), this.vectorByteSize);
            float adjustedDistance = (float)sqDist * this.scoreCorrectionConstant;
            return 1.0f / (1.0f + adjustedDistance);
        }
    }

    public static final class MaxInnerProductScorer
    extends Int7SQVectorScorer {
        public MaxInnerProductScorer(MemorySegmentAccessInput in, QuantizedByteVectorValues values, byte[] query, float corr) {
            super(in, values, query, corr);
        }

        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            int dotProduct = Similarities.dotProduct7u(this.query, this.getSegment(node), this.vectorByteSize);
            assert (dotProduct >= 0);
            long byteOffset = (long)node * (long)(this.vectorByteSize + 4);
            float nodeCorrection = Float.intBitsToFloat(this.input.readInt(byteOffset + (long)this.vectorByteSize));
            float adjustedDistance = (float)dotProduct * this.scoreCorrectionConstant + this.queryCorrection + nodeCorrection;
            if (adjustedDistance < 0.0f) {
                return 1.0f / (1.0f + -1.0f * adjustedDistance);
            }
            return adjustedDistance + 1.0f;
        }
    }
}

