/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action;

import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.features.FeatureService;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.protocol.xpack.XPackUsageRequest;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackFeatureUsage;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureTransportAction;
import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.usage.ModelStats;
import org.elasticsearch.xpack.core.inference.usage.SemanticTextStats;
import org.elasticsearch.xpack.inference.InferenceFeatures;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

public class TransportInferenceUsageAction
extends XPackUsageFeatureTransportAction {
    private final Logger logger = LogManager.getLogger(TransportInferenceUsageAction.class);
    private static final String MODEL_ID_LINUX_SUFFIX = "_linux-x86_64";
    private static final EnumSet<TaskType> TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING);
    private final FeatureService featureService;
    private final ModelRegistry modelRegistry;
    private final Client client;

    @Inject
    public TransportInferenceUsageAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, ModelRegistry modelRegistry, Client client, FeatureService featureService) {
        super(XPackUsageFeatureAction.INFERENCE.name(), transportService, clusterService, threadPool, actionFilters);
        this.modelRegistry = modelRegistry;
        this.client = new OriginSettingClient(client, "ml");
        this.featureService = featureService;
    }

    protected void localClusterStateOperation(Task task, XPackUsageRequest request, ClusterState state, ActionListener<XPackUsageFeatureResponse> listener) {
        GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY, false);
        this.client.execute((ActionType)GetInferenceModelAction.INSTANCE, (ActionRequest)getInferenceModelAction, ActionListener.wrap(response -> listener.onResponse((Object)new XPackUsageFeatureResponse((XPackFeatureUsage)this.collectUsage(response.getEndpoints(), state.getMetadata().indicesAllProjects()))), e -> {
            this.logger.warn(Strings.format((String)"Retrieving inference usage failed with error: %s", (Object[])new Object[]{e.getMessage()}), (Throwable)e);
            listener.onResponse((Object)new XPackUsageFeatureResponse((XPackFeatureUsage)InferenceFeatureSetUsage.EMPTY));
        }));
    }

    private InferenceFeatureSetUsage collectUsage(List<ModelConfigurations> endpoints, Iterable<IndexMetadata> indicesMetadata) {
        Map<ServiceAndTaskType, Map<String, List<InferenceFieldMetadata>>> inferenceFieldsByIndexServiceAndTask = TransportInferenceUsageAction.mapInferenceFieldsByIndexServiceAndTask(indicesMetadata, endpoints);
        TreeMap<String, ModelStats> endpointStats = new TreeMap<String, ModelStats>();
        this.addStatsByServiceAndTask(inferenceFieldsByIndexServiceAndTask, endpoints, endpointStats);
        this.addStatsForDefaultModelsCompatibleWithSemanticText(inferenceFieldsByIndexServiceAndTask, endpoints, endpointStats);
        return new InferenceFeatureSetUsage(endpointStats.values());
    }

    private static Map<ServiceAndTaskType, Map<String, List<InferenceFieldMetadata>>> mapInferenceFieldsByIndexServiceAndTask(Iterable<IndexMetadata> indicesMetadata, List<ModelConfigurations> endpoints) {
        Map inferenceIdToEndpoint = endpoints.stream().collect(Collectors.toMap(ModelConfigurations::getInferenceEntityId, Function.identity()));
        HashMap<ServiceAndTaskType, Map<String, List<InferenceFieldMetadata>>> inferenceFieldByIndexServiceAndTask = new HashMap<ServiceAndTaskType, Map<String, List<InferenceFieldMetadata>>>();
        for (IndexMetadata indexMetadata : indicesMetadata) {
            if (indexMetadata.isSystem() || indexMetadata.isHidden()) continue;
            indexMetadata.getInferenceFields().values().stream().filter(field -> inferenceIdToEndpoint.containsKey(field.getInferenceId())).forEach(field -> {
                ModelConfigurations endpoint = (ModelConfigurations)inferenceIdToEndpoint.get(field.getInferenceId());
                Map fieldsByIndex = inferenceFieldByIndexServiceAndTask.computeIfAbsent(new ServiceAndTaskType(endpoint.getService(), endpoint.getTaskType()), key -> new HashMap());
                fieldsByIndex.computeIfAbsent(indexMetadata.getIndex().getName(), key -> new ArrayList()).add(field);
            });
        }
        return inferenceFieldByIndexServiceAndTask;
    }

    private void addStatsByServiceAndTask(Map<ServiceAndTaskType, Map<String, List<InferenceFieldMetadata>>> inferenceFieldsByIndexServiceAndTask, List<ModelConfigurations> endpoints, Map<String, ModelStats> endpointStats) {
        for (ModelConfigurations model : endpoints) {
            endpointStats.computeIfAbsent(new ServiceAndTaskType(model.getService(), model.getTaskType()).toString(), key -> TransportInferenceUsageAction.createEmptyStats(model)).add();
            endpointStats.computeIfAbsent(new ServiceAndTaskType("_all", model.getTaskType()).toString(), key -> TransportInferenceUsageAction.createEmptyStats("_all", model.getTaskType())).add();
        }
        inferenceFieldsByIndexServiceAndTask.forEach((serviceAndTaskType, inferenceFieldsByIndex) -> TransportInferenceUsageAction.addSemanticTextStats(inferenceFieldsByIndex, (ModelStats)endpointStats.get(serviceAndTaskType.toString())));
        this.addTopLevelStatsByTask(inferenceFieldsByIndexServiceAndTask, endpointStats);
    }

    private static ModelStats createEmptyStats(ModelConfigurations model) {
        return TransportInferenceUsageAction.createEmptyStats(model.getService(), model.getTaskType());
    }

    private static ModelStats createEmptyStats(String service, TaskType taskType) {
        return new ModelStats(service, taskType, 0L, TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(taskType) ? new SemanticTextStats() : null);
    }

    private void addTopLevelStatsByTask(Map<ServiceAndTaskType, Map<String, List<InferenceFieldMetadata>>> inferenceFieldsByIndexServiceAndTask, Map<String, ModelStats> endpointStats) {
        for (TaskType taskType : TaskType.values()) {
            if (taskType == TaskType.ANY || taskType == TaskType.EMBEDDING && !this.featureService.clusterHasFeature(this.clusterService.state(), InferenceFeatures.EMBEDDING_TASK_TYPE)) continue;
            ModelStats allStatsForTaskType = endpointStats.computeIfAbsent(new ServiceAndTaskType("_all", taskType).toString(), key -> TransportInferenceUsageAction.createEmptyStats("_all", taskType));
            if (!TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(taskType)) continue;
            Map<String, List<InferenceFieldMetadata>> inferenceFieldsByIndex = inferenceFieldsByIndexServiceAndTask.entrySet().stream().filter(e -> ((ServiceAndTaskType)e.getKey()).taskType == taskType).flatMap(m -> ((Map)m.getValue()).entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (l1, l2) -> Stream.concat(l1.stream(), l2.stream()).toList()));
            TransportInferenceUsageAction.addSemanticTextStats(inferenceFieldsByIndex, allStatsForTaskType);
        }
    }

    private static void addSemanticTextStats(Map<String, List<InferenceFieldMetadata>> inferenceFieldsByIndex, ModelStats stat) {
        HashSet inferenceIds = new HashSet();
        for (List<InferenceFieldMetadata> inferenceFields : inferenceFieldsByIndex.values()) {
            stat.semanticTextStats().addFieldCount((long)inferenceFields.size());
            stat.semanticTextStats().incIndicesCount();
            inferenceFields.forEach(field -> inferenceIds.add(field.getInferenceId()));
        }
        stat.semanticTextStats().setInferenceIdCount((long)inferenceIds.size());
    }

    private void addStatsForDefaultModelsCompatibleWithSemanticText(Map<ServiceAndTaskType, Map<String, List<InferenceFieldMetadata>>> inferenceFieldsByIndexServiceAndTask, List<ModelConfigurations> endpoints, Map<String, ModelStats> endpointStats) {
        Map<String, String> endpointIdToModelId = endpoints.stream().filter(endpoint -> endpoint.getServiceSettings().modelId() != null).collect(Collectors.toMap(ModelConfigurations::getInferenceEntityId, e -> TransportInferenceUsageAction.stripLinuxSuffix(e.getServiceSettings().modelId())));
        Map<DefaultModelStatsKey, Long> defaultModelsToEndpointCount = this.createStatsKeysWithEndpointCountsForDefaultModelsCompatibleWithSemanticText(endpoints);
        for (Map.Entry<DefaultModelStatsKey, Long> defaultModelStatsKeyToEndpointCount : defaultModelsToEndpointCount.entrySet()) {
            DefaultModelStatsKey statKey = defaultModelStatsKeyToEndpointCount.getKey();
            Map<String, List<InferenceFieldMetadata>> fieldsByIndex = inferenceFieldsByIndexServiceAndTask.getOrDefault(new ServiceAndTaskType(statKey.service, statKey.taskType), Map.of());
            fieldsByIndex = TransportInferenceUsageAction.filterFields(fieldsByIndex, f -> statKey.modelId.equals(endpointIdToModelId.get(f.getInferenceId())));
            ModelStats stats = new ModelStats(statKey.toString(), statKey.taskType, defaultModelStatsKeyToEndpointCount.getValue().longValue(), new SemanticTextStats());
            TransportInferenceUsageAction.addSemanticTextStats(fieldsByIndex, stats);
            endpointStats.put(statKey.toString(), stats);
        }
    }

    private Map<DefaultModelStatsKey, Long> createStatsKeysWithEndpointCountsForDefaultModelsCompatibleWithSemanticText(List<ModelConfigurations> endpoints) {
        Set modelIds = endpoints.stream().filter(endpoint -> TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(endpoint.getTaskType())).filter(endpoint -> this.modelRegistry.containsPreconfiguredInferenceEndpointId(endpoint.getInferenceEntityId())).filter(endpoint -> endpoint.getServiceSettings().modelId() != null).map(endpoint -> TransportInferenceUsageAction.stripLinuxSuffix(endpoint.getServiceSettings().modelId())).collect(Collectors.toSet());
        return endpoints.stream().filter(endpoint -> endpoint.getServiceSettings().modelId() != null).filter(endpoint -> modelIds.contains(TransportInferenceUsageAction.stripLinuxSuffix(endpoint.getServiceSettings().modelId()))).map(endpoint -> new DefaultModelStatsKey(endpoint.getService(), endpoint.getTaskType(), TransportInferenceUsageAction.stripLinuxSuffix(endpoint.getServiceSettings().modelId()))).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
    }

    private static Map<String, List<InferenceFieldMetadata>> filterFields(Map<String, List<InferenceFieldMetadata>> fieldsByIndex, Predicate<InferenceFieldMetadata> predicate) {
        HashMap<String, List<InferenceFieldMetadata>> filtered = new HashMap<String, List<InferenceFieldMetadata>>();
        for (Map.Entry<String, List<InferenceFieldMetadata>> entry : fieldsByIndex.entrySet()) {
            List<InferenceFieldMetadata> filteredFields = entry.getValue().stream().filter(predicate).toList();
            if (filteredFields.isEmpty()) continue;
            filtered.put(entry.getKey(), filteredFields);
        }
        return filtered;
    }

    private static String stripLinuxSuffix(String modelId) {
        if (modelId.endsWith(MODEL_ID_LINUX_SUFFIX)) {
            return modelId.substring(0, modelId.length() - MODEL_ID_LINUX_SUFFIX.length());
        }
        return modelId;
    }

    private record ServiceAndTaskType(String service, TaskType taskType) {
        @Override
        public String toString() {
            return this.service + ":" + this.taskType.name();
        }
    }

    private record DefaultModelStatsKey(String service, TaskType taskType, String modelId) {
        @Override
        public String toString() {
            return "_" + this.service + "_" + this.modelId.replace('.', '_');
        }
    }
}

