/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.gpu.codec;

import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.spi.CuVSProvider;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.elasticsearch.xpack.gpu.codec.DatasetUtils;

public class DatasetUtilsImpl
implements DatasetUtils {
    private static final DatasetUtils INSTANCE = new DatasetUtilsImpl();
    private static final MethodHandle createDataset$mh = CuVSProvider.provider().newNativeMatrixBuilder();
    private static final MethodHandle createDatasetWithStrides$mh = CuVSProvider.provider().newNativeMatrixBuilderWithStrides();

    static DatasetUtils getInstance() {
        return INSTANCE;
    }

    static CuVSMatrix fromMemorySegment(MemorySegment memorySegment, int size, int dimensions, CuVSMatrix.DataType dataType) {
        try {
            return createDataset$mh.invokeExact(memorySegment, size, dimensions, dataType);
        }
        catch (Throwable e) {
            if (e instanceof Error) {
                Error err = (Error)e;
                throw err;
            }
            if (e instanceof RuntimeException) {
                RuntimeException re = (RuntimeException)e;
                throw re;
            }
            throw new RuntimeException(e);
        }
    }

    static CuVSMatrix fromMemorySegment(MemorySegment memorySegment, int size, int dimensions, int rowStride, int columnStride, CuVSMatrix.DataType dataType) {
        try {
            return createDatasetWithStrides$mh.invokeExact(memorySegment, size, dimensions, rowStride, columnStride, dataType);
        }
        catch (Throwable e) {
            if (e instanceof Error) {
                Error err = (Error)e;
                throw err;
            }
            if (e instanceof RuntimeException) {
                RuntimeException re = (RuntimeException)e;
                throw re;
            }
            throw new RuntimeException(e);
        }
    }

    private DatasetUtilsImpl() {
    }

    @Override
    public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException {
        if (numVectors < 0 || dims < 0) {
            DatasetUtilsImpl.throwIllegalArgumentException(numVectors, dims);
        }
        return DatasetUtilsImpl.createCuVSMatrix(input, 0L, input.length(), numVectors, dims, dataType);
    }

    @Override
    public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, int rowStride, int columnStride, CuVSMatrix.DataType dataType) throws IOException {
        if (numVectors < 0 || dims < 0) {
            DatasetUtilsImpl.throwIllegalArgumentException(numVectors, dims);
        }
        return DatasetUtilsImpl.createCuVSMatrix(input, 0L, input.length(), numVectors, dims, rowStride, columnStride, dataType);
    }

    @Override
    public CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException {
        if (pos < 0L || len < 0L) {
            throw new IllegalArgumentException("pos and len must be positive");
        }
        return DatasetUtilsImpl.createCuVSMatrix(input, pos, len, numVectors, dims, dataType);
    }

    private static CuVSMatrix createCuVSMatrix(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException {
        int byteSize;
        MemorySegment ms = input.segmentSliceOrNull(pos, len);
        assert (ms != null);
        int n = byteSize = dataType == CuVSMatrix.DataType.FLOAT ? 4 : 1;
        if ((long)numVectors * (long)dims * (long)byteSize > ms.byteSize()) {
            DatasetUtilsImpl.throwIllegalArgumentException(ms, numVectors, dims);
        }
        return DatasetUtilsImpl.fromMemorySegment(ms, numVectors, dims, dataType);
    }

    private static CuVSMatrix createCuVSMatrix(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, int rowStride, int columnStride, CuVSMatrix.DataType dataType) throws IOException {
        int byteSize;
        MemorySegment ms = input.segmentSliceOrNull(pos, len);
        assert (ms != null);
        int n = byteSize = dataType == CuVSMatrix.DataType.FLOAT ? 4 : 1;
        if ((long)numVectors * (long)rowStride * (long)byteSize > ms.byteSize()) {
            DatasetUtilsImpl.throwIllegalArgumentException(ms, numVectors, dims);
        }
        return DatasetUtilsImpl.fromMemorySegment(ms, numVectors, dims, rowStride, columnStride, dataType);
    }

    static void throwIllegalArgumentException(MemorySegment ms, int numVectors, int dims) {
        String s = "segment of size [" + ms.byteSize() + "] too small for expected " + numVectors + " float vectors of " + dims + " dims";
        throw new IllegalArgumentException(s);
    }

    static void throwIllegalArgumentException(int numVectors, int dims) {
        String s = numVectors < 0 ? "negative number of vectors: " + numVectors : "negative vector dims: " + dims;
        throw new IllegalArgumentException(s);
    }
}

