/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.cuvs.internal;

import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.HnswIndex;
import com.nvidia.cuvs.HnswIndexParams;
import com.nvidia.cuvs.HnswQuery;
import com.nvidia.cuvs.HnswSearchParams;
import com.nvidia.cuvs.SearchResults;
import com.nvidia.cuvs.internal.CuVSParamsHelper;
import com.nvidia.cuvs.internal.HnswSearchResults;
import com.nvidia.cuvs.internal.common.CloseableHandle;
import com.nvidia.cuvs.internal.common.LinkerHelper;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.cuvsHnswIndexParams;
import com.nvidia.cuvs.internal.panama.cuvsHnswSearchParams;
import com.nvidia.cuvs.internal.panama.headers_h;
import java.io.InputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.util.Objects;

public class HnswIndexImpl
implements HnswIndex {
    private final CuVSResources resources;
    private final HnswIndexParams hnswIndexParams;
    private final IndexReference hnswIndexReference;

    private HnswIndexImpl(InputStream inputStream, CuVSResources resources, HnswIndexParams hnswIndexParams) throws Throwable {
        this.hnswIndexParams = hnswIndexParams;
        this.resources = resources;
        this.hnswIndexReference = this.deserialize(inputStream);
    }

    @Override
    public void close() {
        int returnValue = headers_h.cuvsHnswIndexDestroy(this.hnswIndexReference.getMemorySegment());
        Util.checkCuVSError(returnValue, "cuvsHnswIndexDestroy");
    }

    @Override
    public SearchResults search(HnswQuery query) throws Throwable {
        try (Arena localArena = Arena.ofConfined();){
            int topK = query.getTopK();
            float[][] queryVectors = query.getQueryVectors();
            int numQueries = queryVectors.length;
            long numBlocks = (long)topK * (long)numQueries;
            int vectorDimension = numQueries > 0 ? queryVectors[0].length : 0;
            SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_LONG);
            SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, LinkerHelper.C_FLOAT);
            MemorySegment neighborsMemorySegment = localArena.allocate(neighborsSequenceLayout);
            MemorySegment distancesMemorySegment = localArena.allocate(distancesSequenceLayout);
            MemorySegment querySeg = Util.buildMemorySegment(localArena, queryVectors);
            long[] queriesShape = new long[]{numQueries, vectorDimension};
            MemorySegment queriesTensor = Util.prepareTensor(localArena, querySeg, queriesShape, headers_h.kDLFloat(), 32, headers_h.kDLCPU());
            long[] neighborsShape = new long[]{numQueries, topK};
            MemorySegment neighborsTensor = Util.prepareTensor(localArena, neighborsMemorySegment, neighborsShape, headers_h.kDLUInt(), 64, headers_h.kDLCPU());
            long[] distancesShape = new long[]{numQueries, topK};
            MemorySegment distancesTensor = Util.prepareTensor(localArena, distancesMemorySegment, distancesShape, headers_h.kDLFloat(), 32, headers_h.kDLCPU());
            try (CuVSResources.ScopedAccess resourcesAccessor = query.getResources().access();){
                long cuvsRes = resourcesAccessor.handle();
                int returnValue = headers_h.cuvsStreamSync(cuvsRes);
                Util.checkCuVSError(returnValue, "cuvsStreamSync");
                returnValue = headers_h.cuvsHnswSearch(cuvsRes, HnswIndexImpl.segmentFromSearchParams(localArena, query.getHnswSearchParams()), this.hnswIndexReference.getMemorySegment(), queriesTensor, neighborsTensor, distancesTensor);
                Util.checkCuVSError(returnValue, "cuvsHnswSearch");
                returnValue = headers_h.cuvsStreamSync(cuvsRes);
                Util.checkCuVSError(returnValue, "cuvsStreamSync");
            }
            SearchResults searchResults = HnswSearchResults.create(neighborsSequenceLayout, distancesSequenceLayout, neighborsMemorySegment, distancesMemorySegment, topK, query.getMapping(), numQueries);
            return searchResults;
        }
    }

    private static IndexReference createHnswIndex() {
        try (Arena localArena = Arena.ofConfined();){
            MemorySegment indexPtrPtr = localArena.allocate(headers_h.cuvsHnswIndex_t);
            int returnValue = headers_h.cuvsHnswIndexCreate(indexPtrPtr);
            Util.checkCuVSError(returnValue, "cuvsHnswIndexCreate");
            IndexReference indexReference = new IndexReference(indexPtrPtr.get(headers_h.cuvsHnswIndex_t, 0L));
            return indexReference;
        }
    }

    /*
     * Exception decompiling
     */
    private IndexReference deserialize(InputStream inputStream) throws Throwable {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 3 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private CloseableHandle segmentFromIndexParams(HnswIndexParams params) {
        CloseableHandle hnswParams = CuVSParamsHelper.createHnswIndexParams();
        cuvsHnswIndexParams.ef_construction(hnswParams.handle(), params.getEfConstruction());
        cuvsHnswIndexParams.num_threads(hnswParams.handle(), params.getNumThreads());
        return hnswParams;
    }

    private static MemorySegment segmentFromSearchParams(Arena arena, HnswSearchParams params) {
        MemorySegment seg = cuvsHnswSearchParams.allocate(arena);
        cuvsHnswSearchParams.ef(seg, params.ef());
        cuvsHnswSearchParams.num_threads(seg, params.numThreads());
        return seg;
    }

    public static HnswIndex.Builder newBuilder(CuVSResources cuvsResources) {
        return new Builder(Objects.requireNonNull(cuvsResources));
    }

    protected static class IndexReference {
        private final MemorySegment memorySegment;

        protected IndexReference(MemorySegment indexMemorySegment) {
            this.memorySegment = indexMemorySegment;
        }

        protected MemorySegment getMemorySegment() {
            return this.memorySegment;
        }
    }

    public static class Builder
    implements HnswIndex.Builder {
        private final CuVSResources cuvsResources;
        private InputStream inputStream;
        private HnswIndexParams hnswIndexParams;

        public Builder(CuVSResources cuvsResources) {
            this.cuvsResources = cuvsResources;
        }

        @Override
        public Builder from(InputStream inputStream) {
            this.inputStream = inputStream;
            return this;
        }

        @Override
        public Builder withIndexParams(HnswIndexParams hnswIndexParameters) {
            this.hnswIndexParams = hnswIndexParameters;
            return this;
        }

        @Override
        public HnswIndexImpl build() throws Throwable {
            return new HnswIndexImpl(this.inputStream, this.cuvsResources, this.hnswIndexParams);
        }
    }
}

