/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.gpu.codec;

import com.nvidia.cuvs.CagraIndexParams;
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.spi.CuVSProvider;
import java.nio.file.Path;
import java.util.Objects;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import org.elasticsearch.core.Strings;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.gpu.GPUSupport;
import org.elasticsearch.xpack.gpu.codec.GPUMemoryService;
import org.elasticsearch.xpack.gpu.codec.RealGPUMemoryService;

public interface CuVSResourceManager {
    public static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0;

    public ManagedCuVSResources acquire(int var1, int var2, CuVSMatrix.DataType var3, CagraIndexParams var4) throws InterruptedException;

    public void finishedComputation(ManagedCuVSResources var1);

    public void release(ManagedCuVSResources var1);

    public void shutdown();

    public static long estimateNNDescentMemory(int numVectors, int dims, CuVSMatrix.DataType dataType) {
        int elementTypeBytes = switch (dataType) {
            default -> throw new MatchException(null, null);
            case CuVSMatrix.DataType.FLOAT -> 4;
            case CuVSMatrix.DataType.INT, CuVSMatrix.DataType.UINT -> 4;
            case CuVSMatrix.DataType.BYTE -> 1;
        };
        return (long)(2.0 * (double)numVectors * (double)dims * (double)elementTypeBytes);
    }

    public static CuVSResourceManager pooling() {
        return PoolingCuVSResourceManager.Holder.INSTANCE;
    }

    public static class PoolingCuVSResourceManager
    implements CuVSResourceManager {
        static final Logger logger = LogManager.getLogger(CuVSResourceManager.class);
        static final int MAX_RESOURCES = 4;
        private final ManagedCuVSResources[] pool;
        private final int capacity;
        private final GPUMemoryService gpuMemoryService;
        private int createdCount;
        ReentrantLock lock = new ReentrantLock();
        Condition enoughResourcesCondition = this.lock.newCondition();

        PoolingCuVSResourceManager(int capacity, GPUMemoryService gpuMemoryService) {
            if (capacity < 1 || capacity > 4) {
                throw new IllegalArgumentException("Resource count must be between 1 and 4");
            }
            this.capacity = capacity;
            this.gpuMemoryService = gpuMemoryService;
            this.pool = new ManagedCuVSResources[4];
        }

        private ManagedCuVSResources getResourceFromPool() {
            for (int i = 0; i < this.createdCount; ++i) {
                ManagedCuVSResources res = this.pool[i];
                if (res.isLocked()) continue;
                return res;
            }
            if (this.createdCount < this.capacity) {
                ManagedCuVSResources res = new ManagedCuVSResources(Objects.requireNonNull(this.createNew()));
                this.pool[this.createdCount++] = res;
                return res;
            }
            return null;
        }

        private int numLockedResources() {
            int lockedResources = 0;
            for (int i = 0; i < this.createdCount; ++i) {
                ManagedCuVSResources res = this.pool[i];
                if (!res.isLocked()) continue;
                ++lockedResources;
            }
            return lockedResources;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams) throws InterruptedException {
            try {
                long started = System.nanoTime();
                this.lock.lock();
                boolean allConditionsMet = false;
                ManagedCuVSResources res = null;
                long requiredMemoryInBytes = this.estimateRequiredMemory(numVectors, dims, dataType, cagraIndexParams);
                logger.debug("Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]", new Object[]{numVectors, dims, dataType.name(), requiredMemoryInBytes});
                while (!allConditionsMet) {
                    boolean enoughMemory;
                    res = this.getResourceFromPool();
                    if (res != null) {
                        long totalMemoryInBytes = this.gpuMemoryService.totalMemoryInBytes(res);
                        if (requiredMemoryInBytes > totalMemoryInBytes) {
                            String message = Strings.format((String)"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]", (Object[])new Object[]{numVectors, dims, totalMemoryInBytes});
                            logger.error(message);
                            throw new IllegalArgumentException(message);
                        }
                        if (this.numLockedResources() == 0) {
                            logger.debug("No resources currently locked, proceeding");
                            break;
                        }
                        long availableMemoryInBytes = this.gpuMemoryService.availableMemoryInBytes(res);
                        enoughMemory = requiredMemoryInBytes <= availableMemoryInBytes;
                        logger.debug("Free device memory [{} B], enoughMemory[{}]", new Object[]{availableMemoryInBytes, enoughMemory});
                    } else {
                        logger.debug("No resources available in pool");
                        enoughMemory = false;
                    }
                    if (allConditionsMet = enoughMemory) continue;
                    this.enoughResourcesCondition.await();
                }
                long elapsed = started - System.nanoTime();
                logger.debug("Resource acquired in [{}ms]", new Object[]{(double)elapsed / 1000000.0});
                this.gpuMemoryService.reserveMemory(requiredMemoryInBytes);
                res.lock(() -> this.gpuMemoryService.releaseMemory(requiredMemoryInBytes));
                ManagedCuVSResources managedCuVSResources = res;
                return managedCuVSResources;
            }
            finally {
                this.lock.unlock();
            }
        }

        private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams) {
            if (cagraIndexParams.getCagraGraphBuildAlgo() == CagraIndexParams.CagraGraphBuildAlgo.IVF_PQ && cagraIndexParams.getCuVSIvfPqParams() != null && cagraIndexParams.getCuVSIvfPqParams().getIndexParams() != null && cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqDim() != 0) {
                int elementTypeBytes = switch (dataType) {
                    default -> throw new MatchException(null, null);
                    case CuVSMatrix.DataType.FLOAT -> 4;
                    case CuVSMatrix.DataType.INT, CuVSMatrix.DataType.UINT -> 4;
                    case CuVSMatrix.DataType.BYTE -> 1;
                };
                int pqDim = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqDim();
                int pqBits = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqBits();
                int numClusters = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getnLists();
                double approximatedIvfBytes = (double)numVectors * ((double)pqDim * ((double)pqBits / 8.0) + (double)elementTypeBytes) + (double)((long)numClusters * 4L);
                return (long)(2.0 * approximatedIvfBytes);
            }
            return CuVSResourceManager.estimateNNDescentMemory(numVectors, dims, dataType);
        }

