/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.CuVSDeviceMatrix;
import com.nvidia.cuvs.CuVSHostMatrix;
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.RowView;
import com.nvidia.cuvs.internal.CuVSMatrixBaseImpl;
import com.nvidia.cuvs.internal.CuVSMatrixInternal;
import com.nvidia.cuvs.internal.SliceRowView;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.headers_h;
import java.lang.foreign.Arena;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.VarHandle;
import java.util.Locale;

public class CuVSHostMatrixImpl
extends CuVSMatrixBaseImpl
implements CuVSHostMatrix {
    protected final VarHandle accessor$vh;
    private final int rowStride;
    private final int columnStride;
    private final long rowBytes;
    private final long rowSize;
    private final long elementSize;

    public CuVSHostMatrixImpl(MemorySegment memorySegment, long size, long columns, CuVSMatrix.DataType dataType) {
        this(memorySegment, size, columns, -1, -1, dataType, CuVSHostMatrixImpl.valueLayoutFromType(dataType), CuVSHostMatrixImpl.sequenceLayoutFromType(size, columns, -1, dataType));
    }

    public CuVSHostMatrixImpl(MemorySegment memorySegment, long size, long columns, int rowStride, int columnStride, CuVSMatrix.DataType dataType) {
        this(memorySegment, size, columns, rowStride, columnStride, dataType, CuVSHostMatrixImpl.valueLayoutFromType(dataType), CuVSHostMatrixImpl.sequenceLayoutFromType(size, columns, rowStride, dataType));
    }

    protected CuVSHostMatrixImpl(MemorySegment memorySegment, long size, long columns, int rowStride, int columnStride, CuVSMatrix.DataType dataType, ValueLayout valueLayout, MemoryLayout sequenceLayout) {
        super(memorySegment, dataType, valueLayout, size, columns);
        if (rowStride > 0 && (long)rowStride < columns) {
            throw new IllegalArgumentException("Row stride cannot be less than the number of columns");
        }
        this.rowStride = rowStride;
        this.columnStride = columnStride;
        this.elementSize = valueLayout.byteSize();
        this.rowSize = rowStride > 0 ? (long)rowStride * this.elementSize : columns * this.elementSize;
        this.rowBytes = columns * this.elementSize;
        this.accessor$vh = sequenceLayout.varHandle(MemoryLayout.PathElement.sequenceElement());
    }

    @Override
    public RowView getRow(long index) {
        assert (index < this.size) : String.format(Locale.ROOT, "Index out of bound ([%d], size [%d])", index, this.size);
        long valueByteSize = this.valueLayout.byteSize();
        return new SliceRowView(this.memorySegment.asSlice(index * this.columns * valueByteSize, this.columns * valueByteSize), this.columns, this.valueLayout, this.dataType, valueByteSize);
    }

    @Override
    public void toArray(int[][] array) {
        assert ((long)array.length >= this.size) : String.format(Locale.ROOT, "Input array is not large enough (required: [%d], actual [%d])", this.size, array.length);
        assert (array.length == 0 || (long)array[0].length >= this.columns) : String.format(Locale.ROOT, "Input array is not wide enough (required: [%d], actual [%d])", this.columns, array[0].length);
        assert (this.dataType == CuVSMatrix.DataType.INT || this.dataType == CuVSMatrix.DataType.UINT) : String.format(Locale.ROOT, "Input array is of the wrong type for dataType [%s]", this.dataType.toString());
        int r = 0;
        while ((long)r < this.size) {
            MemorySegment.copy(this.memorySegment, this.valueLayout, (long)r * this.rowSize, array[r], 0, (int)this.columns);
            ++r;
        }
    }

    @Override
    public void toArray(float[][] array) {
        assert ((long)array.length >= this.size) : String.format(Locale.ROOT, "Input array is not large enough (required: [%d], actual [%d])", this.size, array.length);
        assert (array.length == 0 || (long)array[0].length >= this.columns) : String.format(Locale.ROOT, "Input array is not wide enough (required: [%d], actual [%d])", this.columns, array[0].length);
        assert (this.dataType == CuVSMatrix.DataType.FLOAT) : String.format(Locale.ROOT, "Input array is of the wrong type for dataType [%s]", this.dataType.toString());
        int r = 0;
        while ((long)r < this.size) {
            MemorySegment.copy(this.memorySegment, this.valueLayout, (long)r * this.rowSize, array[r], 0, (int)this.columns);
            ++r;
        }
    }

    @Override
    public void toArray(byte[][] array) {
        assert ((long)array.length >= this.size) : String.format(Locale.ROOT, "Input array is not large enough (required: [%d], actual [%d])", this.size, array.length);
        assert (array.length == 0 || (long)array[0].length >= this.columns) : String.format(Locale.ROOT, "Input array is not wide enough (required: [%d], actual [%d])", this.columns, array[0].length);
        assert (this.dataType == CuVSMatrix.DataType.BYTE) : String.format(Locale.ROOT, "Input array is of the wrong type for dataType [%s]", this.dataType.toString());
        int r = 0;
        while ((long)r < this.size) {
            MemorySegment.copy(this.memorySegment, this.valueLayout, (long)r * this.rowSize, array[r], 0, (int)this.columns);
            ++r;
        }
    }

    @Override
    public CuVSHostMatrix toHost() {
        return new CuVSHostMatrixDelegate(this);
    }

    @Override
    public void toHost(CuVSHostMatrix hostMatrix) {
        CuVSMatrixInternal targetMatrix = (CuVSMatrixInternal)((Object)hostMatrix);
        if (targetMatrix.columns() != this.columns || targetMatrix.size() != this.size) {
            throw new IllegalArgumentException("Source and target matrices must have the same dimensions");
        }
        if (targetMatrix.dataType() != this.dataType) {
            throw new IllegalArgumentException("Source and target matrices must have the same dataType");
        }
        if (this.rowStride <= 0 && targetMatrix.rowStride() <= 0L) {
            MemorySegment.copy(this.memorySegment, 0L, targetMatrix.memorySegment(), 0L, this.size * this.rowSize);
        } else {
            long targetRowSize = targetMatrix.rowStride() > 0L ? targetMatrix.rowStride() * this.elementSize : this.columns * this.elementSize;
            int r = 0;
            while ((long)r < this.size) {
                MemorySegment.copy(this.memorySegment, (long)r * this.rowSize, targetMatrix.memorySegment(), (long)r * targetRowSize, this.rowBytes);
                ++r;
            }
        }
    }

    @Override
    public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
        CuVSHostMatrixImpl.copyMatrix(this, (CuVSMatrixInternal)((Object)deviceMatrix), cuVSResources);
    }

    @Override
    public void close() {
    }

    @Override
    public int get(int row, int col) {
        long rowPitch = this.rowStride > 0 ? (long)this.rowStride : this.columns;
        return this.accessor$vh.get(this.memorySegment, 0L, (long)row * rowPitch + (long)col);
    }

    @Override
    public long rowStride() {
        return this.rowStride;
    }

    @Override
    public MemorySegment toTensor(Arena arena) {
        long[] lArray;
        if (this.rowStride >= 0) {
            long[] lArray2 = new long[2];
            lArray2[0] = this.rowStride;
            lArray = lArray2;
            lArray2[1] = this.columnStride;
        } else {
            lArray = null;
        }
        long[] strides = lArray;
        return Util.prepareTensor(arena, this.memorySegment, new long[]{this.size, this.columns}, strides, this.code(), this.bits(), headers_h.kDLCPU());
    }

    private static class CuVSHostMatrixDelegate
    implements CuVSHostMatrix,
    CuVSMatrixInternal {
        private final CuVSHostMatrixImpl hostMatrix;

        public CuVSHostMatrixDelegate(CuVSHostMatrixImpl cuVSHostMatrix) {
            this.hostMatrix = cuVSHostMatrix;
        }

        @Override
        public int get(int row, int col) {
            return this.hostMatrix.get(row, col);
        }

        @Override
        public long size() {
            return this.hostMatrix.size();
        }

        @Override
        public long columns() {
            return this.hostMatrix.columns();
        }

        @Override
        public CuVSMatrix.DataType dataType() {
            return this.hostMatrix.dataType();
        }

        @Override
        public RowView getRow(long row) {
            return this.hostMatrix.getRow(row);
        }

        @Override
        public void toArray(int[][] array) {
            this.hostMatrix.toArray(array);
        }

        @Override
        public void toArray(float[][] array) {
            this.hostMatrix.toArray(array);
        }

        @Override
        public void toArray(byte[][] array) {
            this.hostMatrix.toArray(array);
        }

        @Override
        public void toHost(CuVSHostMatrix hostMatrix) {
            this.hostMatrix.toHost(hostMatrix);
        }

        @Override
        public CuVSHostMatrix toHost() {
            return this;
        }

        @Override
        public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
            this.hostMatrix.toDevice(deviceMatrix, cuVSResources);
        }

        @Override
        public MemorySegment memorySegment() {
            return this.hostMatrix.memorySegment();
        }

        @Override
        public ValueLayout valueLayout() {
            return this.hostMatrix.valueLayout();
        }

        @Override
        public long rowStride() {
            return this.hostMatrix.rowStride();
        }

        @Override
        public MemorySegment toTensor(Arena arena) {
            return this.hostMatrix.toTensor(arena);
        }

        @Override
        public void close() {
        }
    }
}

