/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec.internal.vectorization;

import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.LongVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.util.Constants;
import org.elasticsearch.simdvec.internal.vectorization.DefaultESVectorUtilSupport;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;

public final class PanamaESVectorUtilSupport
implements ESVectorUtilSupport {
    static final int VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
    private static final VectorSpecies<Float> FLOAT_SPECIES = VectorSpecies.of(Float.TYPE, (VectorShape)VectorShape.forBitSize((int)VECTOR_BITSIZE));
    static final boolean HAS_FAST_INTEGER_VECTORS;
    private static final VectorSpecies<Byte> BYTE_SPECIES_128;
    private static final VectorSpecies<Byte> BYTE_SPECIES_256;

    private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
        if (Constants.HAS_FAST_VECTOR_FMA) {
            return a.fma((Vector)b, (Vector)c);
        }
        return a.mul((Vector)b).add((Vector)c);
    }

    private static float fma(float a, float b, float c) {
        if (Constants.HAS_FAST_SCALAR_FMA) {
            return Math.fma(a, b, c);
        }
        return a * b + c;
    }

    @Override
    public long ipByteBinByte(byte[] q, byte[] d) {
        if (d.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
            if (VECTOR_BITSIZE >= 256) {
                return PanamaESVectorUtilSupport.ipByteBin256(q, d);
            }
            if (VECTOR_BITSIZE == 128) {
                return PanamaESVectorUtilSupport.ipByteBin128(q, d);
            }
        }
        return DefaultESVectorUtilSupport.ipByteBinByteImpl(q, d);
    }

    @Override
    public int ipByteBit(byte[] q, byte[] d) {
        return DefaultESVectorUtilSupport.ipByteBitImpl(q, d);
    }

    @Override
    public float ipFloatBit(float[] q, byte[] d) {
        return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
    }

    @Override
    public void centerAndCalculateOSQStatsEuclidean(float[] vector, float[] centroid, float[] centered, float[] stats) {
        int i;
        assert (vector.length == centroid.length);
        assert (vector.length == centered.length);
        float vecMean = 0.0f;
        float vecVar = 0.0f;
        float norm2 = 0.0f;
        float min = Float.MAX_VALUE;
        float max = -3.4028235E38f;
        int vectCount = 0;
        if (vector.length > 2 * FLOAT_SPECIES.length()) {
            FloatVector vecMeanVec = FloatVector.zero(FLOAT_SPECIES);
            FloatVector m2Vec = FloatVector.zero(FLOAT_SPECIES);
            FloatVector norm2Vec = FloatVector.zero(FLOAT_SPECIES);
            FloatVector minVec = FloatVector.broadcast(FLOAT_SPECIES, (float)Float.MAX_VALUE);
            FloatVector maxVec = FloatVector.broadcast(FLOAT_SPECIES, (float)-3.4028235E38f);
            int count = 0;
            for (i = 0; i < FLOAT_SPECIES.loopBound(vector.length); i += FLOAT_SPECIES.length()) {
                ++count;
                FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, (float[])vector, (int)i);
                FloatVector c = FloatVector.fromArray(FLOAT_SPECIES, (float[])centroid, (int)i);
                FloatVector centeredVec = v.sub((Vector)c);
                FloatVector deltaVec = centeredVec.sub((Vector)vecMeanVec);
                norm2Vec = PanamaESVectorUtilSupport.fma(centeredVec, centeredVec, norm2Vec);
                vecMeanVec = vecMeanVec.add((Vector)deltaVec.div((float)count));
                FloatVector delta2Vec = centeredVec.sub((Vector)vecMeanVec);
                m2Vec = PanamaESVectorUtilSupport.fma(deltaVec, delta2Vec, m2Vec);
                minVec = minVec.min((Vector)centeredVec);
                maxVec = maxVec.max((Vector)centeredVec);
                centeredVec.intoArray(centered, i);
            }
            min = minVec.reduceLanes(VectorOperators.MIN);
            max = maxVec.reduceLanes(VectorOperators.MAX);
            norm2 = norm2Vec.reduceLanes(VectorOperators.ADD);
            vecMean = vecMeanVec.reduceLanes(VectorOperators.ADD) / (float)FLOAT_SPECIES.length();
            FloatVector d2Mean = vecMeanVec.sub(vecMean);
            m2Vec = PanamaESVectorUtilSupport.fma(d2Mean, d2Mean, m2Vec);
            vectCount = count * FLOAT_SPECIES.length();
            vecVar = m2Vec.reduceLanes(VectorOperators.ADD);
        }
        float tailMean = 0.0f;
        float tailM2 = 0.0f;
        int tailCount = 0;
        while (i < vector.length) {
            centered[i] = vector[i] - centroid[i];
            float delta = centered[i] - tailMean;
            float delta2 = centered[i] - (tailMean += delta / (float)(++tailCount));
            tailM2 = PanamaESVectorUtilSupport.fma(delta, delta2, tailM2);
            min = Math.min(min, centered[i]);
            max = Math.max(max, centered[i]);
            norm2 = PanamaESVectorUtilSupport.fma(centered[i], centered[i], norm2);
            ++i;
        }
        if (vectCount == 0) {
            vecMean = tailMean;
            vecVar = tailM2;
        } else if (tailCount > 0) {
            int totalCount = tailCount + vectCount;
            assert (totalCount == vector.length);
            float alpha = (float)vectCount / (float)totalCount;
            float beta = 1.0f - alpha;
            float completeMean = alpha * vecMean + beta * tailMean;
            float dMean2Lhs = (vecMean - completeMean) * (vecMean - completeMean);
            float dMean2Rhs = (tailMean - completeMean) * (tailMean - completeMean);
            vecVar = vecVar + dMean2Lhs + beta * (tailM2 + dMean2Rhs);
            vecMean = completeMean;
        }
        stats[0] = vecMean;
        stats[1] = vecVar / (float)vector.length;
        stats[2] = norm2;
        stats[3] = min;
        stats[4] = max;
    }

    @Override
    public void centerAndCalculateOSQStatsDp(float[] vector, float[] centroid, float[] centered, float[] stats) {
        int i;
        assert (vector.length == centroid.length);
        assert (vector.length == centered.length);
        float vecMean = 0.0f;
        float vecVar = 0.0f;
        float norm2 = 0.0f;
        float min = Float.MAX_VALUE;
        float max = -3.4028235E38f;
        float centroidDot = 0.0f;
        int vectCount = 0;
        int loopBound = FLOAT_SPECIES.loopBound(vector.length);
        if (vector.length > 2 * FLOAT_SPECIES.length()) {
            FloatVector vecMeanVec = FloatVector.zero(FLOAT_SPECIES);
            FloatVector m2Vec = FloatVector.zero(FLOAT_SPECIES);
            FloatVector norm2Vec = FloatVector.zero(FLOAT_SPECIES);
            FloatVector minVec = FloatVector.broadcast(FLOAT_SPECIES, (float)Float.MAX_VALUE);
            FloatVector maxVec = FloatVector.broadcast(FLOAT_SPECIES, (float)-3.4028235E38f);
            FloatVector centroidDotVec = FloatVector.zero(FLOAT_SPECIES);
            int count = 0;
            for (i = 0; i < loopBound; i += FLOAT_SPECIES.length()) {
                ++count;
                FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, (float[])vector, (int)i);
                FloatVector c = FloatVector.fromArray(FLOAT_SPECIES, (float[])centroid, (int)i);
                centroidDotVec = PanamaESVectorUtilSupport.fma(v, c, centroidDotVec);
                FloatVector centeredVec = v.sub((Vector)c);
                FloatVector deltaVec = centeredVec.sub((Vector)vecMeanVec);
                norm2Vec = PanamaESVectorUtilSupport.fma(centeredVec, centeredVec, norm2Vec);
                vecMeanVec = vecMeanVec.add((Vector)deltaVec.div((float)count));
                FloatVector delta2Vec = centeredVec.sub((Vector)vecMeanVec);
                m2Vec = PanamaESVectorUtilSupport.fma(deltaVec, delta2Vec, m2Vec);
                minVec = minVec.min((Vector)centeredVec);
                maxVec = maxVec.max((Vector)centeredVec);
                centeredVec.intoArray(centered, i);
            }
            min = minVec.reduceLanes(VectorOperators.MIN);
            max = maxVec.reduceLanes(VectorOperators.MAX);
            norm2 = norm2Vec.reduceLanes(VectorOperators.ADD);
            centroidDot = centroidDotVec.reduceLanes(VectorOperators.ADD);
            vecMean = vecMeanVec.reduceLanes(VectorOperators.ADD) / (float)FLOAT_SPECIES.length();
            FloatVector d2Mean = vecMeanVec.sub(vecMean);
            m2Vec = PanamaESVectorUtilSupport.fma(d2Mean, d2Mean, m2Vec);
            vectCount = count * FLOAT_SPECIES.length();
            vecVar = m2Vec.reduceLanes(VectorOperators.ADD);
        }
        float tailMean = 0.0f;
        float tailM2 = 0.0f;
        int tailCount = 0;
        while (i < vector.length) {
            centroidDot = PanamaESVectorUtilSupport.fma(vector[i], centroid[i], centroidDot);
            centered[i] = vector[i] - centroid[i];
            float delta = centered[i] - tailMean;
            float delta2 = centered[i] - (tailMean += delta / (float)(++tailCount));
            tailM2 = PanamaESVectorUtilSupport.fma(delta, delta2, tailM2);
            min = Math.min(min, centered[i]);
            max = Math.max(max, centered[i]);
            norm2 = PanamaESVectorUtilSupport.fma(centered[i], centered[i], norm2);
            ++i;
        }
        if (vectCount == 0) {
            vecMean = tailMean;
            vecVar = tailM2;
        } else if (tailCount > 0) {
            int totalCount = tailCount + vectCount;
            assert (totalCount == vector.length);
            float alpha = (float)vectCount / (float)totalCount;
            float beta = 1.0f - alpha;
            float completeMean = alpha * vecMean + beta * tailMean;
            float dMean2Lhs = (vecMean - completeMean) * (vecMean - completeMean);
            float dMean2Rhs = (tailMean - completeMean) * (tailMean - completeMean);
            vecVar = vecVar + dMean2Lhs + beta * (tailM2 + dMean2Rhs);
            vecMean = completeMean;
        }
        stats[0] = vecMean;
        stats[1] = vecVar / (float)vector.length;
        stats[2] = norm2;
        stats[3] = min;
        stats[4] = max;
        stats[5] = centroidDot;
    }

    @Override
    public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
        int i;
        float a = interval[0];
        float b = interval[1];
        float daa = 0.0f;
        float dab = 0.0f;
        float dbb = 0.0f;
        float dax = 0.0f;
        float dbx = 0.0f;
        FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
        FloatVector dabVec = FloatVector.zero(FLOAT_SPECIES);
        FloatVector dbbVec = FloatVector.zero(FLOAT_SPECIES);
        FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
        FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
        if (target.length > 2 * FLOAT_SPECIES.length()) {
            FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, (float)1.0f);
            FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, (float)((float)points - 1.0f));
            for (i = 0; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
                FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, (float[])target, (int)i);
                FloatVector vClamped = v.max(a).min(b);
                Vector xiqint = vClamped.sub(a).mul(invStep).add(0.5f).convert(VectorOperators.F2I, 0);
                FloatVector kVec = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats();
                FloatVector sVec = kVec.div((Vector)pmOnes);
                FloatVector smVec = ones.sub((Vector)sVec);
                daaVec = PanamaESVectorUtilSupport.fma(smVec, smVec, daaVec);
                dabVec = PanamaESVectorUtilSupport.fma(smVec, sVec, dabVec);
                dbbVec = PanamaESVectorUtilSupport.fma(sVec, sVec, dbbVec);
                daxVec = PanamaESVectorUtilSupport.fma(v, smVec, daxVec);
                dbxVec = PanamaESVectorUtilSupport.fma(v, sVec, dbxVec);
            }
            daa = daaVec.reduceLanes(VectorOperators.ADD);
            dab = dabVec.reduceLanes(VectorOperators.ADD);
            dbb = dbbVec.reduceLanes(VectorOperators.ADD);
            dax = daxVec.reduceLanes(VectorOperators.ADD);
            dbx = dbxVec.reduceLanes(VectorOperators.ADD);
        }
        while (i < target.length) {
            float k = Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep);
            float s = k / (float)(points - 1);
            float ms = 1.0f - s;
            daa = PanamaESVectorUtilSupport.fma(ms, ms, daa);
            dab = PanamaESVectorUtilSupport.fma(ms, s, dab);
            dbb = PanamaESVectorUtilSupport.fma(s, s, dbb);
            dax = PanamaESVectorUtilSupport.fma(ms, target[i], dax);
            dbx = PanamaESVectorUtilSupport.fma(s, target[i], dbx);
            ++i;
        }
        pts[0] = daa;
        pts[1] = dab;
        pts[2] = dbb;
        pts[3] = dax;
        pts[4] = dbx;
    }

    @Override
    public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
        int i;
        float a = interval[0];
        float b = interval[1];
        float xe = 0.0f;
        float e = 0.0f;
        FloatVector xeVec = FloatVector.zero(FLOAT_SPECIES);
        FloatVector eVec = FloatVector.zero(FLOAT_SPECIES);
        if (target.length > 2 * FLOAT_SPECIES.length()) {
            for (i = 0; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
                FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, (float[])target, (int)i);
                FloatVector vClamped = v.max(a).min(b);
                Vector xiqint = vClamped.sub(a).mul(invStep).add(0.5f).convert(VectorOperators.F2I, 0);
                FloatVector xiq = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats().mul(step).add(a);
                FloatVector xiiq = v.sub((Vector)xiq);
                xeVec = PanamaESVectorUtilSupport.fma(v, xiiq, xeVec);
                eVec = PanamaESVectorUtilSupport.fma(xiiq, xiiq, eVec);
            }
            e = eVec.reduceLanes(VectorOperators.ADD);
            xe = xeVec.reduceLanes(VectorOperators.ADD);
        }
        while (i < target.length) {
            float xiq = PanamaESVectorUtilSupport.fma(step, Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep), a);
            float xiiq = target[i] - xiq;
            e = PanamaESVectorUtilSupport.fma(xiiq, xiiq, e);
            xe = PanamaESVectorUtilSupport.fma(target[i], xiiq, xe);
            ++i;
        }
        return (1.0f - lambda) * xe * xe / norm2 + lambda * e;
    }

    static long ipByteBin256(byte[] q, byte[] d) {
        LongVector vd;
        LongVector vq3;
        LongVector vq2;
        LongVector vq1;
        LongVector vq0;
        int i;
        long subRet0 = 0L;
        long subRet1 = 0L;
        long subRet2 = 0L;
        long subRet3 = 0L;
        if (d.length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
            int limit = ByteVector.SPECIES_256.loopBound(d.length);
            LongVector sum0 = LongVector.zero((VectorSpecies)LongVector.SPECIES_256);
            LongVector sum1 = LongVector.zero((VectorSpecies)LongVector.SPECIES_256);
            LongVector sum2 = LongVector.zero((VectorSpecies)LongVector.SPECIES_256);
            LongVector sum3 = LongVector.zero((VectorSpecies)LongVector.SPECIES_256);
            for (i = 0; i < limit; i += ByteVector.SPECIES_256.length()) {
                vq0 = ByteVector.fromArray(BYTE_SPECIES_256, (byte[])q, (int)i).reinterpretAsLongs();
                vq1 = ByteVector.fromArray(BYTE_SPECIES_256, (byte[])q, (int)(i + d.length)).reinterpretAsLongs();
                vq2 = ByteVector.fromArray(BYTE_SPECIES_256, (byte[])q, (int)(i + d.length * 2)).reinterpretAsLongs();
                vq3 = ByteVector.fromArray(BYTE_SPECIES_256, (byte[])q, (int)(i + d.length * 3)).reinterpretAsLongs();
                vd = ByteVector.fromArray(BYTE_SPECIES_256, (byte[])d, (int)i).reinterpretAsLongs();
                sum0 = sum0.add((Vector)vq0.and((Vector)vd).lanewise(VectorOperators.BIT_COUNT));
                sum1 = sum1.add((Vector)vq1.and((Vector)vd).lanewise(VectorOperators.BIT_COUNT));
                sum2 = sum2.add((Vector)vq2.and((Vector)vd).lanewise(VectorOperators.BIT_COUNT));
                sum3 = sum3.add((Vector)vq3.and((Vector)vd).lanewise(VectorOperators.BIT_COUNT));
            }
            subRet0 += sum0.reduceLanes(VectorOperators.ADD);
            subRet1 += sum1.reduceLanes(VectorOperators.ADD);
            subRet2 += sum2.reduceLanes(VectorOperators.ADD);
            subRet3 += sum3.reduceLanes(VectorOperators.ADD);
        }
        if (d.length - i >= ByteVector.SPECIES_128.vectorByteSize()) {
            LongVector sum0 = LongVector.zero((VectorSpecies)LongVector.SPECIES_128);
            LongVector sum1 = LongVector.zero((VectorSpecies)LongVector.SPECIES_128);
            LongVector sum2 = LongVector.zero((VectorSpecies)LongVector.SPECIES_128);
            LongVector sum3 = LongVector.zero((VectorSpecies)LongVector.SPECIES_128);
            int limit = ByteVector.SPECIES_128.loopBound(d.length);
            while (i < limit) {
                vq0 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)i).reinterpretAsLongs();
                vq1 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)(i + d.length)).reinterpretAsLongs();
                vq2 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)(i + d.length * 2)).reinterpretAsLongs();
                vq3 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)(i + d.length * 3)).reinterpretAsLongs();
                vd = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])d, (int)i).reinterpretAsLongs();
                sum0 = sum0.add((Vector)vq0.and((Vector)vd).lanewise(VectorOperators.BIT_COUNT));
                sum1 = sum1.add((Vector)vq1.and((Vector)vd).lanewise(VectorOperators.BIT_COUNT));
                sum2 = sum2.add((Vector)vq2.and((Vector)vd).lanewise(VectorOperators.BIT_COUNT));
                sum3 = sum3.add((Vector)vq3.and((Vector)vd).lanewise(VectorOperators.BIT_COUNT));
                i += ByteVector.SPECIES_128.length();
            }
            subRet0 += sum0.reduceLanes(VectorOperators.ADD);
            subRet1 += sum1.reduceLanes(VectorOperators.ADD);
            subRet2 += sum2.reduceLanes(VectorOperators.ADD);
            subRet3 += sum3.reduceLanes(VectorOperators.ADD);
        }
        while (i < d.length) {
            subRet0 += (long)Integer.bitCount(q[i] & d[i] & 0xFF);
            subRet1 += (long)Integer.bitCount(q[i + d.length] & d[i] & 0xFF);
            subRet2 += (long)Integer.bitCount(q[i + 2 * d.length] & d[i] & 0xFF);
            subRet3 += (long)Integer.bitCount(q[i + 3 * d.length] & d[i] & 0xFF);
            ++i;
        }
        return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
    }

    public static long ipByteBin128(byte[] q, byte[] d) {
        int i;
        long subRet0 = 0L;
        long subRet1 = 0L;
        long subRet2 = 0L;
        long subRet3 = 0L;
        IntVector sum0 = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        IntVector sum1 = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        IntVector sum2 = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        IntVector sum3 = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        int limit = ByteVector.SPECIES_128.loopBound(d.length);
        for (i = 0; i < limit; i += ByteVector.SPECIES_128.length()) {
            IntVector vd = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])d, (int)i).reinterpretAsInts();
            IntVector vq0 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)i).reinterpretAsInts();
            IntVector vq1 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)(i + d.length)).reinterpretAsInts();
            IntVector vq2 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)(i + d.length * 2)).reinterpretAsInts();
            IntVector vq3 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)(i + d.length * 3)).reinterpretAsInts();
            sum0 = sum0.add((Vector)vd.and((Vector)vq0).lanewise(VectorOperators.BIT_COUNT));
            sum1 = sum1.add((Vector)vd.and((Vector)vq1).lanewise(VectorOperators.BIT_COUNT));
            sum2 = sum2.add((Vector)vd.and((Vector)vq2).lanewise(VectorOperators.BIT_COUNT));
            sum3 = sum3.add((Vector)vd.and((Vector)vq3).lanewise(VectorOperators.BIT_COUNT));
        }
        subRet0 += (long)sum0.reduceLanes(VectorOperators.ADD);
        subRet1 += (long)sum1.reduceLanes(VectorOperators.ADD);
        subRet2 += (long)sum2.reduceLanes(VectorOperators.ADD);
        subRet3 += (long)sum3.reduceLanes(VectorOperators.ADD);
        while (i < d.length) {
            byte dValue = d[i];
            subRet0 += (long)Integer.bitCount(dValue & q[i] & 0xFF);
            subRet1 += (long)Integer.bitCount(dValue & q[i + d.length] & 0xFF);
            subRet2 += (long)Integer.bitCount(dValue & q[i + 2 * d.length] & 0xFF);
            subRet3 += (long)Integer.bitCount(dValue & q[i + 3 * d.length] & 0xFF);
            ++i;
        }
        return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
    }

    static {
        boolean isAMD64withoutAVX2 = Constants.OS_ARCH.equals("amd64") && VECTOR_BITSIZE < 256;
        HAS_FAST_INTEGER_VECTORS = !isAMD64withoutAVX2;
        BYTE_SPECIES_128 = ByteVector.SPECIES_128;
        BYTE_SPECIES_256 = ByteVector.SPECIES_256;
    }
}

