/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.llama;

import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.core.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.llama.LlamaModel;
import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionCreator;
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel;
import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

public class LlamaService
extends SenderService {
    public static final String NAME = "llama";
    private static final String SERVICE_NAME = "Llama";
    private static final TransportVersion ML_INFERENCE_LLAMA_ADDED = TransportVersion.fromName((String)"ml_inference_llama_added");
    static final int EMBEDDING_MAX_BATCH_SIZE = 20;
    private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
    private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler("llama chat completion", OpenAiChatCompletionResponseEntity::fromResponse);

    public LlamaService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, InferenceServiceExtension.InferenceServiceFactoryContext context) {
        this(factory, serviceComponents, context.clusterService());
    }

    public LlamaService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
        super(factory, serviceComponents, clusterService);
    }

    @Override
    protected void doInfer(Model model, InferenceInputs inputs, Map<String, Object> taskSettings, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        LlamaActionCreator actionCreator = new LlamaActionCreator(this.getSender(), this.getServiceComponents());
        if (model instanceof LlamaModel) {
            LlamaModel llamaModel = (LlamaModel)model;
            llamaModel.accept(actionCreator).execute(inputs, timeout, listener);
        } else {
            listener.onFailure((Exception)((Object)ServiceUtils.createInvalidModelException(model)));
        }
    }

    @Override
    protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
        ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException);
    }

    protected LlamaModel createModel(String inferenceId, TaskType taskType, Map<String, Object> serviceSettings, ChunkingSettings chunkingSettings, Map<String, Object> secretSettings, ConfigurationParseContext context) {
        switch (taskType) {
            case TEXT_EMBEDDING: {
                return new LlamaEmbeddingsModel(inferenceId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context);
            }
            case CHAT_COMPLETION: 
            case COMPLETION: {
                return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context);
            }
        }
        throw ServiceUtils.createInvalidTaskTypeException(inferenceId, NAME, taskType, context);
    }

    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
        if (model instanceof LlamaEmbeddingsModel) {
            LlamaEmbeddingsModel embeddingsModel = (LlamaEmbeddingsModel)model;
            LlamaEmbeddingsServiceSettings serviceSettings = embeddingsModel.getServiceSettings();
            SimilarityMeasure similarityFromModel = serviceSettings.similarity();
            SimilarityMeasure similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
            LlamaEmbeddingsServiceSettings updatedServiceSettings = new LlamaEmbeddingsServiceSettings(serviceSettings.modelId(), serviceSettings.uri(), (Integer)embeddingSize, similarityToUse, serviceSettings.maxInputTokens(), serviceSettings.rateLimitSettings());
            return new LlamaEmbeddingsModel(embeddingsModel, updatedServiceSettings);
        }
        throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
    }

    @Override
    protected void doChunkedInfer(Model model, List<ChunkInferenceInput> inputs, Map<String, Object> taskSettings, InputType inputType, TimeValue timeout, ActionListener<List<ChunkedInference>> listener) {
        if (!(model instanceof LlamaEmbeddingsModel)) {
            listener.onFailure((Exception)((Object)ServiceUtils.createInvalidModelException(model)));
            return;
        }
        LlamaEmbeddingsModel llamaModel = (LlamaEmbeddingsModel)model;
        LlamaActionCreator actionCreator = new LlamaActionCreator(this.getSender(), this.getServiceComponents());
        List batchedRequests = new EmbeddingRequestChunker(inputs, 20, llamaModel.getConfigurations().getChunkingSettings()).batchRequestsWithListeners(listener);
        for (EmbeddingRequestChunker.BatchRequestAndListener request : batchedRequests) {
            ExecutableAction action = llamaModel.accept(actionCreator);
            action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, (ActionListener<InferenceServiceResults>)request.listener());
        }
    }

    @Override
    protected void doUnifiedCompletionInfer(Model model, UnifiedChatInput inputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        if (!(model instanceof LlamaChatCompletionModel)) {
            listener.onFailure((Exception)((Object)ServiceUtils.createInvalidModelException(model)));
            return;
        }
        LlamaChatCompletionModel llamaChatCompletionModel = (LlamaChatCompletionModel)model;
        LlamaChatCompletionModel overriddenModel = LlamaChatCompletionModel.of(llamaChatCompletionModel, inputs.getRequest());
        GenericRequestManager<UnifiedChatInput> manager = new GenericRequestManager<UnifiedChatInput>(this.getServiceComponents().threadPool(), overriddenModel, UNIFIED_CHAT_COMPLETION_HANDLER, unifiedChatInput -> new LlamaChatCompletionRequest((UnifiedChatInput)unifiedChatInput, overriddenModel), UnifiedChatInput.class);
        String errorMessage = LlamaActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
        SenderExecutableAction action = new SenderExecutableAction(this.getSender(), manager, errorMessage);
        action.execute(inputs, timeout, listener);
    }

    public Set<TaskType> supportedStreamingTasks() {
        return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
    }

    public InferenceServiceConfiguration getConfiguration() {
        return Configuration.get();
    }

    public EnumSet<TaskType> supportedTaskTypes() {
        return SUPPORTED_TASK_TYPES;
    }

    public String name() {
        return NAME;
    }

    public void parseRequestConfig(String modelId, TaskType taskType, Map<String, Object> config, ActionListener<Model> parsedModelListener) {
        try {
            Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
            Map<String, Object> taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, "task_settings");
            ChunkingSettings chunkingSettings = null;
            if (TaskType.TEXT_EMBEDDING.equals((Object)taskType)) {
                chunkingSettings = ChunkingSettingsBuilder.fromMap(ServiceUtils.removeFromMapOrDefaultEmpty(config, "chunking_settings"));
            }
            LlamaModel model = this.createModel(modelId, taskType, serviceSettingsMap, chunkingSettings, serviceSettingsMap, ConfigurationParseContext.REQUEST);
            ServiceUtils.throwIfNotEmptyMap(config, NAME);
            ServiceUtils.throwIfNotEmptyMap(serviceSettingsMap, NAME);
            ServiceUtils.throwIfNotEmptyMap(taskSettingsMap, NAME);
            parsedModelListener.onResponse((Object)model);
        }
        catch (Exception e) {
            parsedModelListener.onFailure(e);
        }
    }

    public Model parsePersistedConfigWithSecrets(String modelId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        ServiceUtils.removeFromMapOrDefaultEmpty(config, "task_settings");
        Map<String, Object> secretSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(secrets, "secret_settings");
        ChunkingSettings chunkingSettings = null;
        if (TaskType.TEXT_EMBEDDING.equals((Object)taskType)) {
            chunkingSettings = ChunkingSettingsBuilder.fromMap(ServiceUtils.removeFromMap(config, "chunking_settings"));
        }
        return this.createModelFromPersistent(modelId, taskType, serviceSettingsMap, chunkingSettings, secretSettingsMap);
    }

    private LlamaModel createModelFromPersistent(String inferenceEntityId, TaskType taskType, Map<String, Object> serviceSettings, ChunkingSettings chunkingSettings, Map<String, Object> secretSettings) {
        return this.createModel(inferenceEntityId, taskType, serviceSettings, chunkingSettings, secretSettings, ConfigurationParseContext.PERSISTENT);
    }

    public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        ServiceUtils.removeFromMapOrDefaultEmpty(config, "task_settings");
        ChunkingSettings chunkingSettings = null;
        if (TaskType.TEXT_EMBEDDING.equals((Object)taskType)) {
            chunkingSettings = ChunkingSettingsBuilder.fromMap(ServiceUtils.removeFromMap(config, "chunking_settings"));
        }
        return this.createModelFromPersistent(modelId, taskType, serviceSettingsMap, chunkingSettings, null);
    }

    public TransportVersion getMinimalSupportedVersion() {
        return ML_INFERENCE_LLAMA_ADDED;
    }

    public static class Configuration {
        private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> CONFIGURATION = new LazyInitializable(() -> {
            HashMap<String, SettingsConfiguration> configurationMap = new HashMap<String, SettingsConfiguration>();
            configurationMap.put("url", new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.").setLabel("URL").setRequired(Boolean.valueOf(true)).setSensitive(Boolean.valueOf(false)).setUpdatable(Boolean.valueOf(false)).setType(SettingsConfigurationFieldType.STRING).build());
            configurationMap.put("model_id", new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("Refer to the Llama models documentation for the list of available models.").setLabel("Model").setRequired(Boolean.valueOf(true)).setSensitive(Boolean.valueOf(false)).setUpdatable(Boolean.valueOf(false)).setType(SettingsConfigurationFieldType.STRING).build());
            configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
            configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES));
            return new InferenceServiceConfiguration.Builder().setService(LlamaService.NAME).setName(LlamaService.SERVICE_NAME).setTaskTypes(SUPPORTED_TASK_TYPES).setConfigurations(configurationMap).build();
        });

        public static InferenceServiceConfiguration get() {
            return (InferenceServiceConfiguration)CONFIGURATION.getOrCompute();
        }

        private Configuration() {
        }
    }
}

