/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.assignment.planning;

import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Strings;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AbstractPreserveAllocations;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
import org.elasticsearch.xpack.ml.inference.assignment.planning.LinearProgrammingPlanSolver;
import org.elasticsearch.xpack.ml.inference.assignment.planning.PreserveAllAllocations;
import org.elasticsearch.xpack.ml.inference.assignment.planning.PreserveOneAllocation;

public class AssignmentPlanner {
    private static final Logger logger = LogManager.getLogger(AssignmentPlanner.class);
    private final List<AssignmentPlan.Node> nodes;
    private final List<AssignmentPlan.Deployment> deployments;
    private final List<AssignmentPlan.Deployment> deploymentsWithZeroAllocations;

    public AssignmentPlanner(List<AssignmentPlan.Node> nodes, List<AssignmentPlan.Deployment> deployments) {
        this.nodes = nodes.stream().sorted(Comparator.comparing(AssignmentPlan.Node::id)).toList();
        this.deployments = deployments.stream().filter(deployment -> deployment.allocations() > 0).sorted(Comparator.comparing(AssignmentPlan.Deployment::deploymentId)).toList();
        this.deploymentsWithZeroAllocations = deployments.stream().filter(deployment -> deployment.allocations() == 0).sorted(Comparator.comparing(AssignmentPlan.Deployment::deploymentId)).toList();
    }

    public AssignmentPlan computePlan() {
        return this.computePlan(true);
    }

