/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.nativeaccess.jdk;

import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.Objects;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
import org.elasticsearch.nativeaccess.jdk.LinkerHelper;
import org.elasticsearch.nativeaccess.jdk.LinkerHelperUtil;
import org.elasticsearch.nativeaccess.lib.LoaderHelper;
import org.elasticsearch.nativeaccess.lib.VectorLibrary;

public final class JdkVectorLibrary
implements VectorLibrary {
    static final Logger logger = LogManager.getLogger(JdkVectorLibrary.class);
    static final MethodHandle dot7u$mh;
    static final MethodHandle dot7uBulk$mh;
    static final MethodHandle sqr7u$mh;
    static final MethodHandle cosf32$mh;
    static final MethodHandle dotf32$mh;
    static final MethodHandle sqrf32$mh;
    public static final JdkVectorSimilarityFunctions INSTANCE;

    @Override
    public VectorSimilarityFunctions getVectorSimilarityFunctions() {
        return INSTANCE;
    }

    static {
        LoaderHelper.loadLibrary("vec");
        MethodHandle vecCaps$mh = LinkerHelper.downcallHandle("vec_caps", FunctionDescriptor.of(ValueLayout.JAVA_INT, new MemoryLayout[0]), new Linker.Option[0]);
        try {
            int caps = vecCaps$mh.invokeExact();
            logger.info("vec_caps=" + caps);
            if (caps > 0) {
                if (caps == 2) {
                    dot7u$mh = LinkerHelper.downcallHandle("vec_dot7u_2", FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), LinkerHelperUtil.critical());
                    dot7uBulk$mh = LinkerHelper.downcallHandle("vec_dot7u_bulk_2", FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT, ValueLayout.JAVA_INT, ValueLayout.ADDRESS), LinkerHelperUtil.critical());
                    sqr7u$mh = LinkerHelper.downcallHandle("vec_sqr7u_2", FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), LinkerHelperUtil.critical());
                    cosf32$mh = LinkerHelper.downcallHandle("vec_cosf32_2", FunctionDescriptor.of(ValueLayout.JAVA_FLOAT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), LinkerHelperUtil.critical());
                    dotf32$mh = LinkerHelper.downcallHandle("vec_dotf32_2", FunctionDescriptor.of(ValueLayout.JAVA_FLOAT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), LinkerHelperUtil.critical());
                    sqrf32$mh = LinkerHelper.downcallHandle("vec_sqrf32_2", FunctionDescriptor.of(ValueLayout.JAVA_FLOAT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), LinkerHelperUtil.critical());
                } else {
                    dot7u$mh = LinkerHelper.downcallHandle("vec_dot7u", FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), LinkerHelperUtil.critical());
                    dot7uBulk$mh = LinkerHelper.downcallHandle("vec_dot7u_bulk", FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT, ValueLayout.JAVA_INT, ValueLayout.ADDRESS), LinkerHelperUtil.critical());
                    sqr7u$mh = LinkerHelper.downcallHandle("vec_sqr7u", FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), LinkerHelperUtil.critical());
                    cosf32$mh = LinkerHelper.downcallHandle("vec_cosf32", FunctionDescriptor.of(ValueLayout.JAVA_FLOAT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), LinkerHelperUtil.critical());
                    dotf32$mh = LinkerHelper.downcallHandle("vec_dotf32", FunctionDescriptor.of(ValueLayout.JAVA_FLOAT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), LinkerHelperUtil.critical());
                    sqrf32$mh = LinkerHelper.downcallHandle("vec_sqrf32", FunctionDescriptor.of(ValueLayout.JAVA_FLOAT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT), LinkerHelperUtil.critical());
                }
                INSTANCE = new JdkVectorSimilarityFunctions();
            } else {
                if (caps < 0) {
                    logger.warn("Your CPU supports vector capabilities, but they are disabled at OS level. For optimal performance, enable them in your OS/Hypervisor/VM/container");
                }
                dot7u$mh = null;
                dot7uBulk$mh = null;
                sqr7u$mh = null;
                cosf32$mh = null;
                dotf32$mh = null;
                sqrf32$mh = null;
                INSTANCE = null;
            }
        }
        catch (Throwable t) {
            throw new AssertionError((Object)t);
        }
    }

    private static final class JdkVectorSimilarityFunctions
    implements VectorSimilarityFunctions {
        static final MethodHandle DOT_HANDLE_7U;
        static final MethodHandle DOT_HANDLE_7U_BULK;
        static final MethodHandle SQR_HANDLE_7U;
        static final MethodHandle COS_HANDLE_FLOAT32;
        static final MethodHandle DOT_HANDLE_FLOAT32;
        static final MethodHandle SQR_HANDLE_FLOAT32;

        private JdkVectorSimilarityFunctions() {
        }

        static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
            JdkVectorSimilarityFunctions.checkByteSize(a, b);
            Objects.checkFromIndexSize(0, length, (int)a.byteSize());
            return JdkVectorSimilarityFunctions.dot7u(a, b, length);
        }

        static void dotProduct7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) {
            Objects.checkFromIndexSize(0, length * count, (int)a.byteSize());
            Objects.checkFromIndexSize(0, length, (int)b.byteSize());
            Objects.checkFromIndexSize(0, count * 4, (int)result.byteSize());
            JdkVectorSimilarityFunctions.dot7uBulk(a, b, length, count, result);
        }

        static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
            JdkVectorSimilarityFunctions.checkByteSize(a, b);
            Objects.checkFromIndexSize(0, length, (int)a.byteSize());
            return JdkVectorSimilarityFunctions.sqr7u(a, b, length);
        }

        static float cosineF32(MemorySegment a, MemorySegment b, int elementCount) {
            JdkVectorSimilarityFunctions.checkByteSize(a, b);
            Objects.checkFromIndexSize(0, elementCount, (int)a.byteSize() / 4);
            return JdkVectorSimilarityFunctions.cosf32(a, b, elementCount);
        }

        static float dotProductF32(MemorySegment a, MemorySegment b, int elementCount) {
            JdkVectorSimilarityFunctions.checkByteSize(a, b);
            Objects.checkFromIndexSize(0, elementCount, (int)a.byteSize() / 4);
            return JdkVectorSimilarityFunctions.dotf32(a, b, elementCount);
        }

        static float squareDistanceF32(MemorySegment a, MemorySegment b, int elementCount) {
            JdkVectorSimilarityFunctions.checkByteSize(a, b);
            Objects.checkFromIndexSize(0, elementCount, (int)a.byteSize() / 4);
            return JdkVectorSimilarityFunctions.sqrf32(a, b, elementCount);
        }

        private static void checkByteSize(MemorySegment a, MemorySegment b) {
            if (a.byteSize() != b.byteSize()) {
                throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
            }
        }

        private static int dot7u(MemorySegment a, MemorySegment b, int length) {
            try {
                return dot7u$mh.invokeExact(a, b, length);
            }
            catch (Throwable t) {
                throw new AssertionError((Object)t);
            }
        }

        private static void dot7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) {
            try {
                dot7uBulk$mh.invokeExact(a, b, length, count, result);
            }
            catch (Throwable t) {
                throw new AssertionError((Object)t);
            }
        }

        private static int sqr7u(MemorySegment a, MemorySegment b, int length) {
            try {
                return sqr7u$mh.invokeExact(a, b, length);
            }
            catch (Throwable t) {
                throw new AssertionError((Object)t);
            }
        }

        private static float cosf32(MemorySegment a, MemorySegment b, int length) {
            try {
                return cosf32$mh.invokeExact(a, b, length);
            }
            catch (Throwable t) {
                throw new AssertionError((Object)t);
            }
        }

        private static float dotf32(MemorySegment a, MemorySegment b, int length) {
            try {
                return dotf32$mh.invokeExact(a, b, length);
            }
            catch (Throwable t) {
                throw new AssertionError((Object)t);
            }
        }

        private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
            try {
                return sqrf32$mh.invokeExact(a, b, length);
            }
            catch (Throwable t) {
                throw new AssertionError((Object)t);
            }
        }

        @Override
        public MethodHandle dotProductHandle7u() {
            return DOT_HANDLE_7U;
        }

        @Override
        public MethodHandle dotProductHandle7uBulk() {
            return DOT_HANDLE_7U_BULK;
        }

        @Override
        public MethodHandle squareDistanceHandle7u() {
            return SQR_HANDLE_7U;
        }

        @Override
        public MethodHandle cosineHandleFloat32() {
            return COS_HANDLE_FLOAT32;
        }

        @Override
        public MethodHandle dotProductHandleFloat32() {
            return DOT_HANDLE_FLOAT32;
        }

        @Override
        public MethodHandle squareDistanceHandleFloat32() {
            return SQR_HANDLE_FLOAT32;
        }

        static {
            try {
                MethodHandles.Lookup lookup = MethodHandles.lookup();
                MethodType mt = MethodType.methodType(Integer.TYPE, MemorySegment.class, MemorySegment.class, Integer.TYPE);
                DOT_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7u", mt);
                SQR_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance7u", mt);
                mt = MethodType.methodType(Void.TYPE, MemorySegment.class, MemorySegment.class, Integer.TYPE, Integer.TYPE, MemorySegment.class);
                DOT_HANDLE_7U_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7uBulk", mt);
                mt = MethodType.methodType(Float.TYPE, MemorySegment.class, MemorySegment.class, Integer.TYPE);
                COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt);
                DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", mt);
                SQR_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistanceF32", mt);
            }
            catch (IllegalAccessException | NoSuchMethodException e) {
                throw new RuntimeException(e);
            }
        }
    }
}

