/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.elastic.authorization;

import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;

public class ElasticInferenceServiceAuthorizationModel {
    private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationModel.class);
    private static final String UNKNOWN_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unknown task type [{}], skipping";
    private static final String UNSUPPORTED_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unsupported task type [{}], skipping";
    private final Map<String, ElasticInferenceServiceModel> authorizedEndpoints;
    private final EnumSet<TaskType> taskTypes;

    public static ElasticInferenceServiceAuthorizationModel of(ElasticInferenceServiceAuthorizationResponseEntity responseEntity, String baseEisUrl) {
        ElasticInferenceServiceComponents components = new ElasticInferenceServiceComponents(baseEisUrl);
        return ElasticInferenceServiceAuthorizationModel.createInternal(responseEntity.getAuthorizedEndpoints(), components);
    }

    private static ElasticInferenceServiceAuthorizationModel createInternal(List<ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint> responseEndpoints, ElasticInferenceServiceComponents components) {
        ArrayList<ElasticInferenceServiceModel> validEndpoints = new ArrayList<ElasticInferenceServiceModel>();
        for (ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint : responseEndpoints) {
            ElasticInferenceServiceModel model = ElasticInferenceServiceAuthorizationModel.createModel(authorizedEndpoint, components);
            if (model == null) continue;
            validEndpoints.add(model);
        }
        return new ElasticInferenceServiceAuthorizationModel(validEndpoints);
    }

    private static ElasticInferenceServiceModel createModel(ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components) {
        try {
            TaskType taskType = ElasticInferenceServiceAuthorizationModel.getTaskType(authorizedEndpoint.taskType().elasticsearchTaskType());
            if (taskType == null) {
                logger.warn(UNKNOWN_TASK_TYPE_LOG_MESSAGE, (Object)authorizedEndpoint.id(), (Object)authorizedEndpoint.taskType());
                return null;
            }
            return switch (taskType) {
                case TaskType.CHAT_COMPLETION -> ElasticInferenceServiceAuthorizationModel.createCompletionModel(authorizedEndpoint, TaskType.CHAT_COMPLETION, components);
                case TaskType.COMPLETION -> ElasticInferenceServiceAuthorizationModel.createCompletionModel(authorizedEndpoint, TaskType.COMPLETION, components);
                case TaskType.SPARSE_EMBEDDING -> ElasticInferenceServiceAuthorizationModel.createSparseTextEmbeddingsModel(authorizedEndpoint, components);
                case TaskType.TEXT_EMBEDDING -> ElasticInferenceServiceAuthorizationModel.createDenseTextEmbeddingsModel(authorizedEndpoint, components);
                case TaskType.RERANK -> ElasticInferenceServiceAuthorizationModel.createRerankModel(authorizedEndpoint, components);
                default -> {
                    logger.info(UNSUPPORTED_TASK_TYPE_LOG_MESSAGE, (Object)authorizedEndpoint.id(), (Object)taskType);
                    yield null;
                }
            };
        }
        catch (Exception e) {
            logger.atWarn().withThrowable((Throwable)e).log("Failed to create model for authorized endpoint id [{}] with task type [{}], skipping", (Object)authorizedEndpoint.id(), (Object)authorizedEndpoint.taskType());
            return null;
        }
    }

    private static TaskType getTaskType(String taskType) {
        try {
            return TaskType.fromString((String)taskType);
        }
        catch (IllegalArgumentException e) {
            return null;
        }
    }

    private static ElasticInferenceServiceCompletionModel createCompletionModel(ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, TaskType taskType, ElasticInferenceServiceComponents components) {
        return new ElasticInferenceServiceCompletionModel(authorizedEndpoint.id(), taskType, "elastic", new ElasticInferenceServiceCompletionServiceSettings(authorizedEndpoint.modelName()), (TaskSettings)EmptyTaskSettings.INSTANCE, (SecretSettings)EmptySecretSettings.INSTANCE, components);
    }

    private static ElasticInferenceServiceSparseEmbeddingsModel createSparseTextEmbeddingsModel(ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components) {
        return new ElasticInferenceServiceSparseEmbeddingsModel(authorizedEndpoint.id(), TaskType.SPARSE_EMBEDDING, "elastic", new ElasticInferenceServiceSparseEmbeddingsServiceSettings(authorizedEndpoint.modelName(), null), (TaskSettings)EmptyTaskSettings.INSTANCE, (SecretSettings)EmptySecretSettings.INSTANCE, components, ChunkingSettingsBuilder.fromMap(ElasticInferenceServiceAuthorizationModel.getChunkingSettingsMap(ElasticInferenceServiceAuthorizationModel.getConfigurationOrEmpty(authorizedEndpoint))));
    }

    private static ElasticInferenceServiceAuthorizationResponseEntity.Configuration getConfigurationOrEmpty(ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint) {
        if (authorizedEndpoint.configuration() != null) {
            return authorizedEndpoint.configuration();
        }
        return ElasticInferenceServiceAuthorizationResponseEntity.Configuration.EMPTY;
    }

    private static Map<String, Object> getChunkingSettingsMap(ElasticInferenceServiceAuthorizationResponseEntity.Configuration configuration) {
        return Objects.requireNonNullElse(configuration.chunkingSettings(), new HashMap());
    }

    private static ElasticInferenceServiceDenseTextEmbeddingsModel createDenseTextEmbeddingsModel(ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components) {
        ElasticInferenceServiceAuthorizationResponseEntity.Configuration config = ElasticInferenceServiceAuthorizationModel.getConfigurationOrEmpty(authorizedEndpoint);
        ElasticInferenceServiceAuthorizationModel.validateConfigurationForTextEmbedding(config);
        return new ElasticInferenceServiceDenseTextEmbeddingsModel(authorizedEndpoint.id(), TaskType.TEXT_EMBEDDING, "elastic", new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(authorizedEndpoint.modelName(), ElasticInferenceServiceAuthorizationModel.getSimilarityMeasure(config), config.dimensions(), null), (TaskSettings)EmptyTaskSettings.INSTANCE, (SecretSettings)EmptySecretSettings.INSTANCE, components, ChunkingSettingsBuilder.fromMap(ElasticInferenceServiceAuthorizationModel.getChunkingSettingsMap(config)));
    }

    private static void validateConfigurationForTextEmbedding(ElasticInferenceServiceAuthorizationResponseEntity.Configuration config) {
        ElasticInferenceServiceAuthorizationModel.validateFieldPresent("element_type", config.elementType(), TaskType.TEXT_EMBEDDING);
        ElasticInferenceServiceAuthorizationModel.validateFieldPresent("dimensions", config.dimensions(), TaskType.TEXT_EMBEDDING);
        ElasticInferenceServiceAuthorizationModel.validateFieldPresent("similarity", config.similarity(), TaskType.TEXT_EMBEDDING);
    }

    private static void validateFieldPresent(String field, Object fieldValue, TaskType taskType) {
        if (fieldValue == null) {
            throw new IllegalArgumentException(Strings.format((String)"Required field [%s] is missing for task_type [%s]", (Object[])new Object[]{field, taskType.toString()}));
        }
    }

    private static SimilarityMeasure getSimilarityMeasure(ElasticInferenceServiceAuthorizationResponseEntity.Configuration configuration) {
        return SimilarityMeasure.fromString((String)configuration.similarity());
    }

    private static ElasticInferenceServiceRerankModel createRerankModel(ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedEndpoint authorizedEndpoint, ElasticInferenceServiceComponents components) {
        return new ElasticInferenceServiceRerankModel(authorizedEndpoint.id(), TaskType.RERANK, "elastic", new ElasticInferenceServiceRerankServiceSettings(authorizedEndpoint.modelName()), (TaskSettings)EmptyTaskSettings.INSTANCE, (SecretSettings)EmptySecretSettings.INSTANCE, components);
    }

    public static ElasticInferenceServiceAuthorizationModel unauthorized() {
        return new ElasticInferenceServiceAuthorizationModel(List.of());
    }

    ElasticInferenceServiceAuthorizationModel(List<ElasticInferenceServiceModel> authorizedEndpoints) {
        Objects.requireNonNull(authorizedEndpoints);
        this.authorizedEndpoints = authorizedEndpoints.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity(), (firstModel, secondModel) -> {
            logger.warn("Found inference id collision for id [{}], ignoring second model", (Object)firstModel.inferenceEntityId());
            return firstModel;
        }, HashMap::new));
        EnumSet<TaskType> taskTypesSet = EnumSet.noneOf(TaskType.class);
        taskTypesSet.addAll(this.authorizedEndpoints.values().stream().map(Model::getTaskType).toList());
        this.taskTypes = taskTypesSet;
    }

    public boolean isAuthorized() {
        return !this.authorizedEndpoints.isEmpty();
    }

    public ElasticInferenceServiceAuthorizationModel newLimitedToTaskTypes(EnumSet<TaskType> taskTypes) {
        List<ElasticInferenceServiceModel> endpoints = this.authorizedEndpoints.values().stream().filter(endpoint -> taskTypes.contains(endpoint.getTaskType())).toList();
        return new ElasticInferenceServiceAuthorizationModel(endpoints);
    }

    public EnumSet<TaskType> getTaskTypes() {
        return EnumSet.copyOf(this.taskTypes);
    }

    public Set<String> getEndpointIds() {
        return Set.copyOf(this.authorizedEndpoints.keySet());
    }

    public List<Model> getEndpoints(Set<String> endpointIds) {
        return endpointIds.stream().map(this.authorizedEndpoints::get).filter(Objects::nonNull).toList();
    }

    public String toString() {
        return Strings.format((String)"AuthorizationModel{authorizedEndpoints=%s, taskTypes=%s}", (Object[])new Object[]{this.authorizedEndpoints, this.taskTypes});
    }

    public boolean equals(Object o) {
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        ElasticInferenceServiceAuthorizationModel that = (ElasticInferenceServiceAuthorizationModel)o;
        return Objects.equals(this.authorizedEndpoints, that.authorizedEndpoints) && Objects.equals(this.taskTypes, that.taskTypes);
    }

    public int hashCode() {
        return Objects.hash(this.authorizedEndpoints, this.taskTypes);
    }
}

