/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.assignment.planning;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.core.Strings;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;

public class AllocationReducer {
    private static final Logger logger = LogManager.getLogger(AllocationReducer.class);
    private final TrainedModelAssignment assignment;
    private final Map<List<String>, Set<String>> nodeIdsByZone;

    public AllocationReducer(TrainedModelAssignment assignment, Map<List<String>, Collection<DiscoveryNode>> nodesByZone) {
        this.assignment = Objects.requireNonNull(assignment);
        this.nodeIdsByZone = nodesByZone.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((Collection)e.getValue()).stream().map(DiscoveryNode::getId).collect(Collectors.toSet())));
    }

    public TrainedModelAssignment.Builder reduceTo(int numberOfAllocations) {
        Map<String, Integer> allocationsByNode = this.assignment.getNodeRoutingTable().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((RoutingInfo)e.getValue()).getTargetAllocations()));
        Map<List, Integer> allocationsByZone = this.nodeIdsByZone.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((Set)e.getValue()).stream().mapToInt(nodeId -> allocationsByNode.getOrDefault(nodeId, 0)).sum()));
        int totalRemainingAllocations = allocationsByZone.values().stream().mapToInt(Integer::intValue).sum();
        if (totalRemainingAllocations <= numberOfAllocations) {
            String msg = "request to reduce allocations is greater than or equal to the existing target number of allocations";
            throw new IllegalArgumentException(msg);
        }
        while (totalRemainingAllocations > numberOfAllocations) {
            int allocationsToRemove = totalRemainingAllocations - numberOfAllocations;
            List allocationsPerZoneInAscendingOrder = allocationsByZone.entrySet().stream().sorted(Map.Entry.comparingByValue()).toList();
            if (allocationsPerZoneInAscendingOrder.isEmpty()) {
                logger.warn("no allocations remain in any zone");
                throw new IllegalStateException("no allocations remain in any zone");
            }
            List largestZone = (List)allocationsPerZoneInAscendingOrder.get(allocationsPerZoneInAscendingOrder.size() - 1).getKey();
            int largestZoneAllocations = (Integer)allocationsPerZoneInAscendingOrder.get(allocationsPerZoneInAscendingOrder.size() - 1).getValue();
            int minAllocationsInOtherZones = allocationsPerZoneInAscendingOrder.size() <= 1 ? 0 : (Integer)allocationsPerZoneInAscendingOrder.get(0).getValue();
            List largestZoneAssignmentsInAscendingOrder = allocationsByNode.entrySet().stream().filter(e -> this.nodeIdsByZone.get(largestZone).contains(e.getKey())).sorted(Map.Entry.comparingByValue()).toList();
            if (largestZoneAssignmentsInAscendingOrder.isEmpty()) {
                logger.warn("no assignments remain in the largest zone");
                throw new IllegalStateException("no assignments remain in the largest zone");
            }
            Map.Entry<String, Integer> smallestAssignmentInLargestZone = largestZoneAssignmentsInAscendingOrder.get(0);
            if (this.canAssignmentBeRemovedEntirely(smallestAssignmentInLargestZone, minAllocationsInOtherZones, largestZoneAllocations, allocationsToRemove)) {
                logger.debug(() -> Strings.format((String)"[%s] removing assignment with [%s] allocations on node [%s]", (Object[])new Object[]{this.assignment.getDeploymentId(), smallestAssignmentInLargestZone.getValue(), smallestAssignmentInLargestZone.getKey()}));
                allocationsByNode.remove(smallestAssignmentInLargestZone.getKey());
                allocationsByZone.computeIfPresent(largestZone, (k, v) -> v - (Integer)smallestAssignmentInLargestZone.getValue());
                totalRemainingAllocations -= smallestAssignmentInLargestZone.getValue().intValue();
                continue;
            }
            logger.debug(() -> Strings.format((String)"[%s] removing 1 allocation from assignment with [%s] allocations on node [%s]", (Object[])new Object[]{this.assignment.getDeploymentId(), smallestAssignmentInLargestZone.getValue(), smallestAssignmentInLargestZone.getKey()}));
            allocationsByNode.computeIfPresent(smallestAssignmentInLargestZone.getKey(), (k, v) -> v - 1);
            allocationsByZone.computeIfPresent(largestZone, (k, v) -> v - 1);
            --totalRemainingAllocations;
        }
        return this.buildUpdatedAssignment(numberOfAllocations, allocationsByNode);
    }

    private boolean canAssignmentBeRemovedEntirely(Map.Entry<String, Integer> assignment, int minAllocationsInOtherZones, int zoneAllocations, int allocationsToRemove) {
        if (assignment.getValue() == 1) {
            return true;
        }
        if (assignment.getValue() > allocationsToRemove) {
            return false;
        }
        if (minAllocationsInOtherZones == 0) {
            return true;
        }
        return zoneAllocations - assignment.getValue() >= minAllocationsInOtherZones;
    }

    private TrainedModelAssignment.Builder buildUpdatedAssignment(int numberOfAllocations, Map<String, Integer> allocationsByNode) {
        TrainedModelAssignment.Builder reducedAssignmentBuilder = TrainedModelAssignment.Builder.fromAssignment((TrainedModelAssignment)this.assignment);
        reducedAssignmentBuilder.setNumberOfAllocations(numberOfAllocations);
        for (Map.Entry routingEntries : this.assignment.getNodeRoutingTable().entrySet()) {
            String nodeId = (String)routingEntries.getKey();
            if (allocationsByNode.containsKey(nodeId)) {
                RoutingInfo existingRoutingInfo = (RoutingInfo)routingEntries.getValue();
                reducedAssignmentBuilder.updateExistingRoutingEntry(nodeId, new RoutingInfo(existingRoutingInfo.getCurrentAllocations(), allocationsByNode.get(nodeId).intValue(), existingRoutingInfo.getState(), existingRoutingInfo.getReason()));
                continue;
            }
            reducedAssignmentBuilder.removeRoutingEntry(nodeId);
        }
        return reducedAssignmentBuilder;
    }
}

