/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.assignment.planning;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.elasticsearch.core.Strings;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlanner;
import org.elasticsearch.xpack.ml.inference.assignment.planning.LinearProgrammingPlanSolver;
import org.elasticsearch.xpack.ml.inference.assignment.planning.PreserveAllAllocations;

public class ZoneAwareAssignmentPlanner {
    private static final Logger logger = LogManager.getLogger(ZoneAwareAssignmentPlanner.class);
    private final Map<List<String>, List<AssignmentPlan.Node>> nodesByZone;
    private final List<AssignmentPlan.Deployment> deployments;

    public ZoneAwareAssignmentPlanner(Map<List<String>, List<AssignmentPlan.Node>> nodesByZone, List<AssignmentPlan.Deployment> deployments) {
        this.nodesByZone = ZoneAwareAssignmentPlanner.sortByZone(Objects.requireNonNull(nodesByZone));
        this.deployments = Objects.requireNonNull(deployments);
    }

    private static Map<List<String>, List<AssignmentPlan.Node>> sortByZone(Map<List<String>, List<AssignmentPlan.Node>> nodesByZone) {
        TreeMap<List<String>, List<AssignmentPlan.Node>> sortedByZone = new TreeMap<List<String>, List<AssignmentPlan.Node>>(Comparator.comparing(zoneAttributes -> String.join((CharSequence)"", zoneAttributes)));
        sortedByZone.putAll(nodesByZone);
        return sortedByZone;
    }

    public AssignmentPlan computePlan() {
        if (this.nodesByZone.size() == 1) {
            return new AssignmentPlanner(this.nodesByZone.values().iterator().next(), this.deployments).computePlan(true);
        }
        AssignmentPlan plan = this.computePlan(false);
        if (!plan.arePreviouslyAssignedModelsAssigned()) {
            plan = this.computePlan(true);
        }
        return plan;
    }

