/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal.common;

import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.panama.DLDataType;
import com.nvidia.cuvs.internal.panama.DLDevice;
import com.nvidia.cuvs.internal.panama.DLManagedTensor;
import com.nvidia.cuvs.internal.panama.DLTensor;
import com.nvidia.cuvs.internal.panama.headers_h;
import com.nvidia.cuvs.internal.panama.headers_h_1;
import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.SymbolLookup;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.VarHandle;
import java.util.BitSet;

public class Util {
    public static final int CUVS_SUCCESS = headers_h.CUVS_SUCCESS();
    public static final int CUDA_SUCCESS = 0;
    private static final Linker LINKER = Linker.nativeLinker();
    static final SymbolLookup SYMBOL_LOOKUP = SymbolLookup.libraryLookup(System.mapLibraryName("cuvs_c"), Arena.ofAuto()).or(SymbolLookup.loaderLookup()).or(Linker.nativeLinker().defaultLookup());
    private static final MethodHandle cudaMemcpyAsync$mh = LINKER.downcallHandle(headers_h.cudaMemcpyAsync$address(), headers_h.cudaMemcpyAsync$descriptor(), Linker.Option.critical((boolean)true));
    private static final MethodHandle cudaGetDeviceProperties$mh = LINKER.downcallHandle(SYMBOL_LOOKUP.find("cudaGetDeviceProperties").or(() -> SYMBOL_LOOKUP.find("cudaGetDeviceProperties_v2")).orElseThrow(UnsatisfiedLinkError::new), FunctionDescriptor.of(headers_h.C_INT, headers_h.C_POINTER, headers_h.C_INT), new Linker.Option[0]);
    static final long MAX_ERROR_TEXT = 1000000L;

    private Util() {
    }

    public static int cudaGetDeviceProperties(MemorySegment prop, int device) {
        try {
            return cudaGetDeviceProperties$mh.invokeExact(prop, device);
        }
        catch (Throwable ex$) {
            throw new AssertionError("should not reach here", ex$);
        }
    }

    public static void checkCuVSError(int value, String caller) {
        if (value != CUVS_SUCCESS) {
            String errorMsg = Util.getLastErrorText();
            throw new RuntimeException(caller + " returned " + value + "[" + errorMsg + "]");
        }
    }

    public static void checkCudaError(int value, String caller) {
        if (value != 0) {
            throw new RuntimeException(caller + " returned " + value);
        }
    }

    public static void cudaMemcpy(MemorySegment dest, MemorySegment src, long numBytes, CudaMemcpyKind kind) {
        int returnValue = headers_h.cudaMemcpy(dest, src, numBytes, kind.kind);
        Util.checkCudaError(returnValue, "cudaMemcpy");
    }

    public static void cudaMemcpy(MemorySegment dest, MemorySegment src, long numBytes) {
        Util.cudaMemcpy(dest, src, numBytes, CudaMemcpyKind.INFER_DIRECTION);
    }

    public static void cudaMemcpyAsync(MemorySegment dst, MemorySegment src, long numBytes, CudaMemcpyKind kind, MemorySegment stream) {
        try {
            int returnValue = cudaMemcpyAsync$mh.invokeExact(dst, src, numBytes, kind.kind, stream);
            Util.checkCudaError(returnValue, "cudaMemcpyAsync");
        }
        catch (Throwable ex$) {
            throw new AssertionError("should not reach here", ex$);
        }
    }

    public static MemorySegment getStream(CuVSResources resources) {
        try (CuVSResources.ScopedAccess resourcesAccess = resources.access();){
            MemorySegment memorySegment = Util.getStream(resourcesAccess.handle());
            return memorySegment;
        }
    }