        protected CuVSResources createNew() {
            return GPUSupport.cuVSResourcesOrNull(true);
        }

        @Override
        public void finishedComputation(ManagedCuVSResources resources) {
            logger.debug("Computation finished");
        }

        @Override
        public void release(ManagedCuVSResources resources) {
            logger.debug("Releasing resources to pool");
            try {
                this.lock.lock();
                assert (resources.isLocked());
                resources.unlock();
                this.enoughResourcesCondition.signalAll();
            }
            finally {
                this.lock.unlock();
            }
        }

        @Override
        public void shutdown() {
            for (int i = 0; i < this.createdCount; ++i) {
                ManagedCuVSResources res = this.pool[i];
                assert (res != null);
                res.delegate.close();
            }
        }

        static class Holder {
            static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(4, new RealGPUMemoryService(CuVSProvider.provider().gpuInfoProvider()));

            Holder() {
            }
        }
    }

    public static final class ManagedCuVSResources
    implements CuVSResources {
        private final CuVSResources delegate;
        private static final Runnable NOT_LOCKED = () -> {};
        private Runnable unlockAction = NOT_LOCKED;

        ManagedCuVSResources(CuVSResources resources) {
            this.delegate = resources;
        }

        public CuVSResources.ScopedAccess access() {
            return this.delegate.access();
        }

        public int deviceId() {
            return this.delegate.deviceId();
        }

        public void close() {
            throw new UnsupportedOperationException("this resource is managed, cannot be closed by clients");
        }

        public Path tempDirectory() {
            return null;
        }

        public String toString() {
            return "ManagedCuVSResources[delegate=" + String.valueOf(this.delegate) + "]";
        }

        void lock(Runnable unlockAction) {
            this.unlockAction = unlockAction;
        }

        void unlock() {
            this.unlockAction.run();
            this.unlockAction = NOT_LOCKED;
        }

        boolean isLocked() {
            return this.unlockAction != NOT_LOCKED;
        }
    }
}

