/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec;

import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.VectorUtil;

public class ES91Int4VectorsScorer {
    public static final int BULK_SIZE = 16;
    protected static final float FOUR_BIT_SCALE = 0.06666667f;
    protected final IndexInput in;
    protected final int dimensions;
    protected byte[] scratch;
    protected final float[] lowerIntervals = new float[16];
    protected final float[] upperIntervals = new float[16];
    protected final int[] targetComponentSums = new int[16];
    protected final float[] additionalCorrections = new float[16];

    public ES91Int4VectorsScorer(IndexInput in, int dimensions) {
        this.in = in;
        this.dimensions = dimensions;
        this.scratch = new byte[dimensions];
    }

    public long int4DotProduct(byte[] b) throws IOException {
        this.in.readBytes(this.scratch, 0, this.dimensions);
        int total = 0;
        for (int i = 0; i < this.dimensions; ++i) {
            total += this.scratch[i] * b[i];
        }
        return total;
    }

    public void int4DotProductBulk(byte[] b, int count, float[] scores) throws IOException {
        for (int i = 0; i < count; ++i) {
            scores[i] = this.int4DotProduct(b);
        }
    }

    public float score(byte[] q, float queryLowerInterval, float queryUpperInterval, int queryComponentSum, float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp) throws IOException {
        float score = this.int4DotProduct(q);
        this.in.readFloats(this.lowerIntervals, 0, 3);
        int addition = Short.toUnsignedInt(this.in.readShort());
        return this.applyCorrections(queryLowerInterval, queryUpperInterval, queryComponentSum, queryAdditionalCorrection, similarityFunction, centroidDp, this.lowerIntervals[0], this.lowerIntervals[1], addition, this.lowerIntervals[2], score);
    }

    public void scoreBulk(byte[] q, float queryLowerInterval, float queryUpperInterval, int queryComponentSum, float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores) throws IOException {
        int i;
        this.int4DotProductBulk(q, 16, scores);
        this.in.readFloats(this.lowerIntervals, 0, 16);
        this.in.readFloats(this.upperIntervals, 0, 16);
        for (i = 0; i < 16; ++i) {
            this.targetComponentSums[i] = Short.toUnsignedInt(this.in.readShort());
        }
        this.in.readFloats(this.additionalCorrections, 0, 16);
        for (i = 0; i < 16; ++i) {
            scores[i] = this.applyCorrections(queryLowerInterval, queryUpperInterval, queryComponentSum, queryAdditionalCorrection, similarityFunction, centroidDp, this.lowerIntervals[i], this.upperIntervals[i], this.targetComponentSums[i], this.additionalCorrections[i], scores[i]);
        }
    }

    public float applyCorrections(float queryLowerInterval, float queryUpperInterval, int queryComponentSum, float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float lowerInterval, float upperInterval, int targetComponentSum, float additionalCorrection, float qcDist) {
        float ax = lowerInterval;
        float lx = (upperInterval - ax) * 0.06666667f;
        float ay = queryLowerInterval;
        float ly = (queryUpperInterval - ay) * 0.06666667f;
        float y1 = queryComponentSum;
        float score = ax * ay * (float)this.dimensions + ay * lx * (float)targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
        if (similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
            score = queryAdditionalCorrection + additionalCorrection - 2.0f * score;
            return Math.max(1.0f / (1.0f + score), 0.0f);
        }
        score += queryAdditionalCorrection + additionalCorrection - centroidDp;
        if (similarityFunction == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
            return VectorUtil.scaleMaxInnerProductScore((float)score);
        }
        return Math.max((1.0f + score) / 2.0f, 0.0f);
    }
}

