/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.assignment;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Strings;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlanner;
import org.elasticsearch.xpack.ml.inference.assignment.planning.ZoneAwareAssignmentPlanner;
import org.elasticsearch.xpack.ml.job.NodeLoad;
import org.elasticsearch.xpack.ml.utils.MlProcessors;

class TrainedModelAssignmentRebalancer {
    private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentRebalancer.class);
    private final TrainedModelAssignmentMetadata currentMetadata;
    private final Map<DiscoveryNode, NodeLoad> nodeLoads;
    private final Map<List<String>, Collection<DiscoveryNode>> mlNodesByZone;
    private final Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd;
    private final int allocatedProcessorsScale;

    TrainedModelAssignmentRebalancer(TrainedModelAssignmentMetadata currentMetadata, Map<DiscoveryNode, NodeLoad> nodeLoads, Map<List<String>, Collection<DiscoveryNode>> mlNodesByZone, Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd, int allocatedProcessorsScale) {
        this.currentMetadata = Objects.requireNonNull(currentMetadata);
        this.nodeLoads = Objects.requireNonNull(nodeLoads);
        this.mlNodesByZone = Objects.requireNonNull(mlNodesByZone);
        this.deploymentToAdd = Objects.requireNonNull(deploymentToAdd);
        this.allocatedProcessorsScale = allocatedProcessorsScale;
    }

    TrainedModelAssignmentMetadata.Builder rebalance() {
        if (this.deploymentToAdd.isPresent() && this.currentMetadata.hasDeployment(this.deploymentToAdd.get().getDeploymentId())) {
            throw new ResourceAlreadyExistsException("[{}] assignment for deployment with model [{}] already exists", new Object[]{this.deploymentToAdd.get().getDeploymentId(), this.deploymentToAdd.get().getModelId()});
        }
        if (this.deploymentToAdd.isEmpty() && this.areAllModelsSatisfiedAndNoOutdatedRoutingEntries()) {
            logger.trace(() -> "No need to rebalance as all model deployments are satisfied");
            return TrainedModelAssignmentMetadata.Builder.fromMetadata(this.currentMetadata);
        }
        AssignmentPlan assignmentPlan = this.computeAssignmentPlan();
        return this.buildAssignmentsFromPlan(assignmentPlan);
    }

    private boolean areAllModelsSatisfiedAndNoOutdatedRoutingEntries() {
        Set assignableNodeIds = this.nodeLoads.keySet().stream().map(DiscoveryNode::getId).collect(Collectors.toSet());
        for (TrainedModelAssignment assignment : this.currentMetadata.allAssignments().values()) {
            if (assignment.isSatisfied(assignableNodeIds) && !assignment.hasOutdatedRoutingEntries()) continue;
            return false;
        }
        return true;
    }

    AssignmentPlan computeAssignmentPlan() {
        Map<List<String>, List<AssignmentPlan.Node>> nodesByZone = this.createNodesByZoneMap();
        Set<String> assignableNodeIds = nodesByZone.values().stream().flatMap(Collection::stream).map(AssignmentPlan.Node::id).collect(Collectors.toSet());
        AssignmentPlan planForNormalPriorityModels = this.computePlanForNormalPriorityModels(nodesByZone, assignableNodeIds);
        AssignmentPlan planForLowPriorityModels = this.computePlanForLowPriorityModels(assignableNodeIds, planForNormalPriorityModels);
        return this.mergePlans(nodesByZone, planForNormalPriorityModels, planForLowPriorityModels);
    }

    private AssignmentPlan mergePlans(Map<List<String>, List<AssignmentPlan.Node>> nodesByZone, AssignmentPlan planForNormalPriorityModels, AssignmentPlan planForLowPriorityModels) {
        ArrayList<AssignmentPlan.Node> allNodes = new ArrayList<AssignmentPlan.Node>();
        nodesByZone.values().forEach(allNodes::addAll);
        ArrayList<AssignmentPlan.Deployment> allDeployments = new ArrayList<AssignmentPlan.Deployment>();
        allDeployments.addAll(planForNormalPriorityModels.models());
        allDeployments.addAll(planForLowPriorityModels.models());
        Map<String, AssignmentPlan.Node> originalNodeById = allNodes.stream().collect(Collectors.toMap(AssignmentPlan.Node::id, Function.identity()));
        AssignmentPlan.Builder finalPlanBuilder = AssignmentPlan.builder(allNodes, allDeployments);
        TrainedModelAssignmentRebalancer.copyAssignments(planForNormalPriorityModels, finalPlanBuilder, originalNodeById);
        TrainedModelAssignmentRebalancer.copyAssignments(planForLowPriorityModels, finalPlanBuilder, originalNodeById);
        return finalPlanBuilder.build();
    }

    private static void copyAssignments(AssignmentPlan source, AssignmentPlan.Builder dest, Map<String, AssignmentPlan.Node> originalNodeById) {
        for (AssignmentPlan.Deployment m : source.models()) {
            Map nodeAssignments = source.assignments(m).orElse(Map.of());
            for (Map.Entry assignment : nodeAssignments.entrySet()) {
                AssignmentPlan.Node originalNode = originalNodeById.get(((AssignmentPlan.Node)assignment.getKey()).id());
                dest.assignModelToNode(m, originalNode, (Integer)assignment.getValue());
                if (!m.currentAllocationsByNodeId().containsKey(originalNode.id())) continue;
                dest.accountMemory(m, originalNode);
            }
        }
    }

    private AssignmentPlan computePlanForNormalPriorityModels(Map<List<String>, List<AssignmentPlan.Node>> nodesByZone, Set<String> assignableNodeIds) {
        ArrayList<AssignmentPlan.Deployment> planDeployments = new ArrayList<AssignmentPlan.Deployment>();
        this.currentMetadata.allAssignments().values().stream().filter(assignment -> assignment.getTaskParams().getPriority() != Priority.LOW).map(assignment -> {
            Map<String, Integer> currentAssignments = assignment.getNodeRoutingTable().entrySet().stream().filter(e -> assignableNodeIds.contains(e.getKey())).filter(e -> ((RoutingInfo)e.getValue()).getCurrentAllocations() > 0 && ((RoutingInfo)e.getValue()).getTargetAllocations() > 0).filter(e -> ((RoutingInfo)e.getValue()).getState().isAnyOf(new RoutingState[]{RoutingState.STARTING, RoutingState.STARTED, RoutingState.FAILED})).collect(Collectors.toMap(Map.Entry::getKey, e -> ((RoutingInfo)e.getValue()).getTargetAllocations()));
            return new AssignmentPlan.Deployment(assignment.getDeploymentId(), assignment.getTaskParams().estimateMemoryUsageBytes(), assignment.getTaskParams().getNumberOfAllocations(), assignment.getTaskParams().getThreadsPerAllocation(), currentAssignments, assignment.getMaxAssignedAllocations());
        }).forEach(planDeployments::add);
        if (this.deploymentToAdd.isPresent() && this.deploymentToAdd.get().getPriority() != Priority.LOW) {
            StartTrainedModelDeploymentAction.TaskParams taskParams = this.deploymentToAdd.get();
            planDeployments.add(new AssignmentPlan.Deployment(taskParams.getDeploymentId(), taskParams.estimateMemoryUsageBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), 0));
        }
        return new ZoneAwareAssignmentPlanner(nodesByZone, planDeployments).computePlan();
    }

    private AssignmentPlan computePlanForLowPriorityModels(Set<String> assignableNodeIds, AssignmentPlan planExcludingLowPriorityModels) {
        List<AssignmentPlan.Node> planNodes = this.mlNodesByZone.values().stream().flatMap(Collection::stream).map(discoveryNode -> new AssignmentPlan.Node(discoveryNode.getId(), planExcludingLowPriorityModels.getRemainingNodeMemory(discoveryNode.getId()), 100)).toList();
        HashMap remainingNodeMemory = new HashMap();
        planNodes.forEach(n -> remainingNodeMemory.put(n.id(), n.availableMemoryBytes()));
        ArrayList<AssignmentPlan.Deployment> planDeployments = new ArrayList<AssignmentPlan.Deployment>();
        this.currentMetadata.allAssignments().values().stream().filter(assignment -> assignment.getTaskParams().getPriority() == Priority.LOW).sorted(Comparator.comparingLong(assignment -> assignment.getTaskParams().estimateMemoryUsageBytes())).map(assignment -> new AssignmentPlan.Deployment(assignment.getDeploymentId(), assignment.getTaskParams().estimateMemoryUsageBytes(), assignment.getTaskParams().getNumberOfAllocations(), assignment.getTaskParams().getThreadsPerAllocation(), this.findFittingAssignments((TrainedModelAssignment)assignment, assignableNodeIds, remainingNodeMemory), assignment.getMaxAssignedAllocations(), Priority.LOW)).forEach(planDeployments::add);
        if (this.deploymentToAdd.isPresent() && this.deploymentToAdd.get().getPriority() == Priority.LOW) {
            StartTrainedModelDeploymentAction.TaskParams taskParams = this.deploymentToAdd.get();
            planDeployments.add(new AssignmentPlan.Deployment(taskParams.getDeploymentId(), taskParams.estimateMemoryUsageBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), 0, Priority.LOW));
        }
        logger.debug(() -> Strings.format((String)"Computing plan for low priority deployments. CPU cores fixed to [%s].", (Object[])new Object[]{100}));
        return new AssignmentPlanner(planNodes, planDeployments).computePlan();
    }

    private Map<String, Integer> findFittingAssignments(TrainedModelAssignment assignment, Set<String> assignableNodeIds, Map<String, Long> remainingNodeMemory) {
        Map<String, Integer> currentAssignments = assignment.getNodeRoutingTable().entrySet().stream().filter(e -> assignableNodeIds.contains(e.getKey())).filter(e -> ((RoutingInfo)e.getValue()).getState().isAnyOf(new RoutingState[]{RoutingState.STARTING, RoutingState.STARTED, RoutingState.FAILED})).collect(Collectors.toMap(Map.Entry::getKey, e -> ((RoutingInfo)e.getValue()).getTargetAllocations()));
        long modelMemoryBytes = assignment.getTaskParams().estimateMemoryUsageBytes();
        HashMap<String, Integer> fittingAssignments = new HashMap<String, Integer>();
        currentAssignments.entrySet().stream().filter(nodeToAllocations -> (Integer)nodeToAllocations.getValue() > 0).forEach(nodeToAllocations -> {
            if ((Long)remainingNodeMemory.get(nodeToAllocations.getKey()) >= modelMemoryBytes) {
                fittingAssignments.put((String)nodeToAllocations.getKey(), (Integer)nodeToAllocations.getValue());
                remainingNodeMemory.computeIfPresent((String)nodeToAllocations.getKey(), (k, v) -> v - modelMemoryBytes);
            }
        });
        return fittingAssignments;
    }

    private Map<List<String>, List<AssignmentPlan.Node>> createNodesByZoneMap() {
        return this.mlNodesByZone.entrySet().stream().collect(Collectors.toMap(e -> (List)e.getKey(), e -> {
            Collection discoveryNodes = (Collection)e.getValue();
            ArrayList<AssignmentPlan.Node> nodes = new ArrayList<AssignmentPlan.Node>();
            for (DiscoveryNode discoveryNode : discoveryNodes) {
                if (this.nodeLoads.containsKey(discoveryNode)) {
                    NodeLoad load = this.nodeLoads.get(discoveryNode);
                    if (org.elasticsearch.common.Strings.isNullOrEmpty((String)load.getError())) {
                        nodes.add(new AssignmentPlan.Node(discoveryNode.getId(), TrainedModelAssignmentRebalancer.getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(load), MlProcessors.get(discoveryNode, this.allocatedProcessorsScale).roundUp()));
                        continue;
                    }
                    logger.warn(Strings.format((String)"ignoring node [%s] as detecting its load failed with [%s]", (Object[])new Object[]{discoveryNode.getId(), load.getError()}));
                    continue;
                }
                logger.warn(Strings.format((String)"ignoring node [%s] as no load could be detected", (Object[])new Object[]{discoveryNode.getId()}));
            }
            return nodes;
        }));
    }

    private static long getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(NodeLoad load) {
        return load.getFreeMemoryExcludingPerNodeOverhead() - load.getAssignedNativeInferenceMemory();
    }

    private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(AssignmentPlan assignmentPlan) {
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.Builder.empty();
        for (AssignmentPlan.Deployment deployment : assignmentPlan.models()) {
            TrainedModelAssignment existingAssignment = this.currentMetadata.getDeploymentAssignment(deployment.id());
            TrainedModelAssignment.Builder assignmentBuilder = TrainedModelAssignment.Builder.empty((StartTrainedModelDeploymentAction.TaskParams)(existingAssignment == null && this.deploymentToAdd.isPresent() ? this.deploymentToAdd.get() : this.currentMetadata.getDeploymentAssignment(deployment.id()).getTaskParams()));
            if (existingAssignment != null) {
                assignmentBuilder.setStartTime(existingAssignment.getStartTime());
                assignmentBuilder.setMaxAssignedAllocations(existingAssignment.getMaxAssignedAllocations());
            }
            Map assignments = assignmentPlan.assignments(deployment).orElseGet(Map::of);
            for (Map.Entry assignment : assignments.entrySet()) {
                if (existingAssignment != null && existingAssignment.isRoutedToNode(((AssignmentPlan.Node)assignment.getKey()).id())) {
                    RoutingInfo existingRoutingInfo = (RoutingInfo)existingAssignment.getNodeRoutingTable().get(((AssignmentPlan.Node)assignment.getKey()).id());
                    RoutingState state = existingRoutingInfo.getState();
                    String reason = existingRoutingInfo.getReason();
                    if (state == RoutingState.FAILED) {
                        state = RoutingState.STARTING;
                        reason = "";
                    }
                    assignmentBuilder.addRoutingEntry(((AssignmentPlan.Node)assignment.getKey()).id(), new RoutingInfo(existingRoutingInfo.getCurrentAllocations(), ((Integer)assignment.getValue()).intValue(), state, reason));
                    continue;
                }
                assignmentBuilder.addRoutingEntry(((AssignmentPlan.Node)assignment.getKey()).id(), new RoutingInfo(((Integer)assignment.getValue()).intValue(), ((Integer)assignment.getValue()).intValue(), RoutingState.STARTING, ""));
            }
            assignmentBuilder.calculateAndSetAssignmentState();
            this.explainAssignments(assignmentPlan, this.nodeLoads, deployment).ifPresent(arg_0 -> ((TrainedModelAssignment.Builder)assignmentBuilder).setReason(arg_0));
            builder.addNewAssignment(deployment.id(), assignmentBuilder);
        }
        return builder;
    }

    private Optional<String> explainAssignments(AssignmentPlan assignmentPlan, Map<DiscoveryNode, NodeLoad> nodeLoads, AssignmentPlan.Deployment deployment) {
        if (assignmentPlan.satisfiesAllocations(deployment)) {
            return Optional.empty();
        }
        if (nodeLoads.isEmpty()) {
            return Optional.of("No ML nodes exist in the cluster");
        }
        TreeMap nodeToReason = new TreeMap();
        for (Map.Entry<DiscoveryNode, NodeLoad> nodeAndLoad : nodeLoads.entrySet()) {
            Optional<String> reason = this.explainAssignment(assignmentPlan, nodeAndLoad.getKey(), nodeAndLoad.getValue(), deployment);
            reason.ifPresent(s -> nodeToReason.put(((DiscoveryNode)nodeAndLoad.getKey()).getId(), s));
        }
        if (!nodeToReason.isEmpty()) {
            return Optional.of(nodeToReason.entrySet().stream().map(entry -> Strings.format((String)"Could not assign (more) allocations on node [%s]. Reason: %s", (Object[])new Object[]{entry.getKey(), entry.getValue()})).collect(Collectors.joining("|")));
        }
        return Optional.empty();
    }

    private Optional<String> explainAssignment(AssignmentPlan assignmentPlan, DiscoveryNode node, NodeLoad load, AssignmentPlan.Deployment deployment) {
        if (!org.elasticsearch.common.Strings.isNullOrEmpty((String)load.getError())) {
            return Optional.of(load.getError());
        }
        if (deployment.memoryBytes() > assignmentPlan.getRemainingNodeMemory(node.getId())) {
            boolean isPerNodeOverheadAccountedFor = load.getNumAssignedJobsAndModels() > 0 || assignmentPlan.getRemainingNodeCores(load.getNodeId()) < MlProcessors.get(node, this.allocatedProcessorsScale).roundUp();
            long requiredMemory = deployment.memoryBytes() + (isPerNodeOverheadAccountedFor ? 0L : MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes());
            long nodeFreeMemory = assignmentPlan.getRemainingNodeMemory(node.getId()) + (isPerNodeOverheadAccountedFor ? 0L : MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes());
            return Optional.of(ParameterizedMessage.format((String)"This node has insufficient available memory. Available memory for ML [{} ({})], free memory [{} ({})], estimated memory required for this model [{} ({})].", (Object[])new Object[]{load.getMaxMlMemory(), ByteSizeValue.ofBytes((long)load.getMaxMlMemory()).toString(), nodeFreeMemory, ByteSizeValue.ofBytes((long)nodeFreeMemory).toString(), requiredMemory, ByteSizeValue.ofBytes((long)requiredMemory).toString()}));
        }
        if (deployment.threadsPerAllocation() > assignmentPlan.getRemainingNodeCores(node.getId())) {
            return Optional.of(ParameterizedMessage.format((String)"This node has insufficient allocated processors. Available processors [{}], free processors [{}], processors required for each allocation of this model [{}]", (Object[])new Object[]{MlProcessors.get(node, this.allocatedProcessorsScale).roundUp(), assignmentPlan.getRemainingNodeCores(node.getId()), deployment.threadsPerAllocation()}));
        }
        return Optional.empty();
    }
}

