/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec.internal.vectorization;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.internal.vectorization.PanamaESVectorUtilSupport;

public final class MemorySegmentES91Int4VectorsScorer
extends ES91Int4VectorsScorer {
    private static final VectorSpecies<Byte> BYTE_SPECIES_64 = ByteVector.SPECIES_64;
    private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
    private static final VectorSpecies<Short> SHORT_SPECIES_128 = ShortVector.SPECIES_128;
    private static final VectorSpecies<Short> SHORT_SPECIES_256 = ShortVector.SPECIES_256;
    private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
    private static final VectorSpecies<Integer> INT_SPECIES_256 = IntVector.SPECIES_256;
    private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
    private static final VectorSpecies<Float> FLOAT_SPECIES = VectorSpecies.of(Float.TYPE, (VectorShape)VectorShape.forBitSize((int)PanamaESVectorUtilSupport.VECTOR_BITSIZE));
    private static final VectorSpecies<Short> SHORT_SPECIES = VectorSpecies.of(Short.TYPE, (VectorShape)VectorShape.forBitSize((int)PanamaESVectorUtilSupport.VECTOR_BITSIZE));
    private final MemorySegment memorySegment;

    public MemorySegmentES91Int4VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
        super(in, dimensions);
        this.memorySegment = memorySegment;
    }

    @Override
    public long int4DotProduct(byte[] q) throws IOException {
        if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512 || PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) {
            return this.dotProduct(q);
        }
        int i = 0;
        int res = 0;
        if (this.dimensions >= 32 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
            res += this.int4DotProductBody128(q, i += BYTE_SPECIES_128.loopBound(this.dimensions));
        }
        this.in.readBytes(this.scratch, i, this.dimensions - i);
        while (i < this.dimensions) {
            res += this.scratch[i] * q[i++];
        }
        return res;
    }

    private int int4DotProductBody128(byte[] q, int limit) throws IOException {
        int sum = 0;
        long offset = this.in.getFilePointer();
        for (int i = 0; i < limit; i += 1024) {
            ShortVector acc0 = ShortVector.zero(SHORT_SPECIES_128);
            ShortVector acc1 = ShortVector.zero(SHORT_SPECIES_128);
            int innerLimit = Math.min(limit - i, 1024);
            for (int j = 0; j < innerLimit; j += BYTE_SPECIES_128.length()) {
                ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)(i + j));
                ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i + (long)j), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                ByteVector prod8 = va8.mul((Vector)vb8);
                ShortVector prod16 = prod8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
                acc0 = acc0.add((Vector)prod16.and((short)255));
                va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)(i + j + 8));
                vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i + (long)j + 8L), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                prod8 = va8.mul((Vector)vb8);
                prod16 = prod8.convertShape(VectorOperators.B2S, SHORT_SPECIES_128, 0).reinterpretAsShorts();
                acc1 = acc1.add((Vector)prod16.and((short)255));
            }
            IntVector intAcc0 = acc0.convertShape(VectorOperators.S2I, INT_SPECIES_128, 0).reinterpretAsInts();
            IntVector intAcc1 = acc0.convertShape(VectorOperators.S2I, INT_SPECIES_128, 1).reinterpretAsInts();
            IntVector intAcc2 = acc1.convertShape(VectorOperators.S2I, INT_SPECIES_128, 0).reinterpretAsInts();
            IntVector intAcc3 = acc1.convertShape(VectorOperators.S2I, INT_SPECIES_128, 1).reinterpretAsInts();
            sum += intAcc0.add((Vector)intAcc1).add((Vector)intAcc2).add((Vector)intAcc3).reduceLanes(VectorOperators.ADD);
        }
        this.in.seek(offset + (long)limit);
        return sum;
    }

    private long dotProduct(byte[] q) throws IOException {
        if (this.dimensions >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
            int i = 0;
            int res = 0;
            if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512) {
                res += this.dotProductBody512(q, i += BYTE_SPECIES_128.loopBound(this.dimensions));
            } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) {
                res += this.dotProductBody256(q, i += BYTE_SPECIES_64.loopBound(this.dimensions));
            } else {
                throw new IllegalArgumentException("Unreacheable statement");
            }
            while (i < q.length) {
                res += this.in.readByte() * q[i];
                ++i;
            }
            return res;
        }
        return super.int4DotProduct(q);
    }

    private int dotProductBody512(byte[] q, int limit) throws IOException {
        IntVector acc = IntVector.zero(INT_SPECIES_512);
        long offset = this.in.getFilePointer();
        for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)i);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va16 = va8.convertShape(VectorOperators.B2S, SHORT_SPECIES_256, 0);
            Vector vb16 = vb8.convertShape(VectorOperators.B2S, SHORT_SPECIES_256, 0);
            Vector prod16 = va16.mul(vb16);
            Vector prod32 = prod16.convertShape(VectorOperators.S2I, INT_SPECIES_512, 0);
            acc = acc.add(prod32);
        }
        this.in.seek(offset + (long)limit);
        return acc.reduceLanes(VectorOperators.ADD);
    }

    private int dotProductBody256(byte[] q, int limit) throws IOException {
        IntVector acc = IntVector.zero(INT_SPECIES_256);
        long offset = this.in.getFilePointer();
        for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)i);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va32 = va8.convertShape(VectorOperators.B2I, INT_SPECIES_256, 0);
            Vector vb32 = vb8.convertShape(VectorOperators.B2I, INT_SPECIES_256, 0);
            acc = acc.add(va32.mul(vb32));
        }
        this.in.seek(offset + (long)limit);
        return acc.reduceLanes(VectorOperators.ADD);
    }

    @Override
    public void int4DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
        if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512 || PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) {
            this.dotProductBulk(q, count, scores);
            return;
        }
        if (this.dimensions >= 32 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
            this.int4DotProductBody128Bulk(q, count, scores);
            return;
        }
        super.int4DotProductBulk(q, count, scores);
    }

    private void int4DotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException {
        int limit = BYTE_SPECIES_128.loopBound(this.dimensions);
        for (int iter = 0; iter < count; ++iter) {
            int sum = 0;
            long offset = this.in.getFilePointer();
            for (int i = 0; i < limit; i += 1024) {
                ShortVector acc0 = ShortVector.zero(SHORT_SPECIES_128);
                ShortVector acc1 = ShortVector.zero(SHORT_SPECIES_128);
                int innerLimit = Math.min(limit - i, 1024);
                for (int j = 0; j < innerLimit; j += BYTE_SPECIES_128.length()) {
                    ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)(i + j));
                    ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i + (long)j), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                    ByteVector prod8 = va8.mul((Vector)vb8);
                    ShortVector prod16 = prod8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
                    acc0 = acc0.add((Vector)prod16.and((short)255));
                    va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)(i + j + 8));
                    vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i + (long)j + 8L), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                    prod8 = va8.mul((Vector)vb8);
                    prod16 = prod8.convertShape(VectorOperators.B2S, SHORT_SPECIES_128, 0).reinterpretAsShorts();
                    acc1 = acc1.add((Vector)prod16.and((short)255));
                }
                IntVector intAcc0 = acc0.convertShape(VectorOperators.S2I, INT_SPECIES_128, 0).reinterpretAsInts();
                IntVector intAcc1 = acc0.convertShape(VectorOperators.S2I, INT_SPECIES_128, 1).reinterpretAsInts();
                IntVector intAcc2 = acc1.convertShape(VectorOperators.S2I, INT_SPECIES_128, 0).reinterpretAsInts();
                IntVector intAcc3 = acc1.convertShape(VectorOperators.S2I, INT_SPECIES_128, 1).reinterpretAsInts();
                sum += intAcc0.add((Vector)intAcc1).add((Vector)intAcc2).add((Vector)intAcc3).reduceLanes(VectorOperators.ADD);
            }
            this.in.seek(offset + (long)limit);
            this.in.readBytes(this.scratch, limit, this.dimensions - limit);
            for (int j = limit; j < this.dimensions; ++j) {
                sum += this.scratch[j] * q[j];
            }
            scores[iter] = sum;
        }
    }

    private void dotProductBulk(byte[] q, int count, float[] scores) throws IOException {
        if (this.dimensions >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
            if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512) {
                this.dotProductBody512Bulk(q, count, scores);
            } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) {
                this.dotProductBody256Bulk(q, count, scores);
            } else {
                throw new IllegalArgumentException("Unreacheable statement");
            }
            return;
        }
        super.int4DotProductBulk(q, count, scores);
    }

    private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException {
        int limit = BYTE_SPECIES_128.loopBound(this.dimensions);
        for (int iter = 0; iter < count; ++iter) {
            int i;
            IntVector acc = IntVector.zero(INT_SPECIES_512);
            long offset = this.in.getFilePointer();
            for (i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
                ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)i);
                ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                Vector va16 = va8.convertShape(VectorOperators.B2S, SHORT_SPECIES_256, 0);
                Vector vb16 = vb8.convertShape(VectorOperators.B2S, SHORT_SPECIES_256, 0);
                Vector prod16 = va16.mul(vb16);
                Vector prod32 = prod16.convertShape(VectorOperators.S2I, INT_SPECIES_512, 0);
                acc = acc.add(prod32);
            }
            this.in.seek(offset + (long)limit);
            long res = acc.reduceLanes(VectorOperators.ADD);
            while (i < q.length) {
                res += (long)(this.in.readByte() * q[i]);
                ++i;
            }
            scores[iter] = res;
        }
    }

    private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException {
        int limit = BYTE_SPECIES_128.loopBound(this.dimensions);
        for (int iter = 0; iter < count; ++iter) {
            int i;
            IntVector acc = IntVector.zero(INT_SPECIES_256);
            long offset = this.in.getFilePointer();
            for (i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
                ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)i);
                ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                Vector va32 = va8.convertShape(VectorOperators.B2I, INT_SPECIES_256, 0);
                Vector vb32 = vb8.convertShape(VectorOperators.B2I, INT_SPECIES_256, 0);
                acc = acc.add(va32.mul(vb32));
            }
            this.in.seek(offset + (long)limit);
            long res = acc.reduceLanes(VectorOperators.ADD);
            while (i < q.length) {
                res += (long)(this.in.readByte() * q[i]);
                ++i;
            }
            scores[iter] = res;
        }
    }

    @Override
    public void scoreBulk(byte[] q, float queryLowerInterval, float queryUpperInterval, int queryComponentSum, float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores) throws IOException {
        this.int4DotProductBulk(q, 16, scores);
        this.applyCorrectionsBulk(queryLowerInterval, queryUpperInterval, queryComponentSum, queryAdditionalCorrection, similarityFunction, centroidDp, scores);
    }

    private void applyCorrectionsBulk(float queryLowerInterval, float queryUpperInterval, int queryComponentSum, float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores) throws IOException {
        int limit = FLOAT_SPECIES.loopBound(16);
        long offset = this.in.getFilePointer();
        float ay = queryLowerInterval;
        float ly = (queryUpperInterval - ay) * 0.06666667f;
        float y1 = queryComponentSum;
        for (int i = 0; i < limit; i += FLOAT_SPECIES.length()) {
            FloatVector ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)this.memorySegment, (long)(offset + (long)(i * 4)), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            FloatVector lx = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)this.memorySegment, (long)(offset + 64L + (long)(i * 4)), (ByteOrder)ByteOrder.LITTLE_ENDIAN).sub((Vector)ax).mul(0.06666667f);
            Vector targetComponentSums = ShortVector.fromMemorySegment(SHORT_SPECIES, (MemorySegment)this.memorySegment, (long)(offset + 128L + (long)(i * 2)), (ByteOrder)ByteOrder.LITTLE_ENDIAN).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(65535).convert(VectorOperators.I2F, 0);
            FloatVector additionalCorrections = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)this.memorySegment, (long)(offset + 160L + (long)(i * 4)), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            FloatVector qcDist = FloatVector.fromArray(FLOAT_SPECIES, (float[])scores, (int)i);
            FloatVector res1 = ax.mul(ay).mul((float)this.dimensions);
            FloatVector res2 = lx.mul(ay).mul(targetComponentSums);
            FloatVector res3 = ax.mul(ly).mul(y1);
            FloatVector res4 = lx.mul(ly).mul((Vector)qcDist);
            FloatVector res = res1.add((Vector)res2).add((Vector)res3).add((Vector)res4);
            if (similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
                res = res.mul(-2.0f).add((Vector)additionalCorrections).add(queryAdditionalCorrection).add(1.0f);
                res = FloatVector.broadcast(FLOAT_SPECIES, (long)1L).div((Vector)res).max(0.0f);
                res.intoArray(scores, i);
                continue;
            }
            res = res.add(queryAdditionalCorrection).add((Vector)additionalCorrections).sub(centroidDp);
            if (similarityFunction == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
                res.intoArray(scores, i);
                for (int j = 0; j < FLOAT_SPECIES.length(); ++j) {
                    scores[i + j] = VectorUtil.scaleMaxInnerProductScore((float)scores[i + j]);
                }
                continue;
            }
            res = res.add(1.0f).mul(0.5f).max(0.0f);
            res.intoArray(scores, i);
        }
        this.in.seek(offset + 224L);
    }
}

