/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec.internal;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;

abstract class MemorySegmentES92PanamaInt7VectorsScorer
extends ES92Int7VectorsScorer {
    private static final VectorSpecies<Byte> BYTE_SPECIES_64 = ByteVector.SPECIES_64;
    private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
    private static final VectorSpecies<Short> SHORT_SPECIES_128 = ShortVector.SPECIES_128;
    private static final VectorSpecies<Short> SHORT_SPECIES_256 = ShortVector.SPECIES_256;
    private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
    private static final VectorSpecies<Integer> INT_SPECIES_256 = IntVector.SPECIES_256;
    private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
    private static final int VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
    private static final VectorSpecies<Float> FLOAT_SPECIES = VectorSpecies.of(Float.TYPE, (VectorShape)VectorShape.forBitSize((int)VECTOR_BITSIZE));
    private static final VectorSpecies<Integer> INT_SPECIES = VectorSpecies.of(Integer.TYPE, (VectorShape)VectorShape.forBitSize((int)VECTOR_BITSIZE));
    protected final MemorySegment memorySegment;

    protected MemorySegmentES92PanamaInt7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
        super(in, dimensions);
        this.memorySegment = memorySegment;
    }

    protected long panamaInt7DotProduct(byte[] q) throws IOException {
        assert (this.dimensions == q.length);
        int i = 0;
        int res = 0;
        if (this.dimensions >= 16) {
            res = VECTOR_BITSIZE >= 512 ? (res += this.dotProductBody512(q, i += BYTE_SPECIES_128.loopBound(this.dimensions))) : (VECTOR_BITSIZE == 256 ? (res += this.dotProductBody256(q, i += BYTE_SPECIES_64.loopBound(this.dimensions))) : (res += this.dotProductBody128(q, i += BYTE_SPECIES_64.loopBound(this.dimensions - BYTE_SPECIES_64.length()))));
            while (i < this.dimensions) {
                res += this.in.readByte() * q[i++];
            }
            return res;
        }
        return super.int7DotProduct(q);
    }

    private int dotProductBody512(byte[] q, int limit) throws IOException {
        IntVector acc = IntVector.zero(INT_SPECIES_512);
        long offset = this.in.getFilePointer();
        for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)i);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va16 = va8.convertShape(VectorOperators.B2S, SHORT_SPECIES_256, 0);
            Vector vb16 = vb8.convertShape(VectorOperators.B2S, SHORT_SPECIES_256, 0);
            Vector prod16 = va16.mul(vb16);
            Vector prod32 = prod16.convertShape(VectorOperators.S2I, INT_SPECIES_512, 0);
            acc = acc.add(prod32);
        }
        this.in.seek(offset + (long)limit);
        return acc.reduceLanes(VectorOperators.ADD);
    }

    private int dotProductBody256(byte[] q, int limit) throws IOException {
        IntVector acc = IntVector.zero(INT_SPECIES_256);
        long offset = this.in.getFilePointer();
        for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)i);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va32 = va8.convertShape(VectorOperators.B2I, INT_SPECIES_256, 0);
            Vector vb32 = vb8.convertShape(VectorOperators.B2I, INT_SPECIES_256, 0);
            acc = acc.add(va32.mul(vb32));
        }
        this.in.seek(offset + (long)limit);
        return acc.reduceLanes(VectorOperators.ADD);
    }

    private int dotProductBody128(byte[] q, int limit) throws IOException {
        IntVector acc = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
        long offset = this.in.getFilePointer();
        for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)i);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va16 = va8.convert(VectorOperators.B2S, 0);
            Vector vb16 = vb8.convert(VectorOperators.B2S, 0);
            Vector prod16 = va16.mul(vb16);
            acc = acc.add(prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
        }
        this.in.seek(offset + (long)limit);
        return acc.reduceLanes(VectorOperators.ADD);
    }

    protected void panamaInt7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
        assert (this.dimensions == q.length);
        if (this.dimensions >= 16) {
            if (VECTOR_BITSIZE >= 512) {
                this.dotProductBody512Bulk(q, count, scores);
            } else if (VECTOR_BITSIZE == 256) {
                this.dotProductBody256Bulk(q, count, scores);
            } else {
                this.dotProductBody128Bulk(q, count, scores);
            }
        } else {
            super.int7DotProductBulk(q, count, scores);
        }
    }

    private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException {
        int limit = BYTE_SPECIES_128.loopBound(this.dimensions);
        for (int iter = 0; iter < count; ++iter) {
            int i;
            IntVector acc = IntVector.zero(INT_SPECIES_512);
            long offset = this.in.getFilePointer();
            for (i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
                ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)i);
                ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                Vector va16 = va8.convertShape(VectorOperators.B2S, SHORT_SPECIES_256, 0);
                Vector vb16 = vb8.convertShape(VectorOperators.B2S, SHORT_SPECIES_256, 0);
                Vector prod16 = va16.mul(vb16);
                Vector prod32 = prod16.convertShape(VectorOperators.S2I, INT_SPECIES_512, 0);
                acc = acc.add(prod32);
            }
            this.in.seek(offset + (long)limit);
            long res = acc.reduceLanes(VectorOperators.ADD);
            while (i < this.dimensions) {
                res += (long)(this.in.readByte() * q[i]);
                ++i;
            }
            scores[iter] = res;
        }
    }

    private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException {
        int limit = BYTE_SPECIES_128.loopBound(this.dimensions);
        for (int iter = 0; iter < count; ++iter) {
            int i;
            IntVector acc = IntVector.zero(INT_SPECIES_256);
            long offset = this.in.getFilePointer();
            for (i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
                ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)i);
                ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                Vector va32 = va8.convertShape(VectorOperators.B2I, INT_SPECIES_256, 0);
                Vector vb32 = vb8.convertShape(VectorOperators.B2I, INT_SPECIES_256, 0);
                acc = acc.add(va32.mul(vb32));
            }
            this.in.seek(offset + (long)limit);
            long res = acc.reduceLanes(VectorOperators.ADD);
            while (i < this.dimensions) {
                res += (long)(this.in.readByte() * q[i]);
                ++i;
            }
            scores[iter] = res;
        }
    }

    private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException {
        int limit = BYTE_SPECIES_64.loopBound(this.dimensions - BYTE_SPECIES_64.length());
        for (int iter = 0; iter < count; ++iter) {
            int i;
            IntVector acc = IntVector.zero((VectorSpecies)IntVector.SPECIES_128);
            long offset = this.in.getFilePointer();
            for (i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
                ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)i);
                ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                Vector va16 = va8.convert(VectorOperators.B2S, 0);
                Vector vb16 = vb8.convert(VectorOperators.B2S, 0);
                Vector prod16 = va16.mul(vb16);
                acc = acc.add(prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
            }
            this.in.seek(offset + (long)limit);
            long res = acc.reduceLanes(VectorOperators.ADD);
            while (i < this.dimensions) {
                res += (long)(this.in.readByte() * q[i]);
                ++i;
            }
            scores[iter] = res;
        }
    }

    protected void applyCorrectionsBulk(float queryLowerInterval, float queryUpperInterval, int queryComponentSum, float queryAdditionalCorrection, VectorSimilarityFunction similarityFunction, float centroidDp, float[] scores) throws IOException {
        int limit = FLOAT_SPECIES.loopBound(16);
        long offset = this.in.getFilePointer();
        float ay = queryLowerInterval;
        float ly = (queryUpperInterval - ay) * 0.007874016f;
        float y1 = queryComponentSum;
        for (int i = 0; i < limit; i += FLOAT_SPECIES.length()) {
            FloatVector ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)this.memorySegment, (long)(offset + (long)(i * 4)), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            FloatVector lx = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)this.memorySegment, (long)(offset + 64L + (long)(i * 4)), (ByteOrder)ByteOrder.LITTLE_ENDIAN).sub((Vector)ax).mul(0.007874016f);
            Vector targetComponentSums = IntVector.fromMemorySegment(INT_SPECIES, (MemorySegment)this.memorySegment, (long)(offset + 128L + (long)(i * 4)), (ByteOrder)ByteOrder.LITTLE_ENDIAN).convert(VectorOperators.I2F, 0);
            FloatVector additionalCorrections = FloatVector.fromMemorySegment(FLOAT_SPECIES, (MemorySegment)this.memorySegment, (long)(offset + 192L + (long)(i * 4)), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            FloatVector qcDist = FloatVector.fromArray(FLOAT_SPECIES, (float[])scores, (int)i);
            FloatVector res1 = ax.mul(ay).mul((float)this.dimensions);
            FloatVector res2 = lx.mul(ay).mul(targetComponentSums);
            FloatVector res3 = ax.mul(ly).mul(y1);
            FloatVector res4 = lx.mul(ly).mul((Vector)qcDist);
            FloatVector res = res1.add((Vector)res2).add((Vector)res3).add((Vector)res4);
            if (similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
                res = res.mul(-2.0f).add((Vector)additionalCorrections).add(queryAdditionalCorrection).add(1.0f);
                res = FloatVector.broadcast(FLOAT_SPECIES, (long)1L).div((Vector)res).max(0.0f);
                res.intoArray(scores, i);
                continue;
            }
            res = res.add(queryAdditionalCorrection).add((Vector)additionalCorrections).sub(centroidDp);
            if (similarityFunction == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
                res.intoArray(scores, i);
                for (int j = 0; j < FLOAT_SPECIES.length(); ++j) {
                    scores[i + j] = VectorUtil.scaleMaxInnerProductScore((float)scores[i + j]);
                }
                continue;
            }
            res = res.add(1.0f).mul(0.5f).max(0.0f);
            res.intoArray(scores, i);
        }
        this.in.seek(offset + 256L);
    }
}

