/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec.internal.vectorization;

import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.simdvec.internal.vectorization.ByteArrayUtils;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;

final class DefaultESVectorUtilSupport
implements ESVectorUtilSupport {
    private static float fma(float a, float b, float c) {
        if (Constants.HAS_FAST_SCALAR_FMA) {
            return Math.fma(a, b, c);
        }
        return a * b + c;
    }

    DefaultESVectorUtilSupport() {
    }

    @Override
    public long ipByteBinByte(byte[] q, byte[] d) {
        return DefaultESVectorUtilSupport.ipByteBinByteImpl(q, d);
    }

    @Override
    public int ipByteBit(byte[] q, byte[] d) {
        return DefaultESVectorUtilSupport.ipByteBitImpl(q, d);
    }

    @Override
    public float ipFloatBit(float[] q, byte[] d) {
        return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
    }

    @Override
    public float ipFloatByte(float[] q, byte[] d) {
        return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
    }

    @Override
    public float calculateOSQLoss(float[] target, float low, float high, float step, float invStep, float norm2, float lambda, int[] quantize) {
        float a = low;
        float b = high;
        float xe = 0.0f;
        float e = 0.0f;
        for (int i = 0; i < target.length; ++i) {
            float xi = target[i];
            quantize[i] = Math.round((Math.min(Math.max(xi, a), b) - a) * invStep);
            float xiq = DefaultESVectorUtilSupport.fma(step, quantize[i], a);
            float xiiq = xi - xiq;
            e = DefaultESVectorUtilSupport.fma(xiiq, xiiq, e);
            xe = DefaultESVectorUtilSupport.fma(xi, xiiq, xe);
        }
        return (1.0f - lambda) * xe * xe / norm2 + lambda * e;
    }

    @Override
    public void calculateOSQGridPoints(float[] target, int[] quantize, int points, float[] pts) {
        float daa = 0.0f;
        float dab = 0.0f;
        float dbb = 0.0f;
        float dax = 0.0f;
        float dbx = 0.0f;
        float invPmOnes = 1.0f / ((float)points - 1.0f);
        for (int i = 0; i < target.length; ++i) {
            float v = target[i];
            float k = quantize[i];
            float s = k * invPmOnes;
            float ms = 1.0f - s;
            daa = DefaultESVectorUtilSupport.fma(ms, ms, daa);
            dab = DefaultESVectorUtilSupport.fma(ms, s, dab);
            dbb = DefaultESVectorUtilSupport.fma(s, s, dbb);
            dax = DefaultESVectorUtilSupport.fma(ms, v, dax);
            dbx = DefaultESVectorUtilSupport.fma(s, v, dbx);
        }
        pts[0] = daa;
        pts[1] = dab;
        pts[2] = dbb;
        pts[3] = dax;
        pts[4] = dbx;
    }

    @Override
    public void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
        float vecMean = 0.0f;
        float vecVar = 0.0f;
        float norm2 = 0.0f;
        float min = Float.MAX_VALUE;
        float max = -3.4028235E38f;
        for (int i = 0; i < target.length; ++i) {
            centered[i] = target[i] - centroid[i];
            min = Math.min(min, centered[i]);
            max = Math.max(max, centered[i]);
            norm2 = DefaultESVectorUtilSupport.fma(centered[i], centered[i], norm2);
            float delta = centered[i] - vecMean;
            float delta2 = centered[i] - (vecMean += delta / (float)(i + 1));
            vecVar = DefaultESVectorUtilSupport.fma(delta, delta2, vecVar);
        }
        stats[0] = vecMean;
        stats[1] = vecVar / (float)target.length;
        stats[2] = norm2;
        stats[3] = min;
        stats[4] = max;
    }

    @Override
    public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
        float vecMean = 0.0f;
        float vecVar = 0.0f;
        float norm2 = 0.0f;
        float centroidDot = 0.0f;
        float min = Float.MAX_VALUE;
        float max = -3.4028235E38f;
        for (int i = 0; i < target.length; ++i) {
            centroidDot = DefaultESVectorUtilSupport.fma(target[i], centroid[i], centroidDot);
            centered[i] = target[i] - centroid[i];
            min = Math.min(min, centered[i]);
            max = Math.max(max, centered[i]);
            norm2 = DefaultESVectorUtilSupport.fma(centered[i], centered[i], norm2);
            float delta = centered[i] - vecMean;
            float delta2 = centered[i] - (vecMean += delta / (float)(i + 1));
            vecVar = DefaultESVectorUtilSupport.fma(delta, delta2, vecVar);
        }
        stats[0] = vecMean;
        stats[1] = vecVar / (float)target.length;
        stats[2] = norm2;
        stats[3] = min;
        stats[4] = max;
        stats[5] = centroidDot;
    }

    @Override
    public float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) {
        assert (v1.length == centroid.length);
        assert (v1.length == originalResidual.length);
        float dsq = VectorUtil.squareDistance((float[])v1, (float[])centroid);
        float proj = 0.0f;
        for (int i = 0; i < v1.length; ++i) {
            float djk = v1[i] - centroid[i];
            proj = DefaultESVectorUtilSupport.fma(djk, originalResidual[i], proj);
        }
        return dsq + soarLambda * proj * proj / rnorm;
    }

    public static int ipByteBitImpl(byte[] q, byte[] d) {
        return DefaultESVectorUtilSupport.ipByteBitImpl(q, d, 0);
    }

    public static int ipByteBitImpl(byte[] q, byte[] d, int start) {
        assert (q.length == d.length * 8);
        int acc0 = 0;
        int acc1 = 0;
        int acc2 = 0;
        int acc3 = 0;
        for (int i = start; i < d.length; ++i) {
            byte mask = d[i];
            acc0 += q[i * 8 + 0] * (mask >> 7 & 1);
            acc1 += q[i * 8 + 1] * (mask >> 6 & 1);
            acc2 += q[i * 8 + 2] * (mask >> 5 & 1);
            acc3 += q[i * 8 + 3] * (mask >> 4 & 1);
            acc0 += q[i * 8 + 4] * (mask >> 3 & 1);
            acc1 += q[i * 8 + 5] * (mask >> 2 & 1);
            acc2 += q[i * 8 + 6] * (mask >> 1 & 1);
            acc3 += q[i * 8 + 7] * (mask >> 0 & 1);
        }
        return acc0 + acc1 + acc2 + acc3;
    }

    public static float ipFloatBitImpl(float[] q, byte[] d) {
        return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, 0);
    }

    static float ipFloatBitImpl(float[] q, byte[] d, int start) {
        assert (q.length == d.length * 8);
        float acc0 = 0.0f;
        float acc1 = 0.0f;
        float acc2 = 0.0f;
        float acc3 = 0.0f;
        for (int i = start; i < d.length; ++i) {
            byte mask = d[i];
            acc0 = DefaultESVectorUtilSupport.fma(q[i * 8 + 0], mask >> 7 & 1, acc0);
            acc1 = DefaultESVectorUtilSupport.fma(q[i * 8 + 1], mask >> 6 & 1, acc1);
            acc2 = DefaultESVectorUtilSupport.fma(q[i * 8 + 2], mask >> 5 & 1, acc2);
            acc3 = DefaultESVectorUtilSupport.fma(q[i * 8 + 3], mask >> 4 & 1, acc3);
            acc0 = DefaultESVectorUtilSupport.fma(q[i * 8 + 4], mask >> 3 & 1, acc0);
            acc1 = DefaultESVectorUtilSupport.fma(q[i * 8 + 5], mask >> 2 & 1, acc1);
            acc2 = DefaultESVectorUtilSupport.fma(q[i * 8 + 6], mask >> 1 & 1, acc2);
            acc3 = DefaultESVectorUtilSupport.fma(q[i * 8 + 7], mask >> 0 & 1, acc3);
        }
        return acc0 + acc1 + acc2 + acc3;
    }

    public static long ipByteBinByteImpl(byte[] q, byte[] d) {
        long ret = 0L;
        int size = d.length;
        for (int s = 0; s < 4; ++s) {
            int r;
            long stripeRet = 0L;
            int upperBound = d.length & 0xFFFFFFFC;
            for (r = 0; r < upperBound; r += 4) {
                stripeRet += (long)Integer.bitCount(BitUtil.VH_NATIVE_INT.get(q, s * size + r) & BitUtil.VH_NATIVE_INT.get(d, r));
            }
            while (r < d.length) {
                stripeRet += (long)Integer.bitCount(q[s * size + r] & d[r] & 0xFF);
                ++r;
            }
            ret += stripeRet << s;
        }
        return ret;
    }

    public static float ipFloatByteImpl(float[] q, byte[] d) {
        float ret = 0.0f;
        for (int i = 0; i < q.length; ++i) {
            ret += q[i] * (float)d[i];
        }
        return ret;
    }

    @Override
    public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
        float nSteps = (1 << bits) - 1;
        float invStep = nSteps / (upperInterval - lowInterval);
        int sumQuery = 0;
        for (int h = 0; h < vector.length; ++h) {
            float xi = Math.min(Math.max(vector[h], lowInterval), upperInterval);
            int assignment = Math.round((xi - lowInterval) * invStep);
            sumQuery += assignment;
            destination[h] = assignment;
        }
        return sumQuery;
    }

    @Override
    public void squareDistanceBulk(float[] query, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) {
        distances[0] = VectorUtil.squareDistance((float[])query, (float[])v0);
        distances[1] = VectorUtil.squareDistance((float[])query, (float[])v1);
        distances[2] = VectorUtil.squareDistance((float[])query, (float[])v2);
        distances[3] = VectorUtil.squareDistance((float[])query, (float[])v3);
    }

    @Override
    public void soarDistanceBulk(float[] v1, float[] c0, float[] c1, float[] c2, float[] c3, float[] originalResidual, float soarLambda, float rnorm, float[] distances) {
        distances[0] = this.soarDistance(v1, c0, originalResidual, soarLambda, rnorm);
        distances[1] = this.soarDistance(v1, c1, originalResidual, soarLambda, rnorm);
        distances[2] = this.soarDistance(v1, c2, originalResidual, soarLambda, rnorm);
        distances[3] = this.soarDistance(v1, c3, originalResidual, soarLambda, rnorm);
    }

    @Override
    public void packDibit(int[] vector, byte[] packed) {
        DefaultESVectorUtilSupport.packDibitImpl(vector, packed);
    }

    @Override
    public void packAsBinary(int[] vector, byte[] packed) {
        DefaultESVectorUtilSupport.packAsBinaryImpl(vector, packed);
    }

    public static void packDibitImpl(int[] vector, byte[] packed) {
        int upperByte;
        int lowerByte;
        int limit = vector.length - 7;
        int i = 0;
        int index = 0;
        while (i < limit) {
            assert (vector[i] >= 0 && vector[i] <= 3);
            assert (vector[i + 1] >= 0 && vector[i + 1] <= 3);
            assert (vector[i + 2] >= 0 && vector[i + 2] <= 3);
            assert (vector[i + 3] >= 0 && vector[i + 3] <= 3);
            assert (vector[i + 4] >= 0 && vector[i + 4] <= 3);
            assert (vector[i + 5] >= 0 && vector[i + 5] <= 3);
            assert (vector[i + 6] >= 0 && vector[i + 6] <= 3);
            assert (vector[i + 7] >= 0 && vector[i + 7] <= 3);
            lowerByte = (vector[i] & 1) << 7 | (vector[i + 1] & 1) << 6 | (vector[i + 2] & 1) << 5 | (vector[i + 3] & 1) << 4 | (vector[i + 4] & 1) << 3 | (vector[i + 5] & 1) << 2 | (vector[i + 6] & 1) << 1 | vector[i + 7] & 1;
            upperByte = (vector[i] >> 1 & 1) << 7 | (vector[i + 1] >> 1 & 1) << 6 | (vector[i + 2] >> 1 & 1) << 5 | (vector[i + 3] >> 1 & 1) << 4 | (vector[i + 4] >> 1 & 1) << 3 | (vector[i + 5] >> 1 & 1) << 2 | (vector[i + 6] >> 1 & 1) << 1 | vector[i + 7] >> 1 & 1;
            packed[index] = (byte)lowerByte;
            packed[index + packed.length / 2] = (byte)upperByte;
            i += 8;
            ++index;
        }
        if (i == vector.length) {
            return;
        }
        lowerByte = 0;
        upperByte = 0;
        int j = 7;
        while (i < vector.length) {
            assert (vector[i] >= 0 && vector[i] <= 3);
            lowerByte |= (vector[i] & 1) << j;
            upperByte |= (vector[i] >> 1 & 1) << j;
            --j;
            ++i;
        }
        packed[index] = (byte)lowerByte;
        packed[index + packed.length / 2] = (byte)upperByte;
    }

    public static void packAsBinaryImpl(int[] vector, byte[] packed) {
        int result;
        int limit = vector.length - 7;
        int i = 0;
        int index = 0;
        while (i < limit) {
            assert (vector[i] == 0 || vector[i] == 1);
            assert (vector[i + 1] == 0 || vector[i + 1] == 1);
            assert (vector[i + 2] == 0 || vector[i + 2] == 1);
            assert (vector[i + 3] == 0 || vector[i + 3] == 1);
            assert (vector[i + 4] == 0 || vector[i + 4] == 1);
            assert (vector[i + 5] == 0 || vector[i + 5] == 1);
            assert (vector[i + 6] == 0 || vector[i + 6] == 1);
            assert (vector[i + 7] == 0 || vector[i + 7] == 1);
            result = vector[i] << 7 | vector[i + 1] << 6 | vector[i + 2] << 5 | vector[i + 3] << 4 | vector[i + 4] << 3 | vector[i + 5] << 2 | vector[i + 6] << 1 | vector[i + 7];
            packed[index] = (byte)result;
            i += 8;
            ++index;
        }
        if (i == vector.length) {
            return;
        }
        result = 0;
        for (int j = 7; j >= 0 && i < vector.length; ++i, --j) {
            assert (vector[i] == 0 || vector[i] == 1);
            result = (byte)(result | (byte)((vector[i] & 1) << j));
        }
        packed[index] = result;
    }

    @Override
    public void transposeHalfByte(int[] q, byte[] quantQueryByte) {
        DefaultESVectorUtilSupport.transposeHalfByteImpl(q, quantQueryByte);
    }

    public static void transposeHalfByteImpl(int[] q, byte[] quantQueryByte) {
        int upperByte;
        int upperMiddleByte;
        int lowerMiddleByte;
        int lowerByte;
        int limit = q.length - 7;
        int i = 0;
        int index = 0;
        while (i < limit) {
            assert (q[i] >= 0 && q[i] <= 15);
            assert (q[i + 1] >= 0 && q[i + 1] <= 15);
            assert (q[i + 2] >= 0 && q[i + 2] <= 15);
            assert (q[i + 3] >= 0 && q[i + 3] <= 15);
            assert (q[i + 4] >= 0 && q[i + 4] <= 15);
            assert (q[i + 5] >= 0 && q[i + 5] <= 15);
            assert (q[i + 6] >= 0 && q[i + 6] <= 15);
            assert (q[i + 7] >= 0 && q[i + 7] <= 15);
            lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i + 5] & 1) << 2 | (q[i + 6] & 1) << 1 | q[i + 7] & 1;
            lowerMiddleByte = (q[i] >> 1 & 1) << 7 | (q[i + 1] >> 1 & 1) << 6 | (q[i + 2] >> 1 & 1) << 5 | (q[i + 3] >> 1 & 1) << 4 | (q[i + 4] >> 1 & 1) << 3 | (q[i + 5] >> 1 & 1) << 2 | (q[i + 6] >> 1 & 1) << 1 | q[i + 7] >> 1 & 1;
            upperMiddleByte = (q[i] >> 2 & 1) << 7 | (q[i + 1] >> 2 & 1) << 6 | (q[i + 2] >> 2 & 1) << 5 | (q[i + 3] >> 2 & 1) << 4 | (q[i + 4] >> 2 & 1) << 3 | (q[i + 5] >> 2 & 1) << 2 | (q[i + 6] >> 2 & 1) << 1 | q[i + 7] >> 2 & 1;
            upperByte = (q[i] >> 3 & 1) << 7 | (q[i + 1] >> 3 & 1) << 6 | (q[i + 2] >> 3 & 1) << 5 | (q[i + 3] >> 3 & 1) << 4 | (q[i + 4] >> 3 & 1) << 3 | (q[i + 5] >> 3 & 1) << 2 | (q[i + 6] >> 3 & 1) << 1 | q[i + 7] >> 3 & 1;
            quantQueryByte[index] = (byte)lowerByte;
            quantQueryByte[index + quantQueryByte.length / 4] = (byte)lowerMiddleByte;
            quantQueryByte[index + quantQueryByte.length / 2] = (byte)upperMiddleByte;
            quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte)upperByte;
            i += 8;
            ++index;
        }
        if (i == q.length) {
            return;
        }
        lowerByte = 0;
        lowerMiddleByte = 0;
        upperMiddleByte = 0;
        upperByte = 0;
        int j = 7;
        while (i < q.length) {
            lowerByte |= (q[i] & 1) << j;
            lowerMiddleByte |= (q[i] >> 1 & 1) << j;
            upperMiddleByte |= (q[i] >> 2 & 1) << j;
            upperByte |= (q[i] >> 3 & 1) << j;
            --j;
            ++i;
        }
        quantQueryByte[index] = (byte)lowerByte;
        quantQueryByte[index + quantQueryByte.length / 4] = (byte)lowerMiddleByte;
        quantQueryByte[index + quantQueryByte.length / 2] = (byte)upperMiddleByte;
        quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte)upperByte;
    }

    @Override
    public int indexOf(byte[] bytes, int offset, int length, byte marker) {
        return ByteArrayUtils.indexOf(bytes, offset, length, marker);
    }
}

