/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action;

import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.RateLimitAssignment;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;

public abstract class BaseTransportInferenceAction<Request extends BaseInferenceActionRequest>
extends HandledTransportAction<Request, InferenceAction.Response> {
    private static final Logger log = LogManager.getLogger(BaseTransportInferenceAction.class);
    private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
    private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
    private final XPackLicenseState licenseState;
    private final ModelRegistry modelRegistry;
    private final InferenceServiceRegistry serviceRegistry;
    private final InferenceStats inferenceStats;
    private final StreamingTaskManager streamingTaskManager;
    private final InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator;
    private final NodeClient nodeClient;
    private final ThreadPool threadPool;
    private final TransportService transportService;
    private final Random random;

    public BaseTransportInferenceAction(String inferenceActionName, TransportService transportService, ActionFilters actionFilters, XPackLicenseState licenseState, ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager, Writeable.Reader<Request> requestReader, InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator, NodeClient nodeClient, ThreadPool threadPool) {
        super(inferenceActionName, transportService, actionFilters, requestReader, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.licenseState = licenseState;
        this.modelRegistry = modelRegistry;
        this.serviceRegistry = serviceRegistry;
        this.inferenceStats = inferenceStats;
        this.streamingTaskManager = streamingTaskManager;
        this.inferenceServiceRateLimitCalculator = inferenceServiceNodeLocalRateLimitCalculator;
        this.nodeClient = nodeClient;
        this.threadPool = threadPool;
        this.transportService = transportService;
        this.random = Randomness.get();
    }

    protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request var1, UnparsedModel var2);

    protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request var1, UnparsedModel var2);

    protected abstract void doInference(Model var1, Request var2, InferenceService var3, ActionListener<InferenceServiceResults> var4);

    protected void doExecute(Task task, Request request, ActionListener<InferenceAction.Response> listener) {
        if (!InferencePlugin.INFERENCE_API_FEATURE.check(this.licenseState)) {
            listener.onFailure((Exception)LicenseUtils.newComplianceException((String)"inference"));
            return;
        }
        InferenceTimer timer = InferenceTimer.start();
        ActionListener getModelListener = ActionListener.wrap(unparsedModel -> {
            String serviceName = unparsedModel.service();
            try {
                this.validateRequest(request, (UnparsedModel)unparsedModel);
            }
            catch (Exception e) {
                this.recordMetrics((UnparsedModel)unparsedModel, timer, (Throwable)e);
                listener.onFailure(e);
                return;
            }
            InferenceContext context = request.getContext();
            if (Objects.nonNull(context)) {
                this.threadPool.getThreadContext().putHeader("X-elastic-product-use-case", context.productUseCase());
            }
            InferenceService service = (InferenceService)this.serviceRegistry.getService(serviceName).get();
            NodeRoutingDecision routingDecision = this.determineRouting(serviceName, request, (UnparsedModel)unparsedModel);
            if (routingDecision.currentNodeShouldHandleRequest()) {
                Model model = service.parsePersistedConfigWithSecrets(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings(), unparsedModel.secrets());
                this.inferOnServiceWithMetrics(model, request, service, timer, listener);
            } else {
                request.setHasBeenRerouted(true);
                this.rerouteRequest(request, listener, routingDecision.targetNode);
            }
        }, e -> {
            try {
                this.inferenceStats.inferenceDuration().record(timer.elapsedMillis(), InferenceStats.responseAttributes(e));
            }
            catch (Exception metricsException) {
                log.atDebug().withThrowable((Throwable)metricsException).log("Failed to record metrics when the model is missing, dropping metrics");
            }
            listener.onFailure(e);
        });
        this.modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), (ActionListener<UnparsedModel>)getModelListener);
    }

    private void validateRequest(Request request, UnparsedModel unparsedModel) {
        String serviceName = unparsedModel.service();
        TaskType requestTaskType = request.getTaskType();
        Optional service = this.serviceRegistry.getService(serviceName);
        BaseTransportInferenceAction.validationHelper(service::isEmpty, () -> BaseTransportInferenceAction.unknownServiceException(serviceName, request.getInferenceEntityId()));
        BaseTransportInferenceAction.validationHelper(() -> !request.getTaskType().isAnyOrSame(unparsedModel.taskType()), () -> BaseTransportInferenceAction.requestModelTaskTypeMismatchException(requestTaskType, unparsedModel.taskType()));
        BaseTransportInferenceAction.validationHelper(() -> this.isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel), () -> this.createInvalidTaskTypeException(request, unparsedModel));
    }

    private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel) {
        TaskType modelTaskType = unparsedModel.taskType();
        if (!this.inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceName, modelTaskType) || request.hasBeenRerouted()) {
            return NodeRoutingDecision.handleLocally();
        }
        RateLimitAssignment rateLimitAssignment = this.inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceName, modelTaskType);
        if (rateLimitAssignment == null) {
            return NodeRoutingDecision.handleLocally();
        }
        List<DiscoveryNode> responsibleNodes = rateLimitAssignment.responsibleNodes();
        if (responsibleNodes == null || responsibleNodes.isEmpty()) {
            return NodeRoutingDecision.handleLocally();
        }
        DiscoveryNode nodeToHandleRequest = responsibleNodes.get(this.random.nextInt(responsibleNodes.size()));
        String localNodeId = this.nodeClient.getLocalNodeId();
        if (nodeToHandleRequest.getId().equals(localNodeId)) {
            return NodeRoutingDecision.handleLocally();
        }
        return NodeRoutingDecision.routeTo(nodeToHandleRequest);
    }

    private static void validationHelper(Supplier<Boolean> validationFailure, Supplier<ElasticsearchStatusException> exceptionCreator) {
        if (validationFailure.get().booleanValue()) {
            throw exceptionCreator.get();
        }
    }

    private void rerouteRequest(Request request, final ActionListener<InferenceAction.Response> listener, DiscoveryNode nodeToHandleRequest) {
        this.transportService.sendRequest(nodeToHandleRequest, "cluster:internal/xpack/inference", request, (TransportResponseHandler)new TransportResponseHandler<InferenceAction.Response>(){

            public Executor executor() {
                return BaseTransportInferenceAction.this.threadPool.executor("inference_utility");
            }

            public void handleResponse(InferenceAction.Response response) {
                listener.onResponse((Object)response);
            }

            public void handleException(TransportException exp) {
                listener.onFailure((Exception)exp);
            }

            public InferenceAction.Response read(StreamInput in) throws IOException {
                return new InferenceAction.Response(in);
            }
        });
    }

    private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
        try {
            this.inferenceStats.inferenceDuration().record(timer.elapsedMillis(), InferenceStats.responseAttributes(model, t));
        }
        catch (Exception e) {
            log.atDebug().withThrowable((Throwable)e).log("Failed to record metrics with an unparsed model, dropping metrics");
        }
    }

    private void inferOnServiceWithMetrics(Model model, Request request, InferenceService service, InferenceTimer timer, ActionListener<InferenceAction.Response> listener) {
        this.inferenceStats.requestCount().incrementBy(1L, InferenceStats.modelAttributes(model));
        this.inferOnService(model, request, service, (ActionListener<InferenceServiceResults>)ActionListener.wrap(inferenceResults -> {
            if (request.isStreaming()) {
                Flow.Processor taskProcessor = this.streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
                inferenceResults.publisher().subscribe(taskProcessor);
                Flow.Publisher instrumentedStream = this.publisherWithMetrics(timer, model, taskProcessor);
                Flow.Publisher streamErrorHandler = this.streamErrorHandler(instrumentedStream);
                listener.onResponse((Object)new InferenceAction.Response(inferenceResults, streamErrorHandler));
            } else {
                this.recordMetrics(model, timer, null);
                listener.onResponse((Object)new InferenceAction.Response(inferenceResults));
            }
        }, e -> {
            this.recordMetrics(model, timer, (Throwable)e);
            listener.onFailure(e);
        }));
    }

    private <T> Flow.Publisher<T> publisherWithMetrics(final InferenceTimer timer, final Model model, Flow.Processor<T, T> upstream) {
        return downstream -> upstream.subscribe(new Flow.Subscriber<T>(){

            @Override
            public void onSubscribe(final Flow.Subscription subscription) {
                downstream.onSubscribe(new Flow.Subscription(){

                    @Override
                    public void request(long n) {
                        subscription.request(n);
                    }

                    @Override
                    public void cancel() {
                        BaseTransportInferenceAction.this.recordMetrics(model, timer, null);
                        subscription.cancel();
                    }
                });
            }

            @Override
            public void onNext(T item) {
                downstream.onNext(item);
            }

            @Override
            public void onError(Throwable throwable) {
                BaseTransportInferenceAction.this.recordMetrics(model, timer, throwable);
                downstream.onError(throwable);
            }

            @Override
            public void onComplete() {
                BaseTransportInferenceAction.this.recordMetrics(model, timer, null);
                downstream.onComplete();
            }
        });
    }

    protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
        return upstream;
    }

    private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
        try {
            this.inferenceStats.inferenceDuration().record(timer.elapsedMillis(), InferenceStats.responseAttributes(model, ExceptionsHelper.unwrapCause((Throwable)t)));
        }
        catch (Exception e) {
            log.atDebug().withThrowable((Throwable)e).log("Failed to record metrics with a parsed model, dropping metrics");
        }
    }

    private void inferOnService(Model model, Request request, InferenceService service, ActionListener<InferenceServiceResults> listener) {
        if (!request.isStreaming() || service.canStream(request.getTaskType())) {
            this.doInference(model, request, service, listener);
        } else {
            listener.onFailure((Exception)this.unsupportedStreamingTaskException(request, service));
        }
    }

    private ElasticsearchStatusException unsupportedStreamingTaskException(Request request, InferenceService service) {
        Set supportedTasks = service.supportedStreamingTasks();
        if (supportedTasks.isEmpty()) {
            return new ElasticsearchStatusException(Strings.format((String)"Streaming is not allowed for service [%s].", (Object[])new Object[]{service.name()}), RestStatus.METHOD_NOT_ALLOWED, new Object[0]);
        }
        String validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(","));
        return new ElasticsearchStatusException(Strings.format((String)"Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", (Object[])new Object[]{service.name(), request.getTaskType(), validTasks}), RestStatus.METHOD_NOT_ALLOWED, new Object[0]);
    }

    private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
        return new ElasticsearchStatusException("Unknown service [{}] for model [{}]", RestStatus.BAD_REQUEST, new Object[]{service, inferenceId});
    }

    private static ElasticsearchStatusException requestModelTaskTypeMismatchException(TaskType requested, TaskType expected) {
        return new ElasticsearchStatusException("Incompatible task_type, the requested type [{}] does not match the model type [{}]", RestStatus.BAD_REQUEST, new Object[]{requested, expected});
    }

    private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
        static NodeRoutingDecision handleLocally() {
            return new NodeRoutingDecision(true, null);
        }

        static NodeRoutingDecision routeTo(DiscoveryNode node) {
            return new NodeRoutingDecision(false, node);
        }
    }
}

