/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.deployment;

import java.io.IOException;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.AbstractPyTorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;

class InferencePyTorchAction
extends AbstractPyTorchAction<InferenceResults> {
    private static final Logger logger = LogManager.getLogger(InferencePyTorchAction.class);
    private final InferenceConfig config;
    private final NlpInferenceInput input;
    @Nullable
    private final CancellableTask parentActionTask;
    private final TrainedModelPrefixStrings.PrefixType prefixType;
    private final boolean chunkResponse;

    InferencePyTorchAction(String deploymentId, long requestId, TimeValue timeout, DeploymentManager.ProcessContext processContext, InferenceConfig config, NlpInferenceInput input, TrainedModelPrefixStrings.PrefixType prefixType, ThreadPool threadPool, @Nullable CancellableTask parentActionTask, boolean chunkResponse, ActionListener<InferenceResults> listener) {
        super(deploymentId, requestId, timeout, processContext, threadPool, listener);
        this.config = config;
        this.input = input;
        this.prefixType = prefixType;
        this.parentActionTask = parentActionTask;
        this.chunkResponse = chunkResponse;
    }

    private boolean isCancelled() {
        if (this.parentActionTask != null) {
            try {
                this.parentActionTask.ensureNotCancelled();
            }
            catch (TaskCancelledException ex) {
                logger.warn(() -> Strings.format((String)"[%s] %s", (Object[])new Object[]{this.getDeploymentId(), ex.getMessage()}));
                return true;
            }
        }
        return false;
    }

    protected void doRun() throws Exception {
        if (this.isNotified()) {
            logger.debug(() -> Strings.format((String)"[%s] skipping inference on request [%s] as it has timed out", (Object[])new Object[]{this.getDeploymentId(), this.getRequestId()}));
            return;
        }
        String requestIdStr = String.valueOf(this.getRequestId());
        if (this.isCancelled()) {
            this.onCancel();
            return;
        }
        try {
            TrainedModelPrefixStrings prefixStrings;
            Object inputText = this.input.extractInput((TrainedModelInput)this.getProcessContext().getModelInput().get());
            if (this.prefixType != TrainedModelPrefixStrings.PrefixType.NONE && (prefixStrings = (TrainedModelPrefixStrings)this.getProcessContext().getPrefixStrings().get()) != null) {
                switch (this.prefixType) {
                    case SEARCH: {
                        if (org.elasticsearch.common.Strings.isNullOrEmpty((String)prefixStrings.searchPrefix())) break;
                        inputText = prefixStrings.searchPrefix() + (String)inputText;
                        break;
                    }
                    case INGEST: {
                        if (org.elasticsearch.common.Strings.isNullOrEmpty((String)prefixStrings.ingestPrefix())) break;
                        inputText = prefixStrings.ingestPrefix() + (String)inputText;
                        break;
                    }
                    default: {
                        throw new IllegalStateException("[" + this.getDeploymentId() + "] Unhandled input prefix type [" + String.valueOf(this.prefixType) + "]");
                    }
                }
            }
            List<String> inputs = List.of(inputText);
            NlpTask.Processor processor = (NlpTask.Processor)this.getProcessContext().getNlpTaskProcessor().get();
            processor.validateInputs(inputs);
            assert (this.config instanceof NlpConfig);
            NlpConfig nlpConfig = (NlpConfig)this.config;
            int span = nlpConfig.getTokenization().getSpan();
            if (this.chunkResponse && nlpConfig.getTokenization().getSpan() <= 0) {
                span = -2;
            }
            NlpTask.Request request = processor.getRequestBuilder(nlpConfig).buildRequest(inputs, requestIdStr, nlpConfig.getTokenization().getTruncate(), span, nlpConfig.getTokenization().maxSequenceLength());
            logger.debug(() -> Strings.format((String)"handling request [%s]", (Object[])new Object[]{requestIdStr}));
            if (this.isCancelled()) {
                this.onCancel();
                return;
            }
            this.getProcessContext().getResultProcessor().registerRequest(requestIdStr, (ActionListener<PyTorchResult>)ActionListener.wrap(result -> this.processResult((PyTorchResult)result, request.tokenization(), processor.getResultProcessor(nlpConfig)), this::onFailure));
            ((PyTorchProcess)this.getProcessContext().getProcess().get()).writeInferenceRequest(request.processInput());
        }
        catch (IOException e) {
            logger.error(() -> "[" + this.getDeploymentId() + "] error writing to inference process", (Throwable)e);
            this.onFailure((Exception)((Object)ExceptionsHelper.serverError((String)"Error writing to inference process", (Throwable)e)));
        }
        catch (ElasticsearchException e) {
            if (e.status().getStatus() >= RestStatus.INTERNAL_SERVER_ERROR.getStatus()) {
                logger.error(() -> "[" + this.getDeploymentId() + "] internal server error running inference", (Throwable)e);
            } else {
                logger.debug(() -> "[" + this.getDeploymentId() + "] error running inference due to input", (Throwable)e);
            }
            this.onFailure((Exception)((Object)e));
        }
        catch (IllegalArgumentException e) {
            logger.debug(() -> "[" + this.getDeploymentId() + "] illegal argument running inference", (Throwable)e);
            this.onFailure(e);
        }
        catch (Exception e) {
            logger.error(() -> "[" + this.getDeploymentId() + "] error running inference", (Throwable)e);
            this.onFailure(e);
        }
    }

    private void processResult(PyTorchResult pyTorchResult, TokenizationResult tokenization, NlpTask.ResultProcessor inferenceResultsProcessor) {
        if (pyTorchResult.isError()) {
            this.onFailure(pyTorchResult.errorResult());
            return;
        }
        logger.debug(() -> Strings.format((String)"[%s] retrieved result for request [%s]", (Object[])new Object[]{this.getDeploymentId(), this.getRequestId()}));
        if (this.isNotified()) {
            logger.debug(() -> Strings.format((String)"[%s] skipping result processing for request [%s] as the request has timed out", (Object[])new Object[]{this.getDeploymentId(), this.getRequestId()}));
            return;
        }
        if (this.isCancelled()) {
            this.onCancel();
            return;
        }
        this.getProcessContext().getResultProcessor().updateStats(pyTorchResult);
        InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult.inferenceResult(), this.chunkResponse);
        logger.debug(() -> Strings.format((String)"[%s] processed result for request [%s]", (Object[])new Object[]{this.getDeploymentId(), this.getRequestId()}));
        this.onSuccess(results);
    }

    @Override
    protected Logger getLogger() {
        return logger;
    }
}

