/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec;

import java.io.IOException;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.Objects;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
import org.elasticsearch.simdvec.ESNextOSQVectorsScorer;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;

public class ESVectorUtil {
    private static final MethodHandle BIT_COUNT_MH;
    private static final ESVectorUtilSupport IMPL;

    public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
        return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
    }

    public static ESNextOSQVectorsScorer getESNextOSQVectorsScorer(IndexInput input, byte queryBits, byte indexBits, int dimension, int dataLength) throws IOException {
        return ESVectorizationProvider.getInstance().newESNextOSQVectorsScorer(input, queryBits, indexBits, dimension, dataLength);
    }

    public static ES91Int4VectorsScorer getES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
        return ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(input, dimension);
    }

    public static ES92Int7VectorsScorer getES92Int7VectorsScorer(IndexInput input, int dimension) throws IOException {
        return ESVectorizationProvider.getInstance().newES92Int7VectorsScorer(input, dimension);
    }

    public static long ipByteBinByte(byte[] q, byte[] d) {
        if (q.length != d.length * 4) {
            throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= 4 x " + d.length);
        }
        return IMPL.ipByteBinByte(q, d);
    }

    public static int ipByteBit(byte[] q, byte[] d) {
        if (q.length != d.length * 8) {
            throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= 8 x " + d.length);
        }
        return IMPL.ipByteBit(q, d);
    }

    public static float ipFloatBit(float[] q, byte[] d) {
        if (q.length != d.length * 8) {
            throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= 8 x " + d.length);
        }
        return IMPL.ipFloatBit(q, d);
    }

    public static float ipFloatByte(float[] q, byte[] d) {
        if (q.length != d.length) {
            throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + d.length);
        }
        return IMPL.ipFloatByte(q, d);
    }

    public static int andBitCount(byte[] a, byte[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
        }
        try {
            return BIT_COUNT_MH.invokeExact(a, b);
        }
        catch (Throwable e) {
            if (e instanceof Error) {
                Error err = (Error)e;
                throw err;
            }
            if (e instanceof RuntimeException) {
                RuntimeException re = (RuntimeException)e;
                throw re;
            }
            throw new RuntimeException(e);
        }
    }

    static int andBitCountInt(byte[] a, byte[] b) {
        int i;
        int distance = 0;
        int upperBound = a.length & 0xFFFFFFFC;
        for (i = 0; i < upperBound; i += 4) {
            distance += Integer.bitCount(BitUtil.VH_NATIVE_INT.get(a, i) & BitUtil.VH_NATIVE_INT.get(b, i));
        }
        while (i < a.length) {
            distance += Integer.bitCount(a[i] & b[i] & 0xFF);
            ++i;
        }
        return distance;
    }

    static int andBitCountLong(byte[] a, byte[] b) {
        int i;
        int distance = 0;
        int upperBound = a.length & 0xFFFFFFF8;
        for (i = 0; i < upperBound; i += 8) {
            distance += Long.bitCount(BitUtil.VH_NATIVE_LONG.get(a, i) & BitUtil.VH_NATIVE_LONG.get(b, i));
        }
        while (i < a.length) {
            distance += Integer.bitCount(a[i] & b[i] & 0xFF);
            ++i;
        }
        return distance;
    }

    public static float calculateOSQLoss(float[] target, float lowerInterval, float upperInterval, int points, float norm2, float lambda, int[] quantize) {
        assert (upperInterval >= lowerInterval);
        float step = (upperInterval - lowerInterval) / ((float)points - 1.0f);
        float invStep = 1.0f / step;
        return IMPL.calculateOSQLoss(target, lowerInterval, upperInterval, step, invStep, norm2, lambda, quantize);
    }

    public static void calculateOSQGridPoints(float[] target, int[] quantize, int points, float[] pts) {
        assert (target.length <= quantize.length);
        assert (pts.length == 5);
        IMPL.calculateOSQGridPoints(target, quantize, points, pts);
    }

    public static void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
        assert (target.length == centroid.length);
        assert (stats.length == 5);
        if (target.length != centroid.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
        }
        if (centered.length != target.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
        }
        IMPL.centerAndCalculateOSQStatsEuclidean(target, centroid, centered, stats);
    }

    public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
        if (target.length != centroid.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
        }
        if (centered.length != target.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
        }
        assert (stats.length == 6);
        IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
    }

    public static void subtract(float[] v1, float[] v2, float[] result) {
        if (v1.length != v2.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + v2.length);
        }
        if (result.length != v1.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + result.length + "!=" + v1.length);
        }
        for (int i = 0; i < v1.length; ++i) {
            result[i] = v1[i] - v2[i];
        }
    }

    public static float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) {
        if (v1.length != centroid.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length);
        }
        if (originalResidual.length != v1.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length);
        }
        return IMPL.soarDistance(v1, centroid, originalResidual, soarLambda, rnorm);
    }

    public static int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bit) {
        if (vector.length > destination.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + vector.length + "!=" + destination.length);
        }
        if (bit <= 0 || bit > 8) {
            throw new IllegalArgumentException("bit must be between 1 and 8, but was: " + bit);
        }
        return IMPL.quantizeVectorWithIntervals(vector, destination, lowInterval, upperInterval, bit);
    }

    public static void squareDistanceBulk(float[] q, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) {
        if (q.length != v0.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v0.length);
        }
        if (q.length != v1.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v1.length);
        }
        if (q.length != v2.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v2.length);
        }
        if (q.length != v3.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v3.length);
        }
        if (distances.length != 4) {
            throw new IllegalArgumentException("distances array must have length 4, but was: " + distances.length);
        }
        IMPL.squareDistanceBulk(q, v0, v1, v2, v3, distances);
    }

    public static void soarDistanceBulk(float[] v1, float[] c0, float[] c1, float[] c2, float[] c3, float[] originalResidual, float soarLambda, float rnorm, float[] distances) {
        if (v1.length != c0.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c0.length);
        }
        if (v1.length != c1.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c1.length);
        }
        if (v1.length != c2.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c2.length);
        }
        if (v1.length != c3.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c3.length);
        }
        if (v1.length != originalResidual.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + originalResidual.length);
        }
        if (distances.length != 4) {
            throw new IllegalArgumentException("distances array must have length 4, but was: " + distances.length);
        }
        IMPL.soarDistanceBulk(v1, c0, c1, c2, c3, originalResidual, soarLambda, rnorm, distances);
    }

    public static void packAsBinary(int[] vector, byte[] packed) {
        if (packed.length * 8 < vector.length) {
            throw new IllegalArgumentException("packed array is too small: " + packed.length * 8 + " < " + vector.length);
        }
        IMPL.packAsBinary(vector, packed);
    }

    public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
        if (quantQueryByte.length * 8 < 4 * q.length) {
            throw new IllegalArgumentException("packed array is too small: " + quantQueryByte.length * 8 + " < " + 4 * q.length);
        }
        IMPL.transposeHalfByte(q, quantQueryByte);
    }

    public static int indexOf(byte[] bytes, int offset, int length, byte marker) {
        Objects.checkFromIndexSize(offset, length, bytes.length);
        return IMPL.indexOf(bytes, offset, length, marker);
    }

    static {
        try {
            BIT_COUNT_MH = Constants.OS_ARCH.equals("aarch64") ? MethodHandles.lookup().findStatic(ESVectorUtil.class, "andBitCountInt", MethodType.methodType(Integer.TYPE, byte[].class, byte[].class)) : MethodHandles.lookup().findStatic(ESVectorUtil.class, "andBitCountLong", MethodType.methodType(Integer.TYPE, byte[].class, byte[].class));
        }
        catch (IllegalAccessException | NoSuchMethodException e) {
            throw new AssertionError((Object)e);
        }
        IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();
    }
}

