/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.BruteForceIndex;
import com.nvidia.cuvs.BruteForceIndexParams;
import com.nvidia.cuvs.BruteForceQuery;
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.SearchResults;
import com.nvidia.cuvs.internal.BruteForceSearchResults;
import com.nvidia.cuvs.internal.CuVSMatrixInternal;
import com.nvidia.cuvs.internal.common.CloseableRMMAllocation;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.cuvsFilter;
import com.nvidia.cuvs.internal.panama.headers_h;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.BitSet;
import java.util.Objects;
import java.util.UUID;

public class BruteForceIndexImpl
implements BruteForceIndex {
    private final CuVSResources resources;
    private final IndexReference bruteForceIndexReference;
    private boolean destroyed;

    private BruteForceIndexImpl(CuVSMatrix dataset, CuVSResources resources, BruteForceIndexParams bruteForceIndexParams) throws Exception {
        Objects.requireNonNull(dataset);
        try (CuVSMatrix cuVSMatrix = dataset;){
            this.resources = resources;
            assert (dataset instanceof CuVSMatrixInternal);
            this.bruteForceIndexReference = this.build((CuVSMatrixInternal)dataset, bruteForceIndexParams);
        }
    }

    private BruteForceIndexImpl(InputStream inputStream, CuVSResources resources) throws Throwable {
        this.resources = resources;
        this.bruteForceIndexReference = this.deserialize(inputStream);
    }

    private void checkNotDestroyed() {
        if (this.destroyed) {
            throw new IllegalStateException("destroyed");
        }
    }

    @Override
    public void close() {
        this.checkNotDestroyed();
        try {
            int returnValue = headers_h.cuvsBruteForceIndexDestroy(this.bruteForceIndexReference.indexPtr);
            Util.checkCuVSError(returnValue, "cuvsBruteForceIndexDestroy");
            this.bruteForceIndexReference.close(this.resources);
        }
        finally {
            this.destroyed = true;
        }
    }

    /*
     * Exception decompiling
     */
    private IndexReference build(CuVSMatrixInternal dataset, BruteForceIndexParams bruteForceIndexParams) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    @Override
    public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
        try (Arena localArena = Arena.ofConfined();){
            long prefilterBytes;
            long prefilterDataLength;
            MemorySegment prefilterDataMemorySegment;
            this.checkNotDestroyed();
            long numQueries = cuvsQuery.getQueryVectors().length;
            long numBlocks = (long)cuvsQuery.getTopK() * numQueries;
            int vectorDimension = numQueries > 0L ? cuvsQuery.getQueryVectors()[0].length : 0;
            SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_LONG);
            SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_FLOAT);
            MemorySegment neighborsMemorySegment = localArena.allocate(neighborsSequenceLayout);
            MemorySegment distancesMemorySegment = localArena.allocate(distancesSequenceLayout);
            BitSet[] prefilters = cuvsQuery.getPrefilters();
            if (prefilters != null && prefilters.length > 0) {
                BitSet concatenatedFilters = Util.concatenate(prefilters, cuvsQuery.getNumDocs());
                long[] filters = concatenatedFilters.toLongArray();
                prefilterDataMemorySegment = Util.buildMemorySegment(localArena, filters);
                prefilterDataLength = (long)cuvsQuery.getNumDocs() * (long)prefilters.length;
                long[] prefilterShape = new long[]{(prefilterDataLength + 31L) / 32L};
                prefilterBytes = LinkerHelper.C_INT_BYTE_SIZE * prefilterShape[0];
            } else {
                prefilterDataLength = 0L;
                prefilterBytes = 0L;
                prefilterDataMemorySegment = MemorySegment.NULL;
            }
            MemorySegment querySeg = Util.buildMemorySegment(localArena, cuvsQuery.getQueryVectors());
            int topk = cuvsQuery.getTopK();
            try (CuVSResources.ScopedAccess resourcesAccessor = cuvsQuery.getResources().access();){
                long cuvsResources = resourcesAccessor.handle();
                long queriesBytes = LinkerHelper.C_FLOAT_BYTE_SIZE * numQueries * (long)vectorDimension;
                long neighborsBytes = LinkerHelper.C_LONG_BYTE_SIZE * numQueries * (long)topk;
                long distanceBytes = LinkerHelper.C_FLOAT_BYTE_SIZE * numQueries * (long)topk;
                try (CloseableRMMAllocation queriesDP = CloseableRMMAllocation.allocateRMMSegment(cuvsResources, queriesBytes);
                     CloseableRMMAllocation neighborsDP = CloseableRMMAllocation.allocateRMMSegment(cuvsResources, neighborsBytes);
                     CloseableRMMAllocation distancesDP = CloseableRMMAllocation.allocateRMMSegment(cuvsResources, distanceBytes);
                     CloseableRMMAllocation prefilterDP = prefilterBytes > 0L ? CloseableRMMAllocation.allocateRMMSegment(cuvsResources, prefilterBytes) : CloseableRMMAllocation.EMPTY;){
                    Util.cudaMemcpy(queriesDP.handle(), querySeg, queriesBytes, Util.CudaMemcpyKind.INFER_DIRECTION);
                    long[] queriesShape = new long[]{numQueries, vectorDimension};
                    MemorySegment queriesTensor = Util.prepareTensor(localArena, queriesDP.handle(), queriesShape, headers_h.kDLFloat(), 32, headers_h.kDLCUDA());
                    long[] neighborsShape = new long[]{numQueries, topk};
                    MemorySegment neighborsTensor = Util.prepareTensor(localArena, neighborsDP.handle(), neighborsShape, headers_h.kDLInt(), 64, headers_h.kDLCUDA());
                    long[] distancesShape = new long[]{numQueries, topk};
                    MemorySegment distancesTensor = Util.prepareTensor(localArena, distancesDP.handle(), distancesShape, headers_h.kDLFloat(), 32, headers_h.kDLCUDA());
                    MemorySegment prefilter = cuvsFilter.allocate(localArena);
                    if (prefilterDataMemorySegment == MemorySegment.NULL) {
                        cuvsFilter.type(prefilter, 0);
                        cuvsFilter.addr(prefilter, 0L);
                    } else {
                        long[] prefilterShape = new long[]{(prefilterDataLength + 31L) / 32L};
                        Util.cudaMemcpy(prefilterDP.handle(), prefilterDataMemorySegment, prefilterBytes, Util.CudaMemcpyKind.HOST_TO_DEVICE);
                        MemorySegment prefilterTensor = Util.prepareTensor(localArena, prefilterDP.handle(), prefilterShape, headers_h.kDLUInt(), 32, headers_h.kDLCUDA());
                        cuvsFilter.type(prefilter, 2);
                        cuvsFilter.addr(prefilter, prefilterTensor.address());
                    }
                    int returnValue = headers_h.cuvsStreamSync(cuvsResources);
                    Util.checkCuVSError(returnValue, "cuvsStreamSync");
                    returnValue = headers_h.cuvsBruteForceSearch(cuvsResources, this.bruteForceIndexReference.indexPtr, queriesTensor, neighborsTensor, distancesTensor, prefilter);
                    Util.checkCuVSError(returnValue, "cuvsBruteForceSearch");
                    returnValue = headers_h.cuvsStreamSync(cuvsResources);
                    Util.checkCuVSError(returnValue, "cuvsStreamSync");
                    Util.cudaMemcpy(neighborsMemorySegment, neighborsDP.handle(), neighborsBytes, Util.CudaMemcpyKind.INFER_DIRECTION);
                    Util.cudaMemcpy(distancesMemorySegment, distancesDP.handle(), distanceBytes, Util.CudaMemcpyKind.INFER_DIRECTION);
                }
            }
            SearchResults searchResults = BruteForceSearchResults.create(neighborsSequenceLayout, distancesSequenceLayout, neighborsMemorySegment, distancesMemorySegment, cuvsQuery.getTopK(), cuvsQuery.getMapping(), numQueries);
            return searchResults;
        }
    }

    @Override
    public void serialize(OutputStream outputStream) throws Throwable {
        Path path = Files.createTempFile(this.resources.tempDirectory(), UUID.randomUUID().toString(), ".bf", new FileAttribute[0]);
        this.serialize(outputStream, path);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void serialize(OutputStream outputStream, Path tempFile) throws Throwable {
        this.checkNotDestroyed();
        Path tempFilePath = tempFile.toAbsolutePath();
        try (Arena localArena = Arena.ofConfined();
             CuVSResources.ScopedAccess resourcesAccessor = this.resources.access();){
            int returnValue = headers_h.cuvsBruteForceSerialize(resourcesAccessor.handle(), localArena.allocateFrom(tempFilePath.toString()), this.bruteForceIndexReference.indexPtr);
            Util.checkCuVSError(returnValue, "cuvsBruteForceSerialize");
        }
        try (InputStream inputStream = Files.newInputStream(tempFilePath, new OpenOption[0]);){
            inputStream.transferTo(outputStream);
        }
        finally {
            Files.deleteIfExists(tempFile);
        }
    }

    private static MemorySegment createBruteForceIndex() {
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment indexPtrPtr = localArena.allocate(headers_h.cuvsBruteForceIndex_t);
            int returnValue = headers_h.cuvsBruteForceIndexCreate(indexPtrPtr);
            Util.checkCuVSError(returnValue, "cuvsBruteForceIndexCreate");
            MemorySegment memorySegment = indexPtrPtr.get(headers_h.cuvsBruteForceIndex_t, 0L);
            return memorySegment;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private IndexReference deserialize(InputStream inputStream) throws Throwable {
        this.checkNotDestroyed();
        Path tmpIndexFile = Files.createTempFile(this.resources.tempDirectory(), UUID.randomUUID().toString(), ".bf", new FileAttribute[0]).toAbsolutePath();
        IndexReference indexReference = new IndexReference(BruteForceIndexImpl.createBruteForceIndex());
        try (InputStream inputStream2 = inputStream;
             OutputStream outputStream = Files.newOutputStream(tmpIndexFile, new OpenOption[0]);
             Arena arena = Arena.ofConfined();
             CuVSResources.ScopedAccess resourcesAccessor = this.resources.access();){
            inputStream.transferTo(outputStream);
            int returnValue = headers_h.cuvsBruteForceDeserialize(resourcesAccessor.handle(), arena.allocateFrom(tmpIndexFile.toString()), indexReference.indexPtr);
            Util.checkCuVSError(returnValue, "cuvsBruteForceDeserialize");
        }
        finally {
            Files.deleteIfExists(tmpIndexFile);
        }
        return indexReference;
    }

    public static BruteForceIndex.Builder newBuilder(CuVSResources cuvsResources) {
        return new Builder(Objects.requireNonNull(cuvsResources));
    }

    private static class IndexReference {
        private final CloseableRMMAllocation datasetAllocationHandle;
        private final long datasetBytes;
        private final Arena tensorDataArena;
        private final MemorySegment indexPtr;

        private IndexReference(CloseableRMMAllocation datasetAllocationHandle, long datasetBytes, Arena tensorDataArena, MemorySegment indexPtr) {
            this.datasetAllocationHandle = datasetAllocationHandle;
            this.datasetBytes = datasetBytes;
            this.tensorDataArena = tensorDataArena;
            this.indexPtr = indexPtr;
        }

        private IndexReference(MemorySegment indexPtr) {
            this.datasetAllocationHandle = CloseableRMMAllocation.EMPTY;
            this.datasetBytes = 0L;
            this.tensorDataArena = null;
            this.indexPtr = indexPtr;
        }

        private void close(CuVSResources resources) {
            try (CuVSResources.ScopedAccess resourcesAccessor = resources.access();){
                this.datasetAllocationHandle.close();
            }
            if (this.tensorDataArena != null) {
                this.tensorDataArena.close();
            }
        }
    }

    public static class Builder
    implements BruteForceIndex.Builder {
        private CuVSMatrix dataset;
        private final CuVSResources cuvsResources;
        private BruteForceIndexParams bruteForceIndexParams;
        private InputStream inputStream;

        public Builder(CuVSResources cuvsResources) {
            this.cuvsResources = cuvsResources;
        }

        @Override
        public Builder withIndexParams(BruteForceIndexParams bruteForceIndexParams) {
            this.bruteForceIndexParams = bruteForceIndexParams;
            return this;
        }

        @Override
        public Builder from(InputStream inputStream) {
            this.inputStream = inputStream;
            return this;
        }

        @Override
        public Builder withDataset(float[][] vectors) {
            this.dataset = CuVSMatrix.ofArray(vectors);
            return this;
        }

        @Override
        public Builder withDataset(CuVSMatrix dataset) {
            this.dataset = dataset;
            return this;
        }

        @Override
        public BruteForceIndexImpl build() throws Throwable {
            if (this.inputStream != null) {
                return new BruteForceIndexImpl(this.inputStream, this.cuvsResources);
            }
            return new BruteForceIndexImpl(this.dataset, this.cuvsResources, this.bruteForceIndexParams);
        }
    }
}

