/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action;

import java.util.HashMap;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.telemetry.InferenceStats;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.InferenceLicenceCheck;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;

public abstract class BaseTransportInferenceAction<Request extends BaseInferenceActionRequest>
extends HandledTransportAction<Request, InferenceAction.Response> {
    private static final Logger log = LogManager.getLogger(BaseTransportInferenceAction.class);
    private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
    private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
    private final XPackLicenseState licenseState;
    private final InferenceEndpointRegistry endpointRegistry;
    private final InferenceServiceRegistry serviceRegistry;
    private final InferenceStats inferenceStats;
    private final StreamingTaskManager streamingTaskManager;
    private final NodeClient nodeClient;
    private final ThreadPool threadPool;
    private final TransportService transportService;
    private final Random random;

    public BaseTransportInferenceAction(String inferenceActionName, TransportService transportService, ActionFilters actionFilters, XPackLicenseState licenseState, InferenceEndpointRegistry endpointRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, Writeable.Reader<Request> requestReader, NodeClient nodeClient, ThreadPool threadPool) {
        super(inferenceActionName, transportService, actionFilters, requestReader, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.licenseState = licenseState;
        this.endpointRegistry = endpointRegistry;
        this.serviceRegistry = serviceRegistry;
        this.inferenceStats = inferenceStats;
        this.streamingTaskManager = streamingTaskManager;
        this.nodeClient = nodeClient;
        this.threadPool = threadPool;
        this.transportService = transportService;
        this.random = Randomness.get();
    }

    protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request var1, Model var2);

    protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request var1, Model var2);

    protected abstract void doInference(Model var1, Request var2, InferenceService var3, ActionListener<InferenceServiceResults> var4);

    protected void doExecute(Task task, Request request, ActionListener<InferenceAction.Response> listener) {
        InferenceTimer timer = InferenceTimer.start();
        ActionListener getModelListener = ActionListener.wrap(model -> {
            boolean headerNotPresentInThreadContext;
            String serviceName = model.getConfigurations().getService();
            if (!InferenceLicenceCheck.isServiceLicenced(serviceName, this.licenseState)) {
                listener.onFailure((Exception)InferenceLicenceCheck.complianceException(serviceName));
                return;
            }
            try {
                this.validateRequest(request, (Model)model);
            }
            catch (Exception e) {
                this.recordRequestDurationMetrics((Model)model, timer, e);
                listener.onFailure(e);
                return;
            }
            InferenceContext context = request.getContext();
            if (Objects.nonNull(context) && (headerNotPresentInThreadContext = Objects.isNull(this.threadPool.getThreadContext().getHeader("X-elastic-product-use-case")))) {
                this.threadPool.getThreadContext().putHeader("X-elastic-product-use-case", context.productUseCase());
            }
            InferenceService service = (InferenceService)this.serviceRegistry.getService(serviceName).get();
            String localNodeId = this.nodeClient.getLocalNodeId();
            this.inferOnServiceWithMetrics((Model)model, request, service, timer, localNodeId, listener);
        }, e -> {
            try {
                this.inferenceStats.inferenceDuration().record(timer.elapsedMillis(), InferenceStats.responseAttributes((Throwable)e));
            }
            catch (Exception metricsException) {
                log.atDebug().withThrowable((Throwable)metricsException).log("Failed to record metrics when the model is missing, dropping metrics");
            }
            listener.onFailure(e);
        });
        this.endpointRegistry.getEndpoint(request.getInferenceEntityId(), (ActionListener<Model>)getModelListener);
    }

    private void validateRequest(Request request, Model model) {
        String serviceName = model.getConfigurations().getService();
        TaskType requestTaskType = request.getTaskType();
        Optional service = this.serviceRegistry.getService(serviceName);
        BaseTransportInferenceAction.validationHelper(service::isEmpty, () -> BaseTransportInferenceAction.unknownServiceException(serviceName, request.getInferenceEntityId()));
        BaseTransportInferenceAction.validationHelper(() -> !request.getTaskType().isAnyOrSame(model.getTaskType()), () -> BaseTransportInferenceAction.requestModelTaskTypeMismatchException(requestTaskType, model.getTaskType()));
        BaseTransportInferenceAction.validationHelper(() -> this.isInvalidTaskTypeForInferenceEndpoint(request, model), () -> this.createInvalidTaskTypeException(request, model));
    }

    private static void validationHelper(Supplier<Boolean> validationFailure, Supplier<ElasticsearchStatusException> exceptionCreator) {
        if (validationFailure.get().booleanValue()) {
            throw exceptionCreator.get();
        }
    }

    private void recordRequestDurationMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
        HashMap metricAttributes = new HashMap();
        metricAttributes.putAll(InferenceStats.serviceAttributes((Model)model));
        metricAttributes.putAll(InferenceStats.responseAttributes((Throwable)ExceptionsHelper.unwrapCause((Throwable)t)));
        this.inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
    }

    private void inferOnServiceWithMetrics(Model model, Request request, InferenceService service, InferenceTimer timer, String localNodeId, ActionListener<InferenceAction.Response> listener) {
        this.recordRequestCountMetrics(model, request, localNodeId);
        this.inferOnService(model, request, service, (ActionListener<InferenceServiceResults>)ActionListener.wrap(inferenceResults -> {
            if (request.isStreaming()) {
                Flow.Processor taskProcessor = this.streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
                inferenceResults.publisher().subscribe(taskProcessor);
                Flow.Publisher instrumentedStream = this.publisherWithMetrics(timer, model, request, localNodeId, taskProcessor);
                Flow.Publisher streamErrorHandler = this.streamErrorHandler(instrumentedStream);
                listener.onResponse((Object)new InferenceAction.Response(inferenceResults, streamErrorHandler));
            } else {
                this.recordRequestDurationMetrics(model, timer, request, localNodeId, null);
                listener.onResponse((Object)new InferenceAction.Response(inferenceResults));
            }
        }, e -> {
            this.recordRequestDurationMetrics(model, timer, request, localNodeId, (Throwable)e);
            listener.onFailure(e);
        }));
    }

    private <T> Flow.Publisher<T> publisherWithMetrics(final InferenceTimer timer, final Model model, Request request, String localNodeId, Flow.Processor<T, T> upstream) {
        return downstream -> upstream.subscribe(new Flow.Subscriber<T>((BaseInferenceActionRequest)request, localNodeId){
            final /* synthetic */ BaseInferenceActionRequest val$request;
            final /* synthetic */ String val$localNodeId;
            {
                this.val$request = baseInferenceActionRequest;
                this.val$localNodeId = string;
            }

            @Override
            public void onSubscribe(final Flow.Subscription subscription) {
                downstream.onSubscribe(new Flow.Subscription(){

                    @Override
                    public void request(long n) {
                        subscription.request(n);
                    }

                    @Override
                    public void cancel() {
                        BaseTransportInferenceAction.this.recordRequestDurationMetrics(model, timer, val$request, val$localNodeId, null);
                        subscription.cancel();
                    }
                });
            }

            @Override
            public void onNext(T item) {
                downstream.onNext(item);
            }

            @Override
            public void onError(Throwable throwable) {
                BaseTransportInferenceAction.this.recordRequestDurationMetrics(model, timer, this.val$request, this.val$localNodeId, throwable);
                downstream.onError(throwable);
            }

            @Override
            public void onComplete() {
                BaseTransportInferenceAction.this.recordRequestDurationMetrics(model, timer, this.val$request, this.val$localNodeId, null);
                downstream.onComplete();
            }
        });
    }

    protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
        return upstream;
    }

    private void recordRequestCountMetrics(Model model, Request request, String localNodeId) {
        HashMap requestCountAttributes = new HashMap();
        requestCountAttributes.putAll(InferenceStats.serviceAttributes((Model)model));
        this.inferenceStats.requestCount().incrementBy(1L, requestCountAttributes);
    }

    private void recordRequestDurationMetrics(Model model, InferenceTimer timer, Request request, String localNodeId, @Nullable Throwable t) {
        HashMap metricAttributes = new HashMap();
        metricAttributes.putAll(InferenceStats.serviceAndResponseAttributes((Model)model, (Throwable)ExceptionsHelper.unwrapCause((Throwable)t)));
        this.inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
    }

    private void inferOnService(Model model, Request request, InferenceService service, ActionListener<InferenceServiceResults> listener) {
        if (!request.isStreaming() || service.canStream(model.getTaskType())) {
            this.doInference(model, request, service, listener);
        } else {
            listener.onFailure((Exception)this.unsupportedStreamingTaskException(request, service));
        }
    }

    private ElasticsearchStatusException unsupportedStreamingTaskException(Request request, InferenceService service) {
        Set supportedTasks = service.supportedStreamingTasks();
        if (supportedTasks.isEmpty()) {
            return new ElasticsearchStatusException(Strings.format((String)"Streaming is not allowed for service [%s].", (Object[])new Object[]{service.name()}), RestStatus.METHOD_NOT_ALLOWED, new Object[0]);
        }
        String validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(","));
        return new ElasticsearchStatusException(Strings.format((String)"Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", (Object[])new Object[]{service.name(), request.getTaskType(), validTasks}), RestStatus.METHOD_NOT_ALLOWED, new Object[0]);
    }

    private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
        return new ElasticsearchStatusException("Unknown service [{}] for model [{}]", RestStatus.BAD_REQUEST, new Object[]{service, inferenceId});
    }

    private static ElasticsearchStatusException requestModelTaskTypeMismatchException(TaskType requested, TaskType expected) {
        return new ElasticsearchStatusException("Incompatible task_type, the requested type [{}] does not match the model type [{}]", RestStatus.BAD_REQUEST, new Object[]{requested, expected});
    }

    private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
        static NodeRoutingDecision handleLocally() {
            return new NodeRoutingDecision(true, null);
        }

        static NodeRoutingDecision routeTo(DiscoveryNode node) {
            return new NodeRoutingDecision(false, node);
        }
    }
}

