/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.common;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.common.RateLimitAssignment;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.SenderService;

public class InferenceServiceNodeLocalRateLimitCalculator
implements InferenceServiceRateLimitCalculator {
    public static final Integer DEFAULT_MAX_NODES_PER_GROUPING = 3;
    static final Map<String, Collection<NodeLocalRateLimitConfig>> SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS = Map.of("elastic", List.of(new NodeLocalRateLimitConfig(TaskType.SPARSE_EMBEDDING, numNodesInCluster -> DEFAULT_MAX_NODES_PER_GROUPING)));
    private static final Logger logger = LogManager.getLogger(InferenceServiceNodeLocalRateLimitCalculator.class);
    private final InferenceServiceRegistry serviceRegistry;
    private final ConcurrentHashMap<String, Map<TaskType, RateLimitAssignment>> serviceAssignments;

    @Inject
    public InferenceServiceNodeLocalRateLimitCalculator(ClusterService clusterService, InferenceServiceRegistry serviceRegistry) {
        clusterService.addListener((ClusterStateListener)this);
        this.serviceRegistry = serviceRegistry;
        this.serviceAssignments = new ConcurrentHashMap();
    }

    public void clusterChanged(ClusterChangedEvent event) {
        boolean clusterTopologyChanged = event.nodesChanged();
        if (clusterTopologyChanged) {
            this.updateAssignments(event);
        }
    }

    @Override
    public boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType) {
        return ((Collection)SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.getOrDefault(serviceName, Collections.emptyList())).stream().anyMatch(rateLimitConfig -> taskType.equals((Object)rateLimitConfig.taskType));
    }

    @Override
    public RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType) {
        Map<TaskType, RateLimitAssignment> assignmentsPerTaskType = this.serviceAssignments.get(service);
        if (assignmentsPerTaskType == null) {
            return null;
        }
        return assignmentsPerTaskType.get(taskType);
    }

    private void updateAssignments(ClusterChangedEvent event) {
        ClusterState newClusterState = event.state();
        Collection nodes = newClusterState.nodes().getAllNodes();
        List<DiscoveryNode> sortedNodes = nodes.stream().sorted(Comparator.comparing(DiscoveryNode::getId)).toList();
        ArrayList sortedServices = new ArrayList(this.serviceRegistry.getServices().values());
        sortedServices.sort(Comparator.comparing(InferenceService::name));
        for (String serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
            Optional service = this.serviceRegistry.getService(serviceName);
            if (service.isPresent()) {
                InferenceService inferenceService = (InferenceService)service.get();
                for (NodeLocalRateLimitConfig rateLimitConfig : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName)) {
                    HashMap<TaskType, RateLimitAssignment> perTaskTypeAssignments = new HashMap<TaskType, RateLimitAssignment>();
                    TaskType taskType = rateLimitConfig.taskType();
                    List<DiscoveryNode> assignedNodes = this.calculateServiceAssignment(rateLimitConfig.maxNodesPerGroupingStrategy(), sortedNodes);
                    int numAssignedNodes = assignedNodes.size();
                    this.updateRateLimits(inferenceService, numAssignedNodes);
                    perTaskTypeAssignments.put(taskType, new RateLimitAssignment(assignedNodes));
                    this.serviceAssignments.put(serviceName, perTaskTypeAssignments);
                }
                continue;
            }
            logger.warn("Service [{}] is configured for node-local rate limiting but was not found in the service registry", (Object)serviceName);
        }
    }

    private List<DiscoveryNode> calculateServiceAssignment(MaxNodesPerGroupingStrategy maxNodesPerGroupingStrategy, List<DiscoveryNode> sortedNodes) {
        int numberOfNodes = sortedNodes.size();
        int nodesPerGrouping = Math.min(numberOfNodes, maxNodesPerGroupingStrategy.calculate(numberOfNodes));
        ArrayList<DiscoveryNode> assignedNodes = new ArrayList<DiscoveryNode>();
        for (int j = 0; j < nodesPerGrouping; ++j) {
            DiscoveryNode assignedNode = sortedNodes.get(j % numberOfNodes);
            assignedNodes.add(assignedNode);
        }
        return assignedNodes;
    }

    private void updateRateLimits(InferenceService service, int responsibleNodes) {
        if (!(service instanceof SenderService)) {
            return;
        }
        SenderService senderService = (SenderService)service;
        Sender sender = senderService.getSender();
        sender.updateRateLimitDivisor(responsibleNodes);
    }

    InferenceServiceRegistry serviceRegistry() {
        return this.serviceRegistry;
    }

    record NodeLocalRateLimitConfig(TaskType taskType, MaxNodesPerGroupingStrategy maxNodesPerGroupingStrategy) {
    }

    @FunctionalInterface
    private static interface MaxNodesPerGroupingStrategy {
        public Integer calculate(Integer var1);
    }
}

