/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.assignment;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.ClusterStateUpdateTask;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.gateway.GatewayService;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentRebalancer;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AllocationReducer;
import org.elasticsearch.xpack.ml.job.NodeLoad;
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
import org.elasticsearch.xpack.ml.notifications.SystemAuditor;

public class TrainedModelAssignmentClusterService
implements ClusterStateListener {
    private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentClusterService.class);
    private final ClusterService clusterService;
    private final ThreadPool threadPool;
    private final NodeLoadDetector nodeLoadDetector;
    private final SystemAuditor systemAuditor;
    private final NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper;
    private final Client client;
    private volatile int maxMemoryPercentage;
    private volatile boolean useAuto;
    private volatile int maxOpenJobs;
    protected volatile int maxLazyMLNodes;
    protected volatile long maxMLNodeSize;
    protected volatile int allocatedProcessorsScale;

    public TrainedModelAssignmentClusterService(Settings settings, ClusterService clusterService, ThreadPool threadPool, NodeLoadDetector nodeLoadDetector, SystemAuditor systemAuditor, NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper, Client client) {
        this.clusterService = Objects.requireNonNull(clusterService);
        this.threadPool = Objects.requireNonNull(threadPool);
        this.nodeLoadDetector = Objects.requireNonNull(nodeLoadDetector);
        this.systemAuditor = Objects.requireNonNull(systemAuditor);
        this.nodeAvailabilityZoneMapper = Objects.requireNonNull(nodeAvailabilityZoneMapper);
        this.maxMemoryPercentage = (Integer)MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings);
        this.useAuto = (Boolean)MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings);
        this.maxOpenJobs = (Integer)MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(settings);
        this.maxLazyMLNodes = (Integer)MachineLearningField.MAX_LAZY_ML_NODES.get(settings);
        this.maxMLNodeSize = ((ByteSizeValue)MachineLearning.MAX_ML_NODE_SIZE.get(settings)).getBytes();
        this.allocatedProcessorsScale = (Integer)MachineLearning.ALLOCATED_PROCESSORS_SCALE.get(settings);
        this.client = client;
        if (DiscoveryNode.isMasterNode((Settings)settings)) {
            clusterService.addListener((ClusterStateListener)this);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, this::setMaxMemoryPercentage);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT, this::setUseAuto);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_OPEN_JOBS_PER_NODE, this::setMaxOpenJobs);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearningField.MAX_LAZY_ML_NODES, this::setMaxLazyMLNodes);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_ML_NODE_SIZE, this::setMaxMLNodeSize);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.ALLOCATED_PROCESSORS_SCALE, this::setAllocatedProcessorsScale);
        }
    }

    private void setMaxMemoryPercentage(int maxMemoryPercentage) {
        this.maxMemoryPercentage = maxMemoryPercentage;
    }

    private void setUseAuto(boolean useAuto) {
        this.useAuto = useAuto;
    }

    private void setMaxOpenJobs(int maxOpenJobs) {
        this.maxOpenJobs = maxOpenJobs;
    }

    private void setMaxLazyMLNodes(int value) {
        this.maxLazyMLNodes = value;
    }

    private void setMaxMLNodeSize(ByteSizeValue value) {
        this.maxMLNodeSize = value.getBytes();
    }

    private void setAllocatedProcessorsScale(int scale) {
        this.allocatedProcessorsScale = scale;
    }

    @SuppressForbidden(reason="legacy usage of unbatched task")
    private void submitUnbatchedTask(String source, ClusterStateUpdateTask task) {
        this.clusterService.submitUnbatchedStateUpdateTask(source, task);
    }

    public void clusterChanged(ClusterChangedEvent event) {
        Optional<String> rebalanceReason;
        if (this.eventStateHasGlobalBlockStateNotRecoveredBlock(event)) {
            return;
        }
        if (!event.localNodeMaster()) {
            return;
        }
        if (event.nodesAdded()) {
            this.logMlNodeHeterogeneity();
        }
        if ((rebalanceReason = TrainedModelAssignmentClusterService.detectReasonToRebalanceModels(event)).isPresent()) {
            this.rebalanceAssignments(event.state(), Optional.empty(), rebalanceReason.get(), (ActionListener<TrainedModelAssignmentMetadata>)ActionListener.wrap(newMetadata -> logger.debug(() -> Strings.format((String)"rebalanced model assignments [%s]", (Object[])new Object[]{org.elasticsearch.common.Strings.toString((ChunkedToXContent)newMetadata, (boolean)false, (boolean)true)})), e -> logger.warn("failed to rebalance models", (Throwable)e)));
        }
    }

    boolean eventStateHasGlobalBlockStateNotRecoveredBlock(ClusterChangedEvent event) {
        return event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK);
    }

    void logMlNodeHeterogeneity() {
        ActionListener<Set<String>> architecturesListener = TrainedModelAssignmentClusterService.getArchitecturesSetActionListener();
        MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(architecturesListener, (Client)this.client, (ExecutorService)this.threadPool.executor("ml_utility"));
    }

    static ActionListener<Set<String>> getArchitecturesSetActionListener() {
        ActionListener<Set<String>> architecturesListener = new ActionListener<Set<String>>(){

            public void onResponse(Set<String> architectures) {
                if (architectures.size() > 1) {
                    logger.warn(Strings.format((String)"Heterogeneous platform architectures were detected among ML nodes. This will prevent the deployment of some trained models. Distinct platform architectures detected: %s", (Object[])new Object[]{architectures}));
                }
            }

            public void onFailure(Exception e) {
                logger.error("Failed to detect heterogeneity among ML nodes with exception: ", (Throwable)e);
            }
        };
        return architecturesListener;
    }

    private void removeRoutingToRemovedOrShuttingDownNodes(ClusterChangedEvent event) {
        if (TrainedModelAssignmentClusterService.areAssignedNodesRemoved(event)) {
            this.submitUnbatchedTask("removing routing entries for removed or shutting down nodes", new ClusterStateUpdateTask(this){

                public ClusterState execute(ClusterState currentState) {
                    return TrainedModelAssignmentClusterService.removeRoutingToUnassignableNodes(currentState);
                }

                public void onFailure(Exception e) {
                    logger.error("could not remove routing entries for removed or shutting down nodes", (Throwable)e);
                }

                public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                    logger.debug(() -> Strings.format((String)"updated model assignments based on node changes in the cluster; new metadata [%s]", (Object[])new Object[]{org.elasticsearch.common.Strings.toString((ChunkedToXContent)TrainedModelAssignmentMetadata.fromState((ClusterState)newState), (boolean)false, (boolean)true)}));
                }
            });
        }
    }

    static boolean areAssignedNodesRemoved(ClusterChangedEvent event) {
        boolean nodesShutdownChanged = event.changedCustomClusterMetadataSet().contains("node_shutdown");
        if (event.nodesRemoved() || nodesShutdownChanged) {
            HashSet<String> removedOrShuttingDownNodeIds = new HashSet<String>(TrainedModelAssignmentClusterService.nodesShuttingDown(event.state()));
            event.nodesDelta().removedNodes().stream().map(DiscoveryNode::getId).forEach(removedOrShuttingDownNodeIds::add);
            TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState((ClusterState)event.state());
            for (TrainedModelAssignment assignment : metadata.allAssignments().values()) {
                if (Sets.intersection(removedOrShuttingDownNodeIds, assignment.getNodeRoutingTable().keySet()).isEmpty()) continue;
                return true;
            }
        }
        return false;
    }

    static ClusterState removeRoutingToUnassignableNodes(ClusterState currentState) {
        Set assignableNodes = TrainedModelAssignmentClusterService.getAssignableNodes(currentState).stream().map(DiscoveryNode::getId).collect(Collectors.toSet());
        TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState((ClusterState)currentState);
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder((ClusterState)currentState);
        Set shuttingDownNodes = currentState.metadata().nodeShutdowns().getAllNodeIds();
        for (TrainedModelAssignment assignment : metadata.allAssignments().values()) {
            Set routedNodeIdsToRemove = Sets.difference(assignment.getNodeRoutingTable().keySet(), assignableNodes);
            if (routedNodeIdsToRemove.isEmpty()) continue;
            logger.debug(() -> Strings.format((String)"[%s] removing routing entries to nodes %s because they have been removed or are shutting down", (Object[])new Object[]{assignment.getDeploymentId(), routedNodeIdsToRemove}));
            TrainedModelAssignment.Builder assignmentBuilder = TrainedModelAssignmentClusterService.removeRoutingBuilder(routedNodeIdsToRemove, shuttingDownNodes, assignment);
            builder.updateAssignment(assignment.getDeploymentId(), assignmentBuilder.calculateAndSetAssignmentState());
        }
        return TrainedModelAssignmentClusterService.update(currentState, builder);
    }

    private static TrainedModelAssignment.Builder removeRoutingBuilder(Set<String> nodeIds, Set<String> shuttingDownNodes, TrainedModelAssignment assignment) {
        TrainedModelAssignment.Builder assignmentBuilder = TrainedModelAssignment.Builder.fromAssignment((TrainedModelAssignment)assignment);
        for (String nodeIdToRemove : nodeIds) {
            RoutingInfo routingInfoToRemove = (RoutingInfo)assignment.getNodeRoutingTable().get(nodeIdToRemove);
            if (!shuttingDownNodes.contains(nodeIdToRemove)) {
                logger.debug(() -> Strings.format((String)"[%s] Removing route for unassignable node id [%s]", (Object[])new Object[]{assignment.getDeploymentId(), nodeIdToRemove}));
                assignmentBuilder.removeRoutingEntry(nodeIdToRemove);
                continue;
            }
            if (routingInfoToRemove == null || !routingInfoToRemove.getState().isAnyOf(new RoutingState[]{RoutingState.STARTED, RoutingState.STARTING})) continue;
            logger.debug(() -> Strings.format((String)"[%s] Found assignment with route to shutting down node id [%s], adding stopping route", (Object[])new Object[]{assignment.getDeploymentId(), nodeIdToRemove}));
            RoutingInfo stoppingRouteInfo = TrainedModelAssignmentUtils.createShuttingDownRoute((RoutingInfo)((RoutingInfo)assignment.getNodeRoutingTable().get(nodeIdToRemove)));
            assignmentBuilder.addOrOverwriteRoutingEntry(nodeIdToRemove, stoppingRouteInfo);
        }
        return assignmentBuilder;
    }

    public void updateModelRoutingTable(final UpdateTrainedModelAssignmentRoutingInfoAction.Request request, final ActionListener<AcknowledgedResponse> listener) {
        logger.debug(() -> Strings.format((String)"[%s] updating routing table entry for node [%s], update [%s]", (Object[])new Object[]{request.getDeploymentId(), request.getNodeId(), request.getUpdate()}));
        this.submitUnbatchedTask("updating model routing for node assignment", new ClusterStateUpdateTask(this){

            public ClusterState execute(ClusterState currentState) {
                return TrainedModelAssignmentClusterService.updateModelRoutingTable(currentState, request);
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }

            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                listener.onResponse((Object)AcknowledgedResponse.TRUE);
            }
        });
    }

    public void createNewModelAssignment(CreateTrainedModelAssignmentAction.Request request, ActionListener<TrainedModelAssignment> listener) {
        if (MlMetadata.getMlMetadata((ClusterState)this.clusterService.state()).isResetMode()) {
            listener.onFailure((Exception)((Object)new ElasticsearchStatusException("cannot create new assignment [{}] for model [{}] while feature reset is in progress.", RestStatus.CONFLICT, new Object[]{request.getTaskParams().getDeploymentId(), request.getTaskParams().getModelId()})));
            return;
        }
        this.rebalanceAssignments(this.clusterService.state(), Optional.of(request), "model deployment started", (ActionListener<TrainedModelAssignmentMetadata>)ActionListener.wrap(newMetadata -> {
            TrainedModelAssignment assignment = newMetadata.getDeploymentAssignment(request.getTaskParams().getDeploymentId());
            if (assignment == null) {
                assignment = TrainedModelAssignment.Builder.empty((CreateTrainedModelAssignmentAction.Request)request).build();
            }
            listener.onResponse((Object)assignment);
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    public void setModelAssignmentToStopping(final String modelId, final ActionListener<AcknowledgedResponse> listener) {
        this.submitUnbatchedTask("set model assignment stopping", new ClusterStateUpdateTask(this){

            public ClusterState execute(ClusterState currentState) {
                return TrainedModelAssignmentClusterService.setToStopping(currentState, modelId, "client API call");
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }

            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                listener.onResponse((Object)AcknowledgedResponse.TRUE);
            }
        });
    }

    public void removeModelAssignment(final String deploymentId, final ActionListener<AcknowledgedResponse> listener) {
        this.submitUnbatchedTask("delete model deployment assignment", new ClusterStateUpdateTask(){

            public ClusterState execute(ClusterState currentState) {
                return TrainedModelAssignmentClusterService.removeAssignment(currentState, deploymentId);
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }

            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                TrainedModelAssignmentClusterService.this.rebalanceAssignments(newState, Optional.empty(), "model deployment stopped", (ActionListener<TrainedModelAssignmentMetadata>)ActionListener.wrap(metadataAfterRebalance -> logger.debug(() -> Strings.format((String)"Successfully rebalanced model deployments after deployment [%s] was stopped", (Object[])new Object[]{deploymentId})), e -> logger.error(Strings.format((String)"Failed to rebalance model deployments after deployment [%s] was stopped", (Object[])new Object[]{deploymentId}), (Throwable)e)));
                listener.onResponse((Object)AcknowledgedResponse.TRUE);
            }
        });
    }

    public void removeAllModelAssignments(final ActionListener<AcknowledgedResponse> listener) {
        this.submitUnbatchedTask("delete all model assignments", new ClusterStateUpdateTask(this){

            public ClusterState execute(ClusterState currentState) {
                return TrainedModelAssignmentClusterService.removeAllAssignments(currentState);
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }

            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                listener.onResponse((Object)AcknowledgedResponse.TRUE);
            }
        });
    }

    private static ClusterState update(ClusterState currentState, TrainedModelAssignmentMetadata.Builder modelAssignments) {
        TrainedModelAssignmentMetadata previousMetadata = TrainedModelAssignmentMetadata.fromState((ClusterState)currentState);
        TrainedModelAssignmentMetadata updatedMetadata = modelAssignments.build();
        if (updatedMetadata.equals((Object)previousMetadata)) {
            return currentState;
        }
        return TrainedModelAssignmentClusterService.forceUpdate(currentState, modelAssignments);
    }

    private static ClusterState forceUpdate(ClusterState currentState, TrainedModelAssignmentMetadata.Builder modelAssignments) {
        logger.debug(() -> Strings.format((String)"updated assignments: %s", (Object[])new Object[]{modelAssignments.build()}));
        ProjectMetadata.Builder builder = ProjectMetadata.builder((ProjectMetadata)currentState.metadata().getProject());
        builder.putCustom("trained_model_assignment", (Metadata.ProjectCustom)modelAssignments.build()).removeCustom("trained_model_allocation");
        return ClusterState.builder((ClusterState)currentState).putProjectMetadata(builder).build();
    }

    ClusterState createModelAssignment(ClusterState currentState, CreateTrainedModelAssignmentAction.Request request) throws Exception {
        return TrainedModelAssignmentClusterService.update(currentState, this.rebalanceAssignments(currentState, Optional.of(request)));
    }

    private void rebalanceAssignments(ClusterState clusterState, Optional<CreateTrainedModelAssignmentAction.Request> createAssignmentRequest, String reason, ActionListener<TrainedModelAssignmentMetadata> listener) {
        ActionListener architecturesListener = ActionListener.wrap(mlNodesArchitectures -> this.threadPool.executor("ml_utility").execute(() -> {
            TrainedModelAssignmentMetadata.Builder rebalancedMetadata;
            logger.debug(() -> Strings.format((String)"Rebalancing model allocations because [%s]", (Object[])new Object[]{reason}));
            try {
                rebalancedMetadata = this.rebalanceAssignments(clusterState, createAssignmentRequest);
            }
            catch (Exception e) {
                listener.onFailure(e);
                return;
            }
            this.submitUnbatchedTask(reason, new ClusterStateUpdateTask((Set)mlNodesArchitectures, createAssignmentRequest, clusterState, rebalancedMetadata, reason, (ActionListener)listener){
                private volatile boolean isUpdated;
                private volatile boolean isChanged;
                final /* synthetic */ Set val$mlNodesArchitectures;
                final /* synthetic */ Optional val$createAssignmentRequest;
                final /* synthetic */ ClusterState val$clusterState;
                final /* synthetic */ TrainedModelAssignmentMetadata.Builder val$rebalancedMetadata;
                final /* synthetic */ String val$reason;
                final /* synthetic */ ActionListener val$listener;
                {
                    this.val$mlNodesArchitectures = set;
                    this.val$createAssignmentRequest = optional;
                    this.val$clusterState = clusterState;
                    this.val$rebalancedMetadata = builder;
                    this.val$reason = string;
                    this.val$listener = actionListener;
                }

                public ClusterState execute(ClusterState currentState) {
                    if (TrainedModelAssignmentClusterService.this.areClusterStatesCompatibleForRebalance(this.val$clusterState, currentState = TrainedModelAssignmentClusterService.this.stopPlatformSpecificModelsInHeterogeneousClusters(currentState, this.val$mlNodesArchitectures, this.val$createAssignmentRequest.map(CreateTrainedModelAssignmentAction.Request::getTaskParams), this.val$clusterState))) {
                        this.isUpdated = true;
                        ClusterState updatedState = TrainedModelAssignmentClusterService.update(currentState, this.val$rebalancedMetadata);
                        this.isChanged = updatedState != currentState;
                        return updatedState;
                    }
                    TrainedModelAssignmentClusterService.this.rebalanceAssignments(currentState, this.val$createAssignmentRequest, this.val$reason, (ActionListener<TrainedModelAssignmentMetadata>)this.val$listener);
                    return currentState;
                }

                public void onFailure(Exception e) {
                    this.val$listener.onFailure(e);
                }

                public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                    if (this.isUpdated) {
                        if (this.isChanged) {
                            TrainedModelAssignmentClusterService.this.threadPool.executor("ml_utility").execute(() -> TrainedModelAssignmentClusterService.this.systemAuditor.info(Messages.getMessage((String)"Rebalanced trained model allocations because [{0}]", (Object[])new Object[]{this.val$reason})));
                        }
                        this.val$listener.onResponse((Object)TrainedModelAssignmentMetadata.fromState((ClusterState)newState));
                    }
                }
            });
        }), arg_0 -> listener.onFailure(arg_0));
        MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet((ActionListener)architecturesListener, (Client)this.client, (ExecutorService)this.threadPool.executor("ml_utility"));
    }

    ClusterState stopPlatformSpecificModelsInHeterogeneousClusters(ClusterState updatedState, Set<String> mlNodesArchitectures, Optional<StartTrainedModelDeploymentAction.TaskParams> modelToAdd, ClusterState clusterState) {
        if (mlNodesArchitectures.size() > 1 && modelToAdd.isPresent()) {
            String reasonToStop = Strings.format((String)"ML nodes in this cluster have multiple platform architectures, but can only have one for this model ([%s]); detected architectures: %s", (Object[])new Object[]{modelToAdd.get().getModelId(), mlNodesArchitectures});
            updatedState = this.callSetToStopping(reasonToStop, modelToAdd.get().getDeploymentId(), clusterState);
        }
        return updatedState;
    }

    ClusterState callSetToStopping(String reasonToStop, String deploymentId, ClusterState clusterState) {
        return TrainedModelAssignmentClusterService.setToStopping(clusterState, deploymentId, reasonToStop);
    }

    private boolean areClusterStatesCompatibleForRebalance(ClusterState source, ClusterState target) {
        List<DiscoveryNode> targetNodes;
        List<DiscoveryNode> sourceNodes = TrainedModelAssignmentClusterService.getAssignableNodes(source);
        return sourceNodes.equals(targetNodes = TrainedModelAssignmentClusterService.getAssignableNodes(target)) && this.detectNodeLoads(sourceNodes, source).equals(this.detectNodeLoads(targetNodes, target)) && MlMetadata.getMlMetadata((ClusterState)source).equals((Object)MlMetadata.getMlMetadata((ClusterState)target)) && TrainedModelAssignmentMetadata.fromState((ClusterState)source).equals((Object)TrainedModelAssignmentMetadata.fromState((ClusterState)target));
    }

    private TrainedModelAssignmentMetadata.Builder rebalanceAssignments(ClusterState currentState, Optional<CreateTrainedModelAssignmentAction.Request> createAssignmentRequest) throws Exception {
        List<DiscoveryNode> nodes = TrainedModelAssignmentClusterService.getAssignableNodes(currentState);
        logger.debug(() -> Strings.format((String)"assignable nodes are %s", (Object[])new Object[]{nodes.stream().map(DiscoveryNode::getId).toList()}));
        Map<DiscoveryNode, NodeLoad> nodeLoads = this.detectNodeLoads(nodes, currentState);
        TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.fromState((ClusterState)currentState);
        TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, this.nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(currentState), createAssignmentRequest, this.allocatedProcessorsScale);
        Set shuttingDownNodeIds = currentState.metadata().nodeShutdowns().getAllNodeIds();
        TrainedModelAssignmentMetadata.Builder rebalanced = TrainedModelAssignmentClusterService.setShuttingDownNodeRoutesToStopping(currentMetadata, shuttingDownNodeIds, rebalancer.rebalance());
        if (createAssignmentRequest.isPresent()) {
            this.checkModelIsFullyAllocatedIfScalingIsNotPossible(createAssignmentRequest.get().getTaskParams().getDeploymentId(), rebalanced, nodes);
        }
        return rebalanced;
    }

    static TrainedModelAssignmentMetadata.Builder setShuttingDownNodeRoutesToStopping(TrainedModelAssignmentMetadata currentMetadata, Set<String> shuttingDownNodeIds, TrainedModelAssignmentMetadata.Builder builder) {
        if (shuttingDownNodeIds.isEmpty()) {
            return builder;
        }
        for (TrainedModelAssignment existingAssignment : currentMetadata.allAssignments().values()) {
            boolean foundShuttingDownNodeForAssignment = false;
            String existingDeploymentId = existingAssignment.getDeploymentId();
            TrainedModelAssignment.Builder assignmentBuilder = builder.hasModelDeployment(existingAssignment.getDeploymentId()) ? builder.getAssignment(existingDeploymentId) : TrainedModelAssignment.Builder.fromAssignment((TrainedModelAssignment)existingAssignment).stopAssignment("nodes changed").clearNodeRoutingTable();
            for (String nodeId : shuttingDownNodeIds) {
                if (!existingAssignment.isRoutedToNode(nodeId) || !((RoutingInfo)existingAssignment.getNodeRoutingTable().get(nodeId)).getState().isAnyOf(new RoutingState[]{RoutingState.STARTED, RoutingState.STARTING})) continue;
                logger.debug(() -> Strings.format((String)"Found assignment deployment id: [%s] with route to shutting down node id: [%s], adding stopping route", (Object[])new Object[]{existingDeploymentId, nodeId}));
                foundShuttingDownNodeForAssignment = true;
                RoutingInfo stoppingRouteInfo = TrainedModelAssignmentUtils.createShuttingDownRoute((RoutingInfo)((RoutingInfo)existingAssignment.getNodeRoutingTable().get(nodeId)));
                assignmentBuilder.addOrOverwriteRoutingEntry(nodeId, stoppingRouteInfo);
            }
            if (!foundShuttingDownNodeForAssignment) continue;
            builder.addOrOverwriteAssignment(existingDeploymentId, assignmentBuilder);
        }
        return builder;
    }

    private void checkModelIsFullyAllocatedIfScalingIsNotPossible(String modelId, TrainedModelAssignmentMetadata.Builder assignments, List<DiscoveryNode> nodes) {
        TrainedModelAssignment assignment = assignments.getAssignment(modelId).build();
        if (this.isScalingPossible(nodes) || assignment.isSatisfied(nodes.stream().map(DiscoveryNode::getId).collect(Collectors.toSet()))) {
            return;
        }
        if (assignment.getNodeRoutingTable().isEmpty()) {
            String msg = "Could not start deployment because no suitable nodes were found, allocation explanation [" + assignment.getReason().orElse("none") + "]";
            logger.warn("[{}] {}", (Object)modelId, (Object)msg);
            IllegalStateException detail = new IllegalStateException(msg);
            throw new ElasticsearchStatusException("Could not start deployment because no ML nodes with sufficient capacity were found", RestStatus.TOO_MANY_REQUESTS, (Throwable)detail, new Object[0]);
        }
        String msg = "Could not start deployment because there are not enough resources to provide all requested allocations";
        logger.debug(() -> Strings.format((String)"[%s] %s", (Object[])new Object[]{modelId, msg}));
        throw new ElasticsearchStatusException(msg, RestStatus.TOO_MANY_REQUESTS, new Object[0]);
    }

    private static List<DiscoveryNode> getAssignableNodes(ClusterState clusterState) {
        Set<String> shuttingDownNodes = TrainedModelAssignmentClusterService.nodesShuttingDown(clusterState);
        return clusterState.getNodes().getNodes().values().stream().filter(StartTrainedModelDeploymentAction.TaskParams::mayAssignToNode).filter(n -> !shuttingDownNodes.contains(n.getId())).toList();
    }

    private Map<DiscoveryNode, NodeLoad> detectNodeLoads(List<DiscoveryNode> nodes, ClusterState clusterState) {
        return nodes.stream().collect(Collectors.toMap(Function.identity(), n -> this.nodeLoadDetector.detectNodeLoad(clusterState, null, (DiscoveryNode)n, this.maxOpenJobs, this.maxMemoryPercentage, this.useAuto)));
    }

    private boolean isScalingPossible(List<DiscoveryNode> nodes) {
        OptionalLong smallestMLNode = nodes.stream().map(NodeLoadDetector::getNodeSize).flatMapToLong(OptionalLong::stream).min();
        return this.maxLazyMLNodes > nodes.size() || smallestMLNode.isPresent() && smallestMLNode.getAsLong() < this.maxMLNodeSize;
    }

    public void updateDeployment(String deploymentId, Integer numberOfAllocations, AdaptiveAllocationsSettings adaptiveAllocationsSettings, boolean isInternal, ActionListener<TrainedModelAssignment> listener) {
        this.updateDeployment(this.clusterService.state(), deploymentId, numberOfAllocations, adaptiveAllocationsSettings, isInternal, listener);
    }

    private void updateDeployment(final ClusterState clusterState, String deploymentId, Integer numberOfAllocations, AdaptiveAllocationsSettings adaptiveAllocationsSettingsUpdates, boolean isInternal, ActionListener<TrainedModelAssignment> listener) {
        boolean hasUpdates;
        TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState((ClusterState)clusterState);
        TrainedModelAssignment existingAssignment = metadata.getDeploymentAssignment(deploymentId);
        if (existingAssignment == null) {
            listener.onFailure((Exception)((Object)ExceptionsHelper.missingModelDeployment((String)deploymentId)));
            return;
        }
        AdaptiveAllocationsSettings adaptiveAllocationsSettings = this.getAdaptiveAllocationsSettings(existingAssignment.getAdaptiveAllocationsSettings(), adaptiveAllocationsSettingsUpdates);
        if (adaptiveAllocationsSettings != null) {
            if (!isInternal && adaptiveAllocationsSettings.getEnabled() == Boolean.TRUE && numberOfAllocations != null) {
                ValidationException validationException = new ValidationException();
                validationException.addValidationError("[" + String.valueOf(StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS) + "] cannot be set if adaptive allocations is enabled");
                listener.onFailure((Exception)((Object)validationException));
                return;
            }
            ActionRequestValidationException validationException = adaptiveAllocationsSettings.validate();
            if (validationException != null) {
                listener.onFailure((Exception)validationException);
                return;
            }
        }
        if (!(hasUpdates = TrainedModelAssignmentClusterService.hasUpdates(numberOfAllocations, adaptiveAllocationsSettingsUpdates, existingAssignment))) {
            logger.debug("no updates to be made for deployment [{}]", (Object)deploymentId);
            listener.onResponse((Object)existingAssignment);
            return;
        }
        if (existingAssignment.getAssignmentState() != AssignmentState.STARTED) {
            listener.onFailure((Exception)((Object)new ElasticsearchStatusException("cannot update deployment that is not in [{}] state", RestStatus.CONFLICT, new Object[]{AssignmentState.STARTED})));
            return;
        }
        ActionListener updatedAssignmentListener = ActionListener.wrap(updatedAssignment -> this.submitUnbatchedTask("update model deployment", new ClusterStateUpdateTask((TrainedModelAssignmentMetadata.Builder)updatedAssignment, deploymentId, numberOfAllocations, adaptiveAllocationsSettings, isInternal, (ActionListener)listener, existingAssignment){
            private volatile boolean isUpdated;
            final /* synthetic */ TrainedModelAssignmentMetadata.Builder val$updatedAssignment;
            final /* synthetic */ String val$deploymentId;
            final /* synthetic */ Integer val$numberOfAllocations;
            final /* synthetic */ AdaptiveAllocationsSettings val$adaptiveAllocationsSettings;
            final /* synthetic */ boolean val$isInternal;
            final /* synthetic */ ActionListener val$listener;
            final /* synthetic */ TrainedModelAssignment val$existingAssignment;
            {
                this.val$updatedAssignment = builder;
                this.val$deploymentId = string;
                this.val$numberOfAllocations = n;
                this.val$adaptiveAllocationsSettings = adaptiveAllocationsSettings;
                this.val$isInternal = bl;
                this.val$listener = actionListener;
                this.val$existingAssignment = trainedModelAssignment;
            }

            public ClusterState execute(ClusterState currentState) {
                if (TrainedModelAssignmentClusterService.this.areClusterStatesCompatibleForRebalance(clusterState, currentState)) {
                    this.isUpdated = true;
                    return TrainedModelAssignmentClusterService.update(currentState, this.val$updatedAssignment);
                }
                logger.debug(() -> Strings.format((String)"[%s] Retrying update as cluster state has been modified", (Object[])new Object[]{this.val$deploymentId}));
                TrainedModelAssignmentClusterService.this.updateDeployment(currentState, this.val$deploymentId, this.val$numberOfAllocations, this.val$adaptiveAllocationsSettings, this.val$isInternal, (ActionListener<TrainedModelAssignment>)this.val$listener);
                return currentState;
            }

            public void onFailure(Exception e) {
                this.val$listener.onFailure(e);
            }

            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                if (this.isUpdated) {
                    TrainedModelAssignment updatedAssignment = TrainedModelAssignmentMetadata.fromState((ClusterState)newState).getDeploymentAssignment(this.val$deploymentId);
                    if (updatedAssignment.totalTargetAllocations() > this.val$existingAssignment.totalTargetAllocations()) {
                        TrainedModelAssignmentClusterService.this.threadPool.executor("ml_utility").execute(() -> TrainedModelAssignmentClusterService.this.systemAuditor.info(Messages.getMessage((String)"Rebalanced trained model allocations because [{0}]", (Object[])new Object[]{"model deployment updated"})));
                    }
                    this.val$listener.onResponse((Object)updatedAssignment);
                }
            }
        }), arg_0 -> listener.onFailure(arg_0));
        this.updateAssignment(clusterState, existingAssignment, numberOfAllocations, adaptiveAllocationsSettings, (ActionListener<TrainedModelAssignmentMetadata.Builder>)updatedAssignmentListener);
    }

    static boolean hasUpdates(Integer proposedNumberOfAllocations, AdaptiveAllocationsSettings proposedAdaptiveSettings, TrainedModelAssignment existingAssignment) {
        return proposedNumberOfAllocations != null && !Objects.equals(proposedNumberOfAllocations, existingAssignment.getTaskParams().getNumberOfAllocations()) || proposedAdaptiveSettings != null && !Objects.equals(proposedAdaptiveSettings, existingAssignment.getAdaptiveAllocationsSettings());
    }

    private AdaptiveAllocationsSettings getAdaptiveAllocationsSettings(AdaptiveAllocationsSettings original, AdaptiveAllocationsSettings updates) {
        if (updates == null) {
            return original;
        }
        if (updates == AdaptiveAllocationsSettings.RESET_PLACEHOLDER) {
            return null;
        }
        if (original == null) {
            return updates;
        }
        return original.merge(updates);
    }

    private void updateAssignment(ClusterState clusterState, TrainedModelAssignment assignment, Integer numberOfAllocations, AdaptiveAllocationsSettings adaptiveAllocationsSettings, ActionListener<TrainedModelAssignmentMetadata.Builder> listener) {
        this.threadPool.executor("ml_utility").execute(() -> {
            if (numberOfAllocations == null || numberOfAllocations.intValue() == assignment.getTaskParams().getNumberOfAllocations()) {
                this.updateAndKeepNumberOfAllocations(clusterState, assignment, adaptiveAllocationsSettings, listener);
            } else if (numberOfAllocations > assignment.getTaskParams().getNumberOfAllocations()) {
                this.increaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, adaptiveAllocationsSettings, listener);
            } else {
                this.decreaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, adaptiveAllocationsSettings, listener);
            }
        });
    }

    private void updateAndKeepNumberOfAllocations(ClusterState clusterState, TrainedModelAssignment assignment, AdaptiveAllocationsSettings adaptiveAllocationsSettings, ActionListener<TrainedModelAssignmentMetadata.Builder> listener) {
        TrainedModelAssignment.Builder updatedAssignment = TrainedModelAssignment.Builder.fromAssignment((TrainedModelAssignment)assignment).setAdaptiveAllocationsSettings(adaptiveAllocationsSettings);
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder((ClusterState)clusterState);
        builder.updateAssignment(assignment.getDeploymentId(), updatedAssignment);
        listener.onResponse((Object)builder);
    }

    private void increaseNumberOfAllocations(ClusterState clusterState, TrainedModelAssignment assignment, int numberOfAllocations, AdaptiveAllocationsSettings adaptiveAllocationsSettings, ActionListener<TrainedModelAssignmentMetadata.Builder> listener) {
        try {
            TrainedModelAssignment.Builder updatedAssignment = TrainedModelAssignment.Builder.fromAssignment((TrainedModelAssignment)assignment).setNumberOfAllocations(numberOfAllocations).setAdaptiveAllocationsSettings(adaptiveAllocationsSettings);
            ClusterState updatedClusterState = TrainedModelAssignmentClusterService.update(clusterState, TrainedModelAssignmentMetadata.builder((ClusterState)clusterState).updateAssignment(assignment.getDeploymentId(), updatedAssignment));
            TrainedModelAssignmentMetadata.Builder rebalancedMetadata = this.rebalanceAssignments(updatedClusterState, Optional.empty());
            if (!this.isScalingPossible(TrainedModelAssignmentClusterService.getAssignableNodes(clusterState)) && rebalancedMetadata.getAssignment(assignment.getDeploymentId()).build().totalTargetAllocations() < numberOfAllocations) {
                listener.onFailure((Exception)((Object)new ElasticsearchStatusException("Could not update deployment because there are not enough resources to provide all requested allocations", RestStatus.TOO_MANY_REQUESTS, new Object[0])));
            } else {
                listener.onResponse((Object)rebalancedMetadata);
            }
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private void decreaseNumberOfAllocations(ClusterState clusterState, TrainedModelAssignment assignment, int numberOfAllocations, AdaptiveAllocationsSettings adaptiveAllocationsSettings, ActionListener<TrainedModelAssignmentMetadata.Builder> listener) {
        TrainedModelAssignment.Builder updatedAssignment = numberOfAllocations < assignment.totalTargetAllocations() ? new AllocationReducer(assignment, this.nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(clusterState)).reduceTo(numberOfAllocations) : TrainedModelAssignment.Builder.fromAssignment((TrainedModelAssignment)assignment).setNumberOfAllocations(numberOfAllocations);
        updatedAssignment.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings);
        if (numberOfAllocations <= assignment.totalTargetAllocations()) {
            updatedAssignment.setReason(null);
        }
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder((ClusterState)clusterState);
        builder.updateAssignment(assignment.getDeploymentId(), updatedAssignment);
        listener.onResponse((Object)builder);
    }

    static ClusterState setToStopping(ClusterState clusterState, String deploymentId, String reason) {
        TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState((ClusterState)clusterState);
        TrainedModelAssignment existingAssignment = metadata.getDeploymentAssignment(deploymentId);
        if (existingAssignment == null) {
            throw new ResourceNotFoundException("assignment with id [{}] not found", new Object[]{deploymentId});
        }
        if (existingAssignment.getAssignmentState().equals((Object)AssignmentState.STOPPING)) {
            return clusterState;
        }
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder((ClusterState)clusterState);
        builder.getAssignment(deploymentId).stopAssignment(reason);
        return TrainedModelAssignmentClusterService.update(clusterState, builder);
    }

    static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTrainedModelAssignmentRoutingInfoAction.Request request) {
        String deploymentId = request.getDeploymentId();
        String nodeId = request.getNodeId();
        TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState((ClusterState)currentState);
        logger.trace(() -> Strings.format((String)"[%s] [%s] current metadata before update %s", (Object[])new Object[]{deploymentId, nodeId, org.elasticsearch.common.Strings.toString((ChunkedToXContent)metadata)}));
        TrainedModelAssignment existingAssignment = metadata.getDeploymentAssignment(deploymentId);
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder((ClusterState)currentState);
        if (request.getUpdate().getStateAndReason().isPresent() && ((RoutingStateAndReason)request.getUpdate().getStateAndReason().get()).getState().equals((Object)RoutingState.STOPPED)) {
            if (existingAssignment == null || !existingAssignment.isRoutedToNode(nodeId)) {
                return currentState;
            }
            builder.getAssignment(deploymentId).removeRoutingEntry(nodeId).calculateAndSetAssignmentState();
            return TrainedModelAssignmentClusterService.update(currentState, builder);
        }
        if (existingAssignment == null) {
            throw new ResourceNotFoundException("assignment with id [{}] not found", new Object[]{deploymentId});
        }
        if (existingAssignment.getAssignmentState().equals((Object)AssignmentState.STOPPING)) {
            logger.debug(() -> Strings.format((String)"[%s] requested update from node [%s] while stopping; update was [%s]", (Object[])new Object[]{deploymentId, nodeId, request.getUpdate()}));
            return currentState;
        }
        if (!existingAssignment.isRoutedToNode(nodeId)) {
            throw new ResourceNotFoundException("assignment with id [{}]] is not routed to node [{}]", new Object[]{deploymentId, nodeId});
        }
        RoutingInfo routingInfo = (RoutingInfo)existingAssignment.getNodeRoutingTable().get(nodeId);
        builder.getAssignment(deploymentId).updateExistingRoutingEntry(nodeId, request.getUpdate().apply(routingInfo)).calculateAndSetAssignmentState();
        return TrainedModelAssignmentClusterService.update(currentState, builder);
    }

    static ClusterState removeAssignment(ClusterState currentState, String deploymentId) {
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder((ClusterState)currentState);
        if (!builder.hasModelDeployment(deploymentId)) {
            throw new ResourceNotFoundException("assignment for deployment with id [{}] not found", new Object[]{deploymentId});
        }
        logger.debug(() -> Strings.format((String)"[%s] removing assignment", (Object[])new Object[]{deploymentId}));
        return TrainedModelAssignmentClusterService.update(currentState, builder.removeAssignment(deploymentId));
    }

    static ClusterState removeAllAssignments(ClusterState currentState) {
        if (TrainedModelAssignmentMetadata.fromState((ClusterState)currentState).allAssignments().isEmpty()) {
            return currentState;
        }
        return TrainedModelAssignmentClusterService.forceUpdate(currentState, TrainedModelAssignmentMetadata.Builder.empty());
    }

    static Optional<String> detectReasonToRebalanceModels(ClusterChangedEvent event) {
        TrainedModelAssignmentMetadata newMetadata = TrainedModelAssignmentMetadata.fromState((ClusterState)event.state());
        if (newMetadata == null || newMetadata.allAssignments().isEmpty()) {
            return Optional.empty();
        }
        return TrainedModelAssignmentClusterService.detectReasonIfMlJobsStopped(event).or(() -> {
            String reason = null;
            if (TrainedModelAssignmentClusterService.haveMlNodesChanged(event, newMetadata)) {
                reason = "nodes changed";
            } else if (newMetadata.hasOutdatedAssignments()) {
                reason = "outdated assignments detected";
            }
            return Optional.ofNullable(reason);
        });
    }

    static Optional<String> detectReasonIfMlJobsStopped(ClusterChangedEvent event) {
        if (!event.changedCustomProjectMetadataSet().contains("persistent_tasks")) {
            return Optional.empty();
        }
        PersistentTasksCustomMetadata previousPersistentTasks = PersistentTasksCustomMetadata.getPersistentTasksCustomMetadata((ClusterState)event.previousState());
        if (previousPersistentTasks == null) {
            return Optional.empty();
        }
        PersistentTasksCustomMetadata currentPersistentTasks = PersistentTasksCustomMetadata.getPersistentTasksCustomMetadata((ClusterState)event.state());
        Set<String> currentMlTaskIds = TrainedModelAssignmentClusterService.findMlProcessTaskIds(currentPersistentTasks);
        Set previousMlTasks = MlTasks.findMlProcessTasks((PersistentTasksCustomMetadata)previousPersistentTasks);
        Set stoppedTaskTypes = previousMlTasks.stream().filter(task -> !currentMlTaskIds.contains(task.getId())).map(PersistentTasksCustomMetadata.PersistentTask::getTaskName).map(MlTasks::prettyPrintTaskName).collect(Collectors.toSet());
        if (stoppedTaskTypes.size() == 1) {
            return Optional.of("ML [" + (String)stoppedTaskTypes.iterator().next() + "] job stopped");
        }
        if (stoppedTaskTypes.size() > 1) {
            return Optional.of("ML " + String.valueOf(stoppedTaskTypes) + " jobs stopped");
        }
        return Optional.empty();
    }

    private static Set<String> findMlProcessTaskIds(@Nullable PersistentTasksCustomMetadata metadata) {
        return metadata == null ? Set.of() : MlTasks.findMlProcessTasks((PersistentTasksCustomMetadata)metadata).stream().map(PersistentTasksCustomMetadata.PersistentTask::getId).collect(Collectors.toSet());
    }

    static boolean haveMlNodesChanged(ClusterChangedEvent event, TrainedModelAssignmentMetadata newMetadata) {
        boolean nodesShutdownChanged = event.changedCustomClusterMetadataSet().contains("node_shutdown");
        if (event.nodesChanged() || nodesShutdownChanged) {
            Set exitingShutDownNodes;
            String eventIdentity = Long.toHexString(System.nanoTime());
            Set<String> shuttingDownNodes = TrainedModelAssignmentClusterService.nodesShuttingDown(event.state());
            DiscoveryNodes.Delta nodesDelta = event.nodesDelta();
            Set removedNodes = nodesDelta.removedNodes().stream().map(DiscoveryNode::getId).collect(Collectors.toSet());
            Set addedNodes = nodesDelta.addedNodes().stream().map(DiscoveryNode::getId).collect(Collectors.toSet());
            logger.debug(() -> Strings.format((String)"Initial node change info; identity: %s; removed nodes: %s; added nodes: %s; shutting down nodes: %s", (Object[])new Object[]{eventIdentity, removedNodes, addedNodes, shuttingDownNodes}));
            if (nodesShutdownChanged) {
                Set<String> previousShuttingDownNodes = TrainedModelAssignmentClusterService.nodesShuttingDown(event.previousState());
                Set presentNodes = event.state().nodes().stream().map(DiscoveryNode::getId).collect(Collectors.toSet());
                Set returningShutDownNodes = Sets.intersection(presentNodes, (Set)Sets.difference(previousShuttingDownNodes, shuttingDownNodes));
                addedNodes.addAll(returningShutDownNodes);
                exitingShutDownNodes = Sets.difference(shuttingDownNodes, previousShuttingDownNodes);
                removedNodes.addAll(exitingShutDownNodes);
                logger.debug(() -> Strings.format((String)"Shutting down nodes were changed; identity: %s; previous shutting down nodes: %s; returning nodes: %s", (Object[])new Object[]{eventIdentity, previousShuttingDownNodes, returningShutDownNodes}));
            } else {
                exitingShutDownNodes = Collections.emptySet();
            }
            logger.debug(() -> Strings.format((String)"identity: %s; added nodes %s; removed nodes %s; shutting down nodes %s; exiting shutdown nodes %s", (Object[])new Object[]{eventIdentity, addedNodes, removedNodes, shuttingDownNodes, exitingShutDownNodes}));
            for (TrainedModelAssignment trainedModelAssignment : newMetadata.allAssignments().values()) {
                if (trainedModelAssignment.getAssignmentState().equals((Object)AssignmentState.STOPPING)) continue;
                for (String nodeId : exitingShutDownNodes) {
                    if (!trainedModelAssignment.isRoutedToNode(nodeId) || ((RoutingInfo)trainedModelAssignment.getNodeRoutingTable().get(nodeId)).getState() == RoutingState.STOPPING) continue;
                    logger.debug(() -> Strings.format((String)"should rebalance because model deployment [%s] has allocations on shutting down node [%s]", (Object[])new Object[]{trainedModelAssignment.getDeploymentId(), nodeId}));
                    return true;
                }
                for (String nodeId : removedNodes) {
                    if (!trainedModelAssignment.isRoutedToNode(nodeId) || shuttingDownNodes.contains(nodeId)) continue;
                    logger.debug(() -> Strings.format((String)"should rebalance because model deployment [%s] has allocations on removed node [%s]", (Object[])new Object[]{trainedModelAssignment.getDeploymentId(), nodeId}));
                    return true;
                }
                for (String nodeId : addedNodes) {
                    if (!StartTrainedModelDeploymentAction.TaskParams.mayAssignToNode((DiscoveryNode)event.state().nodes().get(nodeId)) || shuttingDownNodes.contains(nodeId)) continue;
                    logger.debug(() -> Strings.format((String)"should rebalance because ML eligible node [%s] was added", (Object[])new Object[]{nodeId}));
                    return true;
                }
            }
        }
        return false;
    }

    static Set<String> nodesShuttingDown(ClusterState state) {
        return state.metadata().nodeShutdowns().getAllNodeIds();
    }
}

