/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec;

import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;

public class ES91OSQVectorsScorer {
    public static final int BULK_SIZE = 16;
    protected static final float FOUR_BIT_SCALE = 0.06666667f;
    protected final IndexInput in;
    protected final int length;
    protected final int dimensions;
    protected final float[] lowerIntervals = new float[16];
    protected final float[] upperIntervals = new float[16];
    protected final int[] targetComponentSums = new int[16];
    protected final float[] additionalCorrections = new float[16];

    public ES91OSQVectorsScorer(IndexInput in, int dimensions) {
        this.in = in;
        this.dimensions = dimensions;
        this.length = OptimizedScalarQuantizer.discretize((int)dimensions, (int)64) / 8;
    }

    public long quantizeScore(byte[] q) throws IOException {
        int r;
        assert (q.length == this.length * 4);
        int size = this.length;
        long subRet0 = 0L;
        long subRet1 = 0L;
        long subRet2 = 0L;
        long subRet3 = 0L;
        int upperBound = size & 0xFFFFFFF8;
        for (r = 0; r < upperBound; r += 8) {
            long value = this.in.readLong();
            subRet0 += (long)Long.bitCount(BitUtil.VH_LE_LONG.get(q, r) & value);
            subRet1 += (long)Long.bitCount(BitUtil.VH_LE_LONG.get(q, r + size) & value);
            subRet2 += (long)Long.bitCount(BitUtil.VH_LE_LONG.get(q, r + 2 * size) & value);
            subRet3 += (long)Long.bitCount(BitUtil.VH_LE_LONG.get(q, r + 3 * size) & value);
        }
        upperBound = size & 0xFFFFFFFC;
        while (r < upperBound) {
            int value = this.in.readInt();
            subRet0 += (long)Integer.bitCount(BitUtil.VH_LE_INT.get(q, r) & value);
            subRet1 += (long)Integer.bitCount(BitUtil.VH_LE_INT.get(q, r + size) & value);
            subRet2 += (long)Integer.bitCount(BitUtil.VH_LE_INT.get(q, r + 2 * size) & value);
            subRet3 += (long)Integer.bitCount(BitUtil.VH_LE_INT.get(q, r + 3 * size) & value);
            r += 4;
        }
        while (r < size) {
            byte value = this.in.readByte();
            subRet0 += (long)Integer.bitCount(q[r] & value & 0xFF);
            subRet1 += (long)Integer.bitCount(q[r + size] & value & 0xFF);
            subRet2 += (long)Integer.bitCount(q[r + 2 * size] & value & 0xFF);
            subRet3 += (long)Integer.bitCount(q[r + 3 * size] & value & 0xFF);
            ++r;
        }
        return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
    }

    public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
        for (int i = 0; i < count; ++i) {
            scores[i] = this.quantizeScore(q);
        }
    }

    public float score(float queryLowerInterval, float queryUpperInterval, int queryComponentSum, float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float lowerInterval, float upperInterval, int targetComponentSum, float additionalCorrection, float qcDist) {
        float ax = lowerInterval;
        float lx = upperInterval - ax;
        float ay = queryLowerInterval;
        float ly = (queryUpperInterval - ay) * 0.06666667f;
        float y1 = queryComponentSum;
        float score = ax * ay * (float)this.dimensions + ay * lx * (float)targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
        if (similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
            score = queryAdditionalCorrection + additionalCorrection - 2.0f * score;
            return Math.max(1.0f / (1.0f + score), 0.0f);
        }
        score += queryAdditionalCorrection + additionalCorrection - centroidDp;
        if (similarityFunction == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
            return VectorUtil.scaleMaxInnerProductScore((float)score);
        }
        return Math.max((1.0f + score) / 2.0f, 0.0f);
    }

    public float scoreBulk(byte[] q, float queryLowerInterval, float queryUpperInterval, int queryComponentSum, float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores) throws IOException {
        this.quantizeScoreBulk(q, 16, scores);
        this.in.readFloats(this.lowerIntervals, 0, 16);
        this.in.readFloats(this.upperIntervals, 0, 16);
        for (int i = 0; i < 16; ++i) {
            this.targetComponentSums[i] = Short.toUnsignedInt(this.in.readShort());
        }
        this.in.readFloats(this.additionalCorrections, 0, 16);
        float maxScore = Float.NEGATIVE_INFINITY;
        for (int i = 0; i < 16; ++i) {
            scores[i] = this.score(queryLowerInterval, queryUpperInterval, queryComponentSum, queryAdditionalCorrection, similarityFunction, centroidDp, this.lowerIntervals[i], this.upperIntervals[i], this.targetComponentSums[i], this.additionalCorrections[i], scores[i]);
            if (!(scores[i] > maxScore)) continue;
            maxScore = scores[i];
        }
        return maxScore;
    }
}

