/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.deployment;

import java.io.IOException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.AbstractPyTorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;

abstract class AbstractControlMessagePyTorchAction<T>
extends AbstractPyTorchAction<T> {
    private static final Logger logger = LogManager.getLogger(AbstractControlMessagePyTorchAction.class);

    AbstractControlMessagePyTorchAction(String deploymentId, long requestId, TimeValue timeout, DeploymentManager.ProcessContext processContext, ThreadPool threadPool, ActionListener<T> listener) {
        super(deploymentId, requestId, timeout, processContext, threadPool, listener);
    }

    abstract int controlOrdinal();

    abstract void writeMessage(XContentBuilder var1) throws IOException;

    abstract T getResult(PyTorchResult var1);

    protected void doRun() throws Exception {
        if (this.isNotified()) {
            logger.debug(() -> Strings.format((String)"[%s] skipping control message on request [%s] as it has timed out", (Object[])new Object[]{this.getDeploymentId(), this.getRequestId()}));
            return;
        }
        String requestIdStr = String.valueOf(this.getRequestId());
        try {
            BytesReference message = this.buildControlMessage(requestIdStr);
            this.getProcessContext().getResultProcessor().registerRequest(requestIdStr, (ActionListener<PyTorchResult>)ActionListener.wrap(this::processResponse, this::onFailure));
            ((PyTorchProcess)this.getProcessContext().getProcess().get()).writeInferenceRequest(message);
        }
        catch (IOException e) {
            logger.error(() -> "[" + this.getDeploymentId() + "] error writing control message to the inference process", (Throwable)e);
            this.onFailure((Exception)((Object)ExceptionsHelper.serverError((String)"Error writing control message to the inference process", (Throwable)e)));
        }
        catch (Exception e) {
            this.onFailure(e);
        }
    }

    final BytesReference buildControlMessage(String requestId) throws IOException {
        XContentBuilder builder = XContentFactory.jsonBuilder();
        builder.startObject();
        builder.field("request_id", requestId);
        builder.field("control", this.controlOrdinal());
        this.writeMessage(builder);
        builder.endObject();
        return BytesReference.bytes((XContentBuilder)builder);
    }

    private void processResponse(PyTorchResult result) {
        if (result.isError()) {
            this.onFailure(result.errorResult());
            return;
        }
        this.onSuccess(this.getResult(result));
    }

    @Override
    protected Logger getLogger() {
        return logger;
    }

    static enum ControlMessageTypes {
        AllocationThreads,
        ClearCache;

    }
}

