/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.internal.CuVSDeviceMatrixImpl;
import com.nvidia.cuvs.internal.CuVSHostMatrixImpl;
import com.nvidia.cuvs.internal.CuVSMatrixInternal;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.DLDataType;
import com.nvidia.cuvs.internal.panama.DLDevice;
import com.nvidia.cuvs.internal.panama.DLManagedTensor;
import com.nvidia.cuvs.internal.panama.DLTensor;
import com.nvidia.cuvs.internal.panama.headers_h;
import java.lang.foreign.Arena;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.ValueLayout;
import java.util.Locale;

abstract class CuVSMatrixBaseImpl
implements CuVSMatrixInternal {
    protected final MemorySegment memorySegment;
    protected final CuVSMatrix.DataType dataType;
    protected final ValueLayout valueLayout;
    protected final long size;
    protected final long columns;

    protected CuVSMatrixBaseImpl(MemorySegment memorySegment, CuVSMatrix.DataType dataType, ValueLayout valueLayout, long size, long columns) {
        this.memorySegment = memorySegment;
        this.dataType = dataType;
        this.valueLayout = valueLayout;
        this.size = size;
        this.columns = columns;
    }

    protected static void copyMatrix(CuVSMatrixInternal sourceMatrix, CuVSMatrixInternal targetMatrix, CuVSResources resources) {
        if (targetMatrix.columns() != sourceMatrix.columns() || targetMatrix.size() != sourceMatrix.size()) {
            throw new IllegalArgumentException("Source and target matrices must have the same dimensions");
        }
        if (targetMatrix.dataType() != sourceMatrix.dataType()) {
            throw new IllegalArgumentException("Source and target matrices must have the same dataType");
        }
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment targetTensor = targetMatrix.toTensor(localArena);
            try (CuVSResources.ScopedAccess resourceAccess = resources.access();){
                long cuvsRes = resourceAccess.handle();
                MemorySegment sourceTensor = sourceMatrix.toTensor(localArena);
                Util.checkCuVSError(headers_h.cuvsMatrixCopy(cuvsRes, sourceTensor, targetTensor), "cuvsMatrixCopy");
                Util.checkCuVSError(headers_h.cuvsStreamSync(cuvsRes), "cuvsStreamSync");
            }
        }
    }

    @Override
    public long size() {
        return this.size;
    }

    @Override
    public long columns() {
        return this.columns;
    }

    @Override
    public CuVSMatrix.DataType dataType() {
        return this.dataType;
    }

    @Override
    public MemorySegment memorySegment() {
        return this.memorySegment;
    }

    @Override
    public ValueLayout valueLayout() {
        return this.valueLayout;
    }

    protected static ValueLayout valueLayoutFromType(CuVSMatrix.DataType dataType) {
        return switch (dataType) {
            default -> throw new MatchException(null, null);
            case CuVSMatrix.DataType.FLOAT -> LinkerHelper.C_FLOAT;
            case CuVSMatrix.DataType.INT, CuVSMatrix.DataType.UINT -> LinkerHelper.C_INT;
            case CuVSMatrix.DataType.BYTE -> LinkerHelper.C_CHAR;
        };
    }

    protected static SequenceLayout sequenceLayoutFromType(long size, long columns, int rowStride, CuVSMatrix.DataType dataType) {
        long elements = rowStride > 0 ? (long)rowStride * size : columns * size;
        return MemoryLayout.sequenceLayout(elements, CuVSMatrixBaseImpl.valueLayoutFromType(dataType)).withByteAlignment(32L);
    }

    public static CuVSMatrix fromTensor(MemorySegment dlManagedTensor, CuVSResources resources) {
        MemorySegment dlTensor = DLManagedTensor.dl_tensor(dlManagedTensor);
        MemorySegment dlDevice = DLTensor.device(dlTensor);
        int deviceType = DLDevice.device_type(dlDevice);
        MemorySegment data = DLTensor.data(dlTensor);
        if (data.equals(MemorySegment.NULL)) {
            throw new IllegalArgumentException("[data] must not be NULL");
        }
        int ndim = DLTensor.ndim(dlTensor);
        if (ndim != 2) {
            throw new IllegalArgumentException("CuVSMatrix only supports 2D data");
        }
        MemorySegment dtype = DLTensor.dtype(dlTensor);
        byte code = DLDataType.code(dtype);
        byte bits = DLDataType.bits(dtype);
        CuVSMatrix.DataType dataType = CuVSMatrixBaseImpl.dataTypeFromTensor(code, bits);
        MemorySegment shape = DLTensor.shape(dlTensor);
        if (shape.equals(MemorySegment.NULL)) {
            throw new IllegalArgumentException("[shape] must not be NULL");
        }
        long rows = shape.get(headers_h.int64_t, 0L);
        long cols = shape.getAtIndex(headers_h.int64_t, 1L);
        if (deviceType == headers_h.kDLCUDA()) {
            MemorySegment strides = DLTensor.strides(dlTensor);
            if (strides.equals(MemorySegment.NULL)) {
                return new CuVSDeviceMatrixImpl(resources, data, rows, cols, dataType, CuVSMatrixBaseImpl.valueLayoutFromType(dataType));
            }
            long rowStride = strides.get(headers_h.int64_t, 0L);
            long colStride = strides.getAtIndex(headers_h.int64_t, 1L);
            return new CuVSDeviceMatrixImpl(resources, data, rows, cols, rowStride, colStride, dataType, CuVSMatrixBaseImpl.valueLayoutFromType(dataType));
        }
        if (deviceType == headers_h.kDLCPU()) {
            return new CuVSHostMatrixImpl(data, rows, cols, dataType);
        }
        throw new IllegalArgumentException("Unsupported device type: " + deviceType);
    }

    private static CuVSMatrix.DataType dataTypeFromTensor(byte code, byte bits) {
        CuVSMatrix.DataType dataType;
        if (code == headers_h.kDLUInt() && bits == 32) {
            dataType = CuVSMatrix.DataType.UINT;
        } else if (code == headers_h.kDLInt() && bits == 32) {
            dataType = CuVSMatrix.DataType.INT;
        } else if (code == headers_h.kDLFloat() && bits == 32) {
            dataType = CuVSMatrix.DataType.FLOAT;
        } else if ((code == headers_h.kDLInt() || code == headers_h.kDLUInt()) && bits == 8) {
            dataType = CuVSMatrix.DataType.BYTE;
        } else {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported data type (code=%d, bits=%d)", code, bits));
        }
        return dataType;
    }
}