    public AssignmentPlan computePlan(boolean tryAssigningAllPreviouslyAllocatedModels) {
        AssignmentPlan bestPlan;
        logger.debug(() -> Strings.format((String)"Computing plan for nodes = %s; deployments = %s", (Object[])new Object[]{this.nodes, this.deployments}));
        AssignmentPlan planSatisfyingCurrentAssignments = this.solveSatisfyingCurrentAssignments();
        logger.debug(() -> "Plan satisfying current assignments =\n" + planSatisfyingCurrentAssignments.prettyPrint());
        if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() || !tryAssigningAllPreviouslyAllocatedModels) {
            bestPlan = planSatisfyingCurrentAssignments;
        } else {
            AssignmentPlan planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated = this.solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated();
            logger.debug(() -> "Plan with at least one allocation for previously assigned models =\n" + planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated.prettyPrint());
            bestPlan = planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated.arePreviouslyAssignedModelsAssigned() ? planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated : (planSatisfyingCurrentAssignments.countPreviouslyAssignedModelsThatAreStillAssigned() >= planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated.countPreviouslyAssignedModelsThatAreStillAssigned() ? planSatisfyingCurrentAssignments : planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated);
        }
        bestPlan = bestPlan.withZeroAllocationDeployments(this.deploymentsWithZeroAllocations);
        if (logger.isDebugEnabled()) {
            logger.debug("Best plan =\n{}", (Object)bestPlan.prettyPrint());
            logger.debug("{}", (Object)this.prettyPrintOverallStats(bestPlan));
        }
        return bestPlan;
    }

    private AssignmentPlan solveSatisfyingCurrentAssignments() {
        AssignmentPlan planKeepingAllAllocationsOnCurrentAssignments;
        AssignmentPlan planKeepingOneAllocationOnCurrentAssignments = this.solveKeepingOneAllocationOnCurrentAssignments();
        AssignmentPlan bestPlan = planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels() ? planKeepingOneAllocationOnCurrentAssignments : (!planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() ? this.solvePreservingAllAllocationsOnCurrentAssignments() : ((planKeepingAllAllocationsOnCurrentAssignments = this.solvePreservingAllAllocationsOnCurrentAssignments()).compareTo(planKeepingOneAllocationOnCurrentAssignments) >= 0 ? planKeepingAllAllocationsOnCurrentAssignments : planKeepingOneAllocationOnCurrentAssignments));
        return bestPlan;
    }

    private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated() {
        logger.debug(() -> "Attempting to solve assigning at least one allocation to previously assigned deployments");
        List<AssignmentPlan.Deployment> previouslyAssignedModelsOnly = this.deployments.stream().filter(m -> m.hasEverBeenAllocated()).map(m -> new AssignmentPlan.Deployment(m.deploymentId(), m.modelId(), m.memoryBytes(), 1, m.threadsPerAllocation(), Map.of(), m.maxAssignedAllocations(), m.getAdaptiveAllocationsSettings(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes())).toList();
        AssignmentPlan planWithSingleAllocationForPreviouslyAssignedModels = new LinearProgrammingPlanSolver(this.nodes, previouslyAssignedModelsOnly).solvePlan(true);
        HashMap<String, String> modelIdToNodeIdWithSingleAllocation = new HashMap<String, String>();
        for (AssignmentPlan.Deployment m2 : planWithSingleAllocationForPreviouslyAssignedModels.deployments()) {
            Optional<Map<AssignmentPlan.Node, Integer>> assignments = planWithSingleAllocationForPreviouslyAssignedModels.assignments(m2);
            Set nodes = assignments.orElse(Map.of()).keySet();
            if (nodes.isEmpty()) continue;
            assert (nodes.size() == 1);
            modelIdToNodeIdWithSingleAllocation.put(m2.deploymentId(), ((AssignmentPlan.Node)nodes.iterator().next()).id());
        }
        List<AssignmentPlan.Deployment> planDeployments = this.deployments.stream().map(m -> {
            Map<String, Integer> currentAllocationsByNodeId = modelIdToNodeIdWithSingleAllocation.containsKey(m.deploymentId()) ? Map.of((String)modelIdToNodeIdWithSingleAllocation.get(m.deploymentId()), 1) : Map.of();
            return new AssignmentPlan.Deployment(m.deploymentId(), m.modelId(), m.memoryBytes(), m.allocations(), m.threadsPerAllocation(), currentAllocationsByNodeId, m.maxAssignedAllocations(), m.getAdaptiveAllocationsSettings(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes());
        }).toList();
        return new AssignmentPlanner(this.nodes, planDeployments).computePlan(false);
    }

    private AssignmentPlan solveKeepingOneAllocationOnCurrentAssignments() {
        logger.trace(() -> Strings.format((String)"Solving preserving one allocation on current assignments", (Object[])new Object[0]));
        return AssignmentPlanner.solvePreservingCurrentAssignments(new PreserveOneAllocation(this.nodes, this.deployments));
    }

    private AssignmentPlan solvePreservingAllAllocationsOnCurrentAssignments() {
        logger.trace(() -> Strings.format((String)"Solving preserving all allocations on current assignments", (Object[])new Object[0]));
        return AssignmentPlanner.solvePreservingCurrentAssignments(new PreserveAllAllocations(this.nodes, this.deployments));
    }

    private static AssignmentPlan solvePreservingCurrentAssignments(AbstractPreserveAllocations preserveAllocations) {
        List<AssignmentPlan.Node> planNodes = preserveAllocations.nodesPreservingAllocations();
        List<AssignmentPlan.Deployment> planDeployments = preserveAllocations.modelsPreservingAllocations();
        logger.trace(() -> Strings.format((String)"Nodes after applying allocation preserving strategy = %s", (Object[])new Object[]{planNodes}));
        logger.trace(() -> Strings.format((String)"Deployments after applying allocation preserving strategy = %s", (Object[])new Object[]{planDeployments}));
        AssignmentPlan assignmentPlan = new LinearProgrammingPlanSolver(planNodes, planDeployments).solvePlan(false);
        return preserveAllocations.mergePreservedAllocations(assignmentPlan);
    }

    private String prettyPrintOverallStats(AssignmentPlan assignmentPlan) {
        int totalAllocationsRequired = 0;
        int totalAllocationsAssigned = 0;
        int totalCoresUsed = 0;
        long totalAvailableMem = this.nodes.stream().map(AssignmentPlan.Node::availableMemoryBytes).mapToLong(Long::longValue).sum();
        int totalCores = this.nodes.stream().map(AssignmentPlan.Node::cores).mapToInt(Integer::intValue).sum();
        long totalUsedMem = 0L;
        for (AssignmentPlan.Deployment m : this.deployments) {
            totalAllocationsRequired += m.allocations();
            if (!assignmentPlan.assignments(m).isPresent()) continue;
            int allocations = assignmentPlan.assignments(m).get().values().stream().mapToInt(Integer::intValue).sum();
            totalAllocationsAssigned += allocations;
            totalCoresUsed += allocations * m.threadsPerAllocation();
            totalUsedMem += m.memoryBytes() * (long)assignmentPlan.assignments(m).get().values().size();
        }
        StringBuilder msg = new StringBuilder("Overall Stats: ");
        msg.append("(used memory = ");
        msg.append(ByteSizeValue.ofBytes((long)totalUsedMem));
        msg.append(") (total available memory = ");
        msg.append(ByteSizeValue.ofBytes((long)totalAvailableMem));
        msg.append(") (allocations = ");
        msg.append(totalAllocationsAssigned);
        msg.append("/");
        msg.append(totalAllocationsRequired);
        msg.append(") (cores = ");
        msg.append(totalCoresUsed);
        msg.append("/");
        msg.append(totalCores);
        msg.append(")");
        return msg.toString();
    }
}

