/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.packageloader.action;

import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask;
import org.elasticsearch.xpack.ml.packageloader.action.ModelLoaderUtils;

public class ModelImporter {
    private static final int DEFAULT_CHUNK_SIZE = 0x100000;
    public static final int NUMBER_OF_STREAMS = 5;
    private static final Logger logger = LogManager.getLogger(ModelImporter.class);
    private final Client client;
    private final String modelId;
    private final ModelPackageConfig config;
    private final ModelDownloadTask task;
    private final ExecutorService executorService;
    private final AtomicInteger progressCounter = new AtomicInteger();
    private final URI uri;
    private final CircuitBreakerService breakerService;

    ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task, ThreadPool threadPool, CircuitBreakerService cbs) throws URISyntaxException {
        this.client = client;
        this.modelId = Objects.requireNonNull(modelId);
        this.config = Objects.requireNonNull(packageConfig);
        this.task = Objects.requireNonNull(task);
        this.executorService = threadPool.executor("model_download");
        this.uri = ModelLoaderUtils.resolvePackageLocation(this.config.getModelRepository(), this.config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION);
        this.breakerService = cbs;
    }

    public void doImport(ActionListener<AcknowledgedResponse> listener) {
        this.executorService.execute(() -> this.doImportInternal(listener));
    }

    private void doImportInternal(ActionListener<AcknowledgedResponse> finalListener) {
        assert (ThreadPool.assertCurrentThreadPool((String[])new String[]{"model_download"})) : Strings.format((String)"Model download must execute from [%s] but thread is [%s]", (Object[])new Object[]{"model_download", Thread.currentThread().getName()});
        ModelLoaderUtils.VocabularyParts vocabularyParts = null;
        try {
            if (this.config.getVocabularyFile() != null) {
                vocabularyParts = ModelLoaderUtils.loadVocabulary(ModelLoaderUtils.resolvePackageLocation(this.config.getModelRepository(), this.config.getVocabularyFile()));
            }
            int totalParts = (int)((this.config.getSize() + 0x100000L - 1L) / 0x100000L);
            if (!ModelLoaderUtils.uriIsFile(this.uri)) {
                this.breakerService.getBreaker("request").addEstimateBytesAndMaybeBreak(0x500000L, "model importer");
                ActionListener breakerFreeingListener = ActionListener.runAfter(finalListener, () -> this.breakerService.getBreaker("request").addWithoutBreaking(-5242880L));
                List<ModelLoaderUtils.RequestRange> ranges = ModelLoaderUtils.split(this.config.getSize(), 5, 0x100000L);
                ArrayList<ModelLoaderUtils.HttpStreamChunker> downloaders = new ArrayList<ModelLoaderUtils.HttpStreamChunker>(ranges.size());
                for (ModelLoaderUtils.RequestRange range : ranges) {
                    downloaders.add(new ModelLoaderUtils.HttpStreamChunker(this.uri, range, 0x100000));
                }
                this.downloadModelDefinition(this.config.getSize(), totalParts, vocabularyParts, downloaders, (ActionListener<AcknowledgedResponse>)breakerFreeingListener);
            } else {
                InputStream modelInputStream = ModelLoaderUtils.getFileInputStream(this.uri);
                ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker(modelInputStream, 0x100000);
                this.readModelDefinitionFromFile(this.config.getSize(), totalParts, chunkIterator, vocabularyParts, finalListener);
            }
        }
        catch (Exception e) {
            finalListener.onFailure(e);
        }
    }

    void downloadModelDefinition(long size, int totalParts, @Nullable ModelLoaderUtils.VocabularyParts vocabularyParts, List<ModelLoaderUtils.HttpStreamChunker> downloaders, ActionListener<AcknowledgedResponse> finalListener) {
        try (RefCountingListener countingListener = new RefCountingListener(1, ActionListener.wrap(ignore -> this.executorService.execute(() -> {
            ModelLoaderUtils.HttpStreamChunker finalDownloader = (ModelLoaderUtils.HttpStreamChunker)downloaders.get(downloaders.size() - 1);
            this.downloadFinalPart(size, totalParts, finalDownloader, (ActionListener<AcknowledgedResponse>)finalListener.delegateFailureAndWrap((l, r) -> {
                this.checkDownloadComplete(downloaders);
                l.onResponse((Object)AcknowledgedResponse.TRUE);
            }));
        }), arg_0 -> finalListener.onFailure(arg_0)));){
            if (vocabularyParts != null) {
                this.uploadVocabulary(vocabularyParts, countingListener);
            }
            for (int streamSplit = 0; streamSplit < downloaders.size() - 1; ++streamSplit) {
                ModelLoaderUtils.HttpStreamChunker downloader = downloaders.get(streamSplit);
                ActionListener rangeDownloadedListener = countingListener.acquire();
                this.executorService.execute(() -> this.downloadPartInRange(size, totalParts, downloader, this.executorService, countingListener, (ActionListener<Void>)rangeDownloadedListener));
            }
        }
    }

    private void downloadPartInRange(long size, int totalParts, ModelLoaderUtils.HttpStreamChunker downloadChunker, ExecutorService executorService, RefCountingListener countingListener, ActionListener<Void> rangeFullyDownloadedListener) {
        assert (ThreadPool.assertCurrentThreadPool((String[])new String[]{"model_download"})) : Strings.format((String)"Model download must execute from [%s] but thread is [%s]", (Object[])new Object[]{"model_download", Thread.currentThread().getName()});
        if (countingListener.isFailing()) {
            rangeFullyDownloadedListener.onResponse(null);
            return;
        }
        try {
            this.throwIfTaskCancelled();
            ModelLoaderUtils.HttpStreamChunker.BytesAndPartIndex bytesAndIndex = downloadChunker.next();
            this.task.setProgress(totalParts, this.progressCounter.getAndIncrement());
            this.indexPart(bytesAndIndex.partIndex(), totalParts, size, bytesAndIndex.bytes());
        }
        catch (Exception e) {
            rangeFullyDownloadedListener.onFailure(e);
            return;
        }
        if (downloadChunker.hasNext()) {
            executorService.execute(() -> this.downloadPartInRange(size, totalParts, downloadChunker, executorService, countingListener, rangeFullyDownloadedListener));
        } else {
            rangeFullyDownloadedListener.onResponse(null);
        }
    }

    private void downloadFinalPart(long size, int totalParts, ModelLoaderUtils.HttpStreamChunker downloader, ActionListener<AcknowledgedResponse> lastPartWrittenListener) {
        assert (ThreadPool.assertCurrentThreadPool((String[])new String[]{"model_download"})) : Strings.format((String)"Model download must execute from [%s] but thread is [%s]", (Object[])new Object[]{"model_download", Thread.currentThread().getName()});
        try {
            ModelLoaderUtils.HttpStreamChunker.BytesAndPartIndex bytesAndIndex = downloader.next();
            this.task.setProgress(totalParts, this.progressCounter.getAndIncrement());
            this.indexPart(bytesAndIndex.partIndex(), totalParts, size, bytesAndIndex.bytes());
            lastPartWrittenListener.onResponse((Object)AcknowledgedResponse.TRUE);
        }
        catch (Exception e) {
            lastPartWrittenListener.onFailure(e);
        }
    }

    void readModelDefinitionFromFile(long size, int totalParts, ModelLoaderUtils.InputStreamChunker chunkIterator, @Nullable ModelLoaderUtils.VocabularyParts vocabularyParts, ActionListener<AcknowledgedResponse> finalListener) {
        try (RefCountingListener countingListener = new RefCountingListener(1, ActionListener.wrap(ignore -> this.executorService.execute(() -> finalListener.onResponse((Object)AcknowledgedResponse.TRUE)), arg_0 -> finalListener.onFailure(arg_0)));){
            try {
                if (vocabularyParts != null) {
                    this.uploadVocabulary(vocabularyParts, countingListener);
                }
                for (int part = 0; part < totalParts; ++part) {
                    this.throwIfTaskCancelled();
                    this.task.setProgress(totalParts, part);
                    BytesArray definition = chunkIterator.next();
                    this.indexPart(part, totalParts, size, definition);
                }
                this.task.setProgress(totalParts, totalParts);
                this.checkDownloadComplete(chunkIterator, totalParts);
            }
            catch (Exception e) {
                countingListener.acquire().onFailure(e);
            }
        }
    }

    private void uploadVocabulary(ModelLoaderUtils.VocabularyParts vocabularyParts, RefCountingListener countingListener) {
        PutTrainedModelVocabularyAction.Request request = new PutTrainedModelVocabularyAction.Request(this.modelId, vocabularyParts.vocab(), vocabularyParts.merges(), vocabularyParts.scores(), true);
        this.client.execute((ActionType)PutTrainedModelVocabularyAction.INSTANCE, (ActionRequest)request, countingListener.acquire(r -> logger.debug(() -> Strings.format((String)"[%s] imported model vocabulary [%s]", (Object[])new Object[]{this.modelId, this.config.getVocabularyFile()}))));
    }

    private void indexPart(int partIndex, int totalParts, long totalSize, BytesArray bytes) {
        PutTrainedModelDefinitionPartAction.Request modelPartRequest = new PutTrainedModelDefinitionPartAction.Request(this.modelId, (BytesReference)bytes, partIndex, totalSize, totalParts, true);
        this.client.execute((ActionType)PutTrainedModelDefinitionPartAction.INSTANCE, (ActionRequest)modelPartRequest).actionGet();
    }

    private void checkDownloadComplete(List<ModelLoaderUtils.HttpStreamChunker> downloaders) {
        long totalBytesRead = downloaders.stream().mapToLong(ModelLoaderUtils.HttpStreamChunker::getTotalBytesRead).sum();
        int totalParts = downloaders.stream().mapToInt(ModelLoaderUtils.HttpStreamChunker::getCurrentPart).sum();
        this.checkSize(totalBytesRead);
        logger.debug(Strings.format((String)"finished importing model [%s] using [%d] parts", (Object[])new Object[]{this.modelId, totalParts}));
    }

    private void checkDownloadComplete(ModelLoaderUtils.InputStreamChunker fileInputStream, int totalParts) {
        this.checkSha256(fileInputStream.getSha256());
        this.checkSize(fileInputStream.getTotalBytesRead());
        logger.debug(Strings.format((String)"finished importing model [%s] using [%d] parts", (Object[])new Object[]{this.modelId, totalParts}));
    }

    private void checkSha256(String sha256) {
        if (!this.config.getSha256().equals(sha256)) {
            String message = Strings.format((String)"Model sha256 checksums do not match, expected [%s] but got [%s]", (Object[])new Object[]{this.config.getSha256(), sha256});
            throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
    }

    private void checkSize(long definitionSize) {
        if (this.config.getSize() != definitionSize) {
            String message = Strings.format((String)"Model size does not match, expected [%d] but got [%d]", (Object[])new Object[]{this.config.getSize(), definitionSize});
            throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
    }

    private void throwIfTaskCancelled() {
        if (this.task.isCancelled()) {
            logger.info("Model [{}] download task cancelled", (Object)this.modelId);
            throw new TaskCancelledException(Strings.format((String)"Model [%s] download task cancelled with reason [%s]", (Object[])new Object[]{this.modelId, this.task.getReasonCancelled()}));
        }
    }
}

