/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.nvidia.action;

import java.util.Map;
import java.util.Objects;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.nvidia.action.NvidiaActionVisitor;
import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaChatCompletionModel;
import org.elasticsearch.xpack.inference.services.nvidia.completion.NvidiaCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.nvidia.embeddings.NvidiaEmbeddingsResponseHandler;
import org.elasticsearch.xpack.inference.services.nvidia.request.completion.NvidiaChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.nvidia.request.embeddings.NvidiaEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.nvidia.request.rerank.NvidiaRerankRequest;
import org.elasticsearch.xpack.inference.services.nvidia.rerank.NvidiaRerankModel;
import org.elasticsearch.xpack.inference.services.nvidia.rerank.NvidiaRerankResponseHandler;
import org.elasticsearch.xpack.inference.services.nvidia.response.rerank.NvidiaRerankResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity;

public class NvidiaActionCreator
implements NvidiaActionVisitor {
    private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Nvidia %s request from inference entity id [%s]";
    private static final String COMPLETION_ERROR_PREFIX = "Nvidia completions";
    private static final String USER_ROLE = "user";
    private static final ResponseHandler EMBEDDINGS_HANDLER = new NvidiaEmbeddingsResponseHandler("Nvidia text embedding", OpenAiEmbeddingsResponseEntity::fromResponse);
    private static final ResponseHandler COMPLETION_HANDLER = new NvidiaCompletionResponseHandler("Nvidia completion", OpenAiChatCompletionResponseEntity::fromResponse);
    private static final ResponseHandler RERANK_HANDLER = new NvidiaRerankResponseHandler("Nvidia rerank", (request, response) -> NvidiaRerankResponseEntity.fromResponse(response));
    private final Sender sender;
    private final ServiceComponents serviceComponents;

    public NvidiaActionCreator(Sender sender, ServiceComponents serviceComponents) {
        this.sender = Objects.requireNonNull(sender);
        this.serviceComponents = Objects.requireNonNull(serviceComponents);
    }

    @Override
    public ExecutableAction create(NvidiaEmbeddingsModel model, Map<String, Object> taskSettings) {
        NvidiaEmbeddingsModel overriddenModel = NvidiaEmbeddingsModel.of(model, taskSettings);
        GenericRequestManager<EmbeddingsInput> manager = new GenericRequestManager<EmbeddingsInput>(this.serviceComponents.threadPool(), overriddenModel, EMBEDDINGS_HANDLER, embeddingsInput -> new NvidiaEmbeddingsRequest(this.serviceComponents.truncator(), Truncator.truncate(embeddingsInput.getTextInputs(), overriddenModel.getServiceSettings().maxInputTokens()), overriddenModel, embeddingsInput.getInputType()), EmbeddingsInput.class);
        String errorMessage = NvidiaActionCreator.buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId());
        return new SenderExecutableAction(this.sender, manager, errorMessage);
    }

    @Override
    public ExecutableAction create(NvidiaChatCompletionModel model) {
        GenericRequestManager<ChatCompletionInput> manager = new GenericRequestManager<ChatCompletionInput>(this.serviceComponents.threadPool(), model, COMPLETION_HANDLER, inputs -> new NvidiaChatCompletionRequest(new UnifiedChatInput((ChatCompletionInput)inputs, USER_ROLE), model), ChatCompletionInput.class);
        String errorMessage = NvidiaActionCreator.buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId());
        return new SingleInputSenderExecutableAction(this.sender, manager, errorMessage, COMPLETION_ERROR_PREFIX);
    }

    @Override
    public ExecutableAction create(NvidiaRerankModel model) {
        GenericRequestManager<QueryAndDocsInputs> manager = new GenericRequestManager<QueryAndDocsInputs>(this.serviceComponents.threadPool(), model, RERANK_HANDLER, inputs -> new NvidiaRerankRequest(inputs.getQuery(), inputs.getChunks(), model), QueryAndDocsInputs.class);
        String errorMessage = NvidiaActionCreator.buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId());
        return new SenderExecutableAction(this.sender, manager, errorMessage);
    }

    public static String buildErrorMessage(TaskType requestType, String inferenceId) {
        return Strings.format((String)FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, (Object[])new Object[]{requestType.toString(), inferenceId});
    }
}