    public static MemorySegment getStream(long resourcesHandle) {
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment streamPointer = localArena.allocate(headers_h_1.cudaStream_t);
            Util.checkCuVSError(headers_h.cuvsStreamGet(resourcesHandle, streamPointer), "cuvsStreamGet");
            MemorySegment memorySegment = streamPointer.get(headers_h_1.cudaStream_t, 0L);
            return memorySegment;
        }
    }

    static String getLastErrorText() {
        try {
            MemorySegment seg = headers_h.cuvsGetLastErrorText.makeInvoker(new MemoryLayout[0]).apply(new Object[0]);
            if (seg.equals(MemorySegment.NULL)) {
                return "no last error text";
            }
            return seg.reinterpret(1000000L).getString(0L);
        }
        catch (Throwable t) {
            throw new RuntimeException(t);
        }
    }

    public static MemorySegment buildMemorySegment(Arena arena, String str) {
        StringBuilder sb = new StringBuilder(str).append('\u0000');
        SequenceLayout stringMemoryLayout = MemoryLayout.sequenceLayout(sb.length(), LinkerHelper.C_CHAR);
        MemorySegment stringMemorySegment = arena.allocate(stringMemoryLayout);
        for (int i = 0; i < sb.length(); ++i) {
            VarHandle varHandle = stringMemoryLayout.varHandle(MemoryLayout.PathElement.sequenceElement(i));
            varHandle.set(stringMemorySegment, 0L, (byte)sb.charAt(i));
        }
        return stringMemorySegment;
    }

    public static MemorySegment buildMemorySegment(Arena arena, long[] data) {
        int cells = data.length;
        SequenceLayout dataMemoryLayout = MemoryLayout.sequenceLayout(cells, LinkerHelper.C_LONG);
        MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
        MemorySegment.copy(data, 0, dataMemorySegment, LinkerHelper.C_LONG, 0L, cells);
        return dataMemorySegment;
    }

    public static MemorySegment buildMemorySegment(Arena arena, byte[] data) {
        int cells = data.length;
        SequenceLayout dataMemoryLayout = MemoryLayout.sequenceLayout(cells, LinkerHelper.C_CHAR);
        MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
        MemorySegment.copy(data, 0, dataMemorySegment, LinkerHelper.C_CHAR, 0L, cells);
        return dataMemorySegment;
    }

    public static MemorySegment buildMemorySegment(Arena arena, float[][] data) {
        long rows = data.length;
        long cols = rows > 0L ? (long)data[0].length : 0L;
        SequenceLayout dataMemoryLayout = MemoryLayout.sequenceLayout(rows * cols, LinkerHelper.C_FLOAT);
        MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout);
        Util.copy(dataMemorySegment, data);
        return dataMemorySegment;
    }

    public static void copy(MemorySegment memorySegment, float[][] data) {
        int rows = data.length;
        int cols = rows > 0 ? data[0].length : 0;
        for (int r = 0; r < rows; ++r) {
            MemorySegment.copy(data[r], 0, memorySegment, LinkerHelper.C_FLOAT, (long)(r * cols) * LinkerHelper.C_FLOAT.byteSize(), cols);
        }
    }

    public static void copy(MemorySegment memorySegment, int[][] data) {
        int rows = data.length;
        int cols = rows > 0 ? data[0].length : 0;
        for (int r = 0; r < rows; ++r) {
            MemorySegment.copy(data[r], 0, memorySegment, LinkerHelper.C_INT, (long)(r * cols) * LinkerHelper.C_INT.byteSize(), cols);
        }
    }

    public static void copy(MemorySegment memorySegment, byte[][] data) {
        int rows = data.length;
        int cols = rows > 0 ? data[0].length : 0;
        for (int r = 0; r < rows; ++r) {
            MemorySegment.copy(data[r], 0, memorySegment, LinkerHelper.C_CHAR, (long)(r * cols) * LinkerHelper.C_CHAR.byteSize(), cols);
        }
    }

    public static BitSet concatenate(BitSet[] arr, int maxSizeOfEachBitSet) {
        BitSet ret = new BitSet(maxSizeOfEachBitSet * arr.length);
        for (int i = 0; i < arr.length; ++i) {
            BitSet b = arr[i];
            if (b == null || b.length() == 0) {
                ret.set(i * maxSizeOfEachBitSet, (i + 1) * maxSizeOfEachBitSet);
                continue;
            }
            for (int j = 0; j < maxSizeOfEachBitSet; ++j) {
                ret.set(i * maxSizeOfEachBitSet + j, b.get(j));
            }
        }
        return ret;
    }

    public static MemorySegment prepareTensor(Arena arena, MemorySegment data, long[] shape, int code, int bits, int deviceType) {
        return Util.prepareTensor(arena, data, shape, null, code, bits, deviceType);
    }

    public static MemorySegment prepareTensor(Arena arena, MemorySegment data, long[] shape, long[] strides, int code, int bits, int deviceType) {
        MemorySegment managedTensor = DLManagedTensor.allocate(arena);
        MemorySegment tensor = DLTensor.allocate(arena);
        DLTensor.data(tensor, data);
        MemorySegment dlDevice = DLDevice.allocate(arena);
        DLDevice.device_type(dlDevice, deviceType);
        DLTensor.device(tensor, dlDevice);
        MemorySegment dtype = DLDataType.allocate(arena);
        DLDataType.code(dtype, (byte)code);
        DLDataType.bits(dtype, (byte)bits);
        DLDataType.lanes(dtype, (short)1);
        DLTensor.dtype(tensor, dtype);
        DLTensor.ndim(tensor, shape.length);
        DLTensor.shape(tensor, Util.buildMemorySegment(arena, shape));
        if (strides != null) {
            assert (shape.length == strides.length);
            DLTensor.strides(tensor, Util.buildMemorySegment(arena, strides));
        } else {
            DLTensor.strides(tensor, MemorySegment.NULL);
        }
        DLManagedTensor.dl_tensor(managedTensor, tensor);
        assert (bits == DLDataType.bits(DLTensor.dtype(DLManagedTensor.dl_tensor(managedTensor))));
        return managedTensor;
    }

    public static enum CudaMemcpyKind {
        HOST_TO_HOST(headers_h.cudaMemcpyHostToHost()),
        HOST_TO_DEVICE(headers_h.cudaMemcpyHostToDevice()),
        DEVICE_TO_HOST(headers_h.cudaMemcpyDeviceToHost()),
        DEVICE_TO_DEVICE(headers_h.cudaMemcpyDeviceToDevice()),
        INFER_DIRECTION(4);

        public final int kind;

        private CudaMemcpyKind(int k) {
            this.kind = k;
        }
    }
}