    private AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
        logger.debug(() -> Strings.format((String)"computing plan%s trying to assign previously assigned models", (Object[])new Object[]{tryAssigningPreviouslyAssignedModels ? "" : " without"}));
        int remainingZones = this.nodesByZone.size();
        Map<String, Integer> deploymentIdToRemainingAllocations = this.deployments.stream().collect(Collectors.toMap(AssignmentPlan.Deployment::deploymentId, AssignmentPlan.Deployment::allocations));
        ArrayList<AssignmentPlan> plans = new ArrayList<AssignmentPlan>();
        for (Map.Entry<List<String>, List<AssignmentPlan.Node>> zoneToNodes : this.nodesByZone.entrySet()) {
            logger.debug(() -> Strings.format((String)"computing plan for availability zone %s", (Object[])new Object[]{zoneToNodes.getKey()}));
            AssignmentPlan plan = this.computeZonePlan(zoneToNodes.getValue(), deploymentIdToRemainingAllocations, remainingZones, tryAssigningPreviouslyAssignedModels);
            plan.deployments().forEach(d -> deploymentIdToRemainingAllocations.computeIfPresent(d.deploymentId(), (deploymentId, remainingAllocations) -> remainingAllocations - plan.totalAllocations((AssignmentPlan.Deployment)d)));
            plans.add(plan);
            --remainingZones;
        }
        AssignmentPlan plan = this.computePlanAcrossAllNodes(plans);
        logger.debug(() -> "Zone aware plan =\n" + plan.prettyPrint());
        return plan;
    }

    private AssignmentPlan computeZonePlan(List<AssignmentPlan.Node> nodes, Map<String, Integer> deploymentIdToRemainingAllocations, int remainingZones, boolean tryAssigningPreviouslyAssignedModels) {
        Map<String, Integer> deploymentIdToTargetAllocationsPerZone = deploymentIdToRemainingAllocations.entrySet().stream().filter(e -> (Integer)e.getValue() > 0).collect(Collectors.toMap(Map.Entry::getKey, e -> 1 + ZoneAwareAssignmentPlanner.remainingAllocationsPerZoneAfterAssigningOne(remainingZones, (Integer)e.getValue())));
        List<AssignmentPlan.Deployment> modifiedDeployments = this.deployments.stream().filter(d -> deploymentIdToTargetAllocationsPerZone.getOrDefault(d.deploymentId(), 0) > 0).map(d -> new AssignmentPlan.Deployment(d.deploymentId(), d.modelId(), d.memoryBytes(), (Integer)deploymentIdToTargetAllocationsPerZone.get(d.deploymentId()), d.threadsPerAllocation(), d.currentAllocationsByNodeId(), tryAssigningPreviouslyAssignedModels && ((Integer)deploymentIdToRemainingAllocations.get(d.deploymentId())).intValue() == d.allocations() ? d.maxAssignedAllocations() : 0, d.getAdaptiveAllocationsSettings(), d.perDeploymentMemoryBytes(), d.perAllocationMemoryBytes())).toList();
        return new AssignmentPlanner(nodes, modifiedDeployments).computePlan(tryAssigningPreviouslyAssignedModels);
    }

    private static int remainingAllocationsPerZoneAfterAssigningOne(int remainingZones, Integer remainingAllocations) {
        if (remainingAllocations == null || remainingZones == 0) {
            return 0;
        }
        return (remainingAllocations - 1) / remainingZones;
    }

    private AssignmentPlan computePlanAcrossAllNodes(List<AssignmentPlan> plans) {
        logger.debug(() -> "computing plan across all nodes");
        ArrayList<AssignmentPlan.Node> allNodes = new ArrayList<AssignmentPlan.Node>();
        this.nodesByZone.values().forEach(allNodes::addAll);
        Map<String, Map<String, Integer>> allocationsByNodeIdByDeploymentId = this.mergeAllocationsByNodeIdByDeploymentId(plans);
        List<AssignmentPlan.Deployment> modelsAccountingPlans = this.deployments.stream().map(d -> new AssignmentPlan.Deployment(d.deploymentId(), d.modelId(), d.memoryBytes(), d.allocations(), d.threadsPerAllocation(), (Map)allocationsByNodeIdByDeploymentId.get(d.deploymentId()), d.maxAssignedAllocations(), d.getAdaptiveAllocationsSettings(), d.perDeploymentMemoryBytes(), d.perAllocationMemoryBytes())).toList();
        PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(allNodes, modelsAccountingPlans);
        List<AssignmentPlan.Node> planNodes = preserveAllAllocations.nodesPreservingAllocations();
        List<AssignmentPlan.Deployment> planDeployments = preserveAllAllocations.modelsPreservingAllocations();
        AssignmentPlan plan = new LinearProgrammingPlanSolver(planNodes, planDeployments).solvePlan(false);
        plan = preserveAllAllocations.mergePreservedAllocations(plan);
        return this.swapOriginalDeploymentsInPlan(plan, allNodes, modelsAccountingPlans);
    }

    private AssignmentPlan swapOriginalDeploymentsInPlan(AssignmentPlan plan, List<AssignmentPlan.Node> allNodes, List<AssignmentPlan.Deployment> planDeployments) {
        Map originalDeploymentsById = this.deployments.stream().collect(Collectors.toMap(AssignmentPlan.Deployment::deploymentId, Function.identity()));
        Map originalNodeById = allNodes.stream().collect(Collectors.toMap(AssignmentPlan.Node::id, Function.identity()));
        AssignmentPlan.Builder finalPlanBuilder = AssignmentPlan.builder(allNodes, this.deployments);
        for (AssignmentPlan.Deployment planDeployment : planDeployments) {
            AssignmentPlan.Deployment originalDeployment = (AssignmentPlan.Deployment)originalDeploymentsById.get(planDeployment.deploymentId());
            Map nodeAssignments = plan.assignments(planDeployment).orElse(Map.of());
            for (Map.Entry assignment : nodeAssignments.entrySet()) {
                AssignmentPlan.Node originalNode = (AssignmentPlan.Node)originalNodeById.get(((AssignmentPlan.Node)assignment.getKey()).id());
                finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, (Integer)assignment.getValue());
            }
        }
        return finalPlanBuilder.build();
    }

    private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByDeploymentId(List<AssignmentPlan> plans) {
        HashMap<String, Map<String, Integer>> allocationsByNodeIdByDeploymentId = new HashMap<String, Map<String, Integer>>();
        this.deployments.forEach(d -> allocationsByNodeIdByDeploymentId.put(d.deploymentId(), new HashMap()));
        for (AssignmentPlan plan : plans) {
            for (AssignmentPlan.Deployment m : plan.deployments()) {
                Map nodeIdToAllocations = (Map)allocationsByNodeIdByDeploymentId.get(m.deploymentId());
                Optional<Map<AssignmentPlan.Node, Integer>> assignments = plan.assignments(m);
                if (!assignments.isPresent()) continue;
                for (Map.Entry<AssignmentPlan.Node, Integer> nodeAssignments : assignments.get().entrySet()) {
                    nodeIdToAllocations.compute(nodeAssignments.getKey().id(), (nodeId, existingAllocations) -> existingAllocations == null ? (Integer)nodeAssignments.getValue() : existingAllocations + (Integer)nodeAssignments.getValue());
                }
            }
        }
        return allocationsByNodeIdByDeploymentId;
    }
}

