/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.indices;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ReferenceDocs;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.health.metadata.HealthMetadata;
import org.elasticsearch.index.Index;

public class ShardLimitValidator {
    public static final Setting<Integer> SETTING_CLUSTER_MAX_SHARDS_PER_NODE = Setting.intSetting("cluster.max_shards_per_node", 1000, 1, Setting.Property.Dynamic, Setting.Property.NodeScope);
    public static final Setting<Integer> SETTING_CLUSTER_MAX_SHARDS_PER_NODE_FROZEN = Setting.intSetting("cluster.max_shards_per_node.frozen", 3000, 1, Setting.Property.Dynamic, Setting.Property.NodeScope);
    public static final String FROZEN_GROUP = "frozen";
    public static final String NORMAL_GROUP = "normal";
    private static final Set<String> VALID_INDEX_SETTING_GROUPS = Set.of("normal", "frozen");
    public static final Setting<String> INDEX_SETTING_SHARD_LIMIT_GROUP = Setting.simpleString("index.shard_limit.group", "normal", value -> {
        if (!VALID_INDEX_SETTING_GROUPS.contains(value)) {
            throw new IllegalArgumentException("[" + value + "] is not a valid shard limit group");
        }
    }, Setting.Property.IndexScope, Setting.Property.PrivateIndex, Setting.Property.NotCopyableOnResize);
    protected final AtomicInteger shardLimitPerNode = new AtomicInteger();
    protected final AtomicInteger shardLimitPerNodeFrozen = new AtomicInteger();
    private final boolean isStateless;

    public ShardLimitValidator(Settings settings, ClusterService clusterService) {
        this.shardLimitPerNode.set(SETTING_CLUSTER_MAX_SHARDS_PER_NODE.get(settings));
        this.shardLimitPerNodeFrozen.set(SETTING_CLUSTER_MAX_SHARDS_PER_NODE_FROZEN.get(settings));
        clusterService.getClusterSettings().addSettingsUpdateConsumer(SETTING_CLUSTER_MAX_SHARDS_PER_NODE, this::setShardLimitPerNode);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(SETTING_CLUSTER_MAX_SHARDS_PER_NODE_FROZEN, this::setShardLimitPerNodeFrozen);
        this.isStateless = DiscoveryNode.isStateless(settings);
    }

    private void setShardLimitPerNode(int newValue) {
        this.shardLimitPerNode.set(newValue);
    }

    private void setShardLimitPerNodeFrozen(int newValue) {
        this.shardLimitPerNodeFrozen.set(newValue);
    }

    private int getShardLimitPerNode(LimitGroup limitGroup) {
        return switch (limitGroup.ordinal()) {
            default -> throw new MatchException(null, null);
            case 0, 2, 3 -> this.shardLimitPerNode.get();
            case 1 -> this.shardLimitPerNodeFrozen.get();
        };
    }

    public static int getShardLimitPerNode(LimitGroup limitGroup, HealthMetadata.ShardLimits shardLimits) {
        return switch (limitGroup.ordinal()) {
            default -> throw new MatchException(null, null);
            case 0, 2, 3 -> shardLimits.maxShardsPerNode();
            case 1 -> shardLimits.maxShardsPerNodeFrozen();
        };
    }

    public void validateShardLimit(Settings settings, DiscoveryNodes discoveryNodes, Metadata metadata) {
        List<LimitGroup> limitGroups = ShardLimitValidator.applicableLimitGroups(this.isStateless);
        Map<LimitGroup, Integer> shardsToCreatePerGroup = limitGroups.stream().collect(Collectors.toUnmodifiableMap(Function.identity(), limitGroup -> limitGroup.newShardsTotal(settings)));
        Result result = this.checkShardLimitOnGroups(limitGroups, shardsToCreatePerGroup, discoveryNodes, metadata);
        if (!result.canAddShards) {
            ValidationException e = new ValidationException();
            e.addValidationError(ShardLimitValidator.errorMessageFrom(result));
            throw e;
        }
    }

    public void validateShardLimit(DiscoveryNodes discoveryNodes, Metadata metadata, Index[] indicesToOpen) {
        List<LimitGroup> limitGroups = ShardLimitValidator.applicableLimitGroups(this.isStateless);
        HashMap<LimitGroup, Integer> shardsToCreatePerGroup = new HashMap<LimitGroup, Integer>();
        for (Index index : indicesToOpen) {
            IndexMetadata imd = metadata.indexMetadata(index);
            if (!imd.getState().equals(IndexMetadata.State.CLOSE)) continue;
            limitGroups.forEach(limitGroup -> shardsToCreatePerGroup.merge((LimitGroup)((Object)limitGroup), limitGroup.newShardsTotal(imd.getSettings()), Integer::sum));
        }
        Result result = this.checkShardLimitOnGroups(limitGroups, shardsToCreatePerGroup, discoveryNodes, metadata);
        if (!result.canAddShards) {
            ValidationException ex = new ValidationException();
            ex.addValidationError(ShardLimitValidator.errorMessageFrom(result));
            throw ex;
        }
    }

    public void validateShardLimitOnReplicaUpdate(DiscoveryNodes discoveryNodes, Metadata metadata, Index[] indices, int replicas) {
        List<LimitGroup> limitGroups = ShardLimitValidator.applicableLimitGroups(this.isStateless);
        HashMap<LimitGroup, Integer> shardsToCreatePerGroup = new HashMap<LimitGroup, Integer>();
        for (Index index : indices) {
            IndexMetadata imd = metadata.indexMetadata(index);
            limitGroups.forEach(limitGroup -> shardsToCreatePerGroup.merge((LimitGroup)((Object)limitGroup), limitGroup.newShardsTotal(imd.getSettings(), replicas), Integer::sum));
        }
        Result result = this.checkShardLimitOnGroups(limitGroups, shardsToCreatePerGroup, discoveryNodes, metadata);
        if (!result.canAddShards) {
            ValidationException ex = new ValidationException();
            ex.addValidationError(ShardLimitValidator.errorMessageFrom(result));
            throw ex;
        }
    }

    public static List<LimitGroup> applicableLimitGroups(boolean isStateless) {
        return isStateless ? List.of(LimitGroup.INDEX, LimitGroup.SEARCH) : List.of(LimitGroup.NORMAL, LimitGroup.FROZEN);
    }

    private Result checkShardLimitOnGroups(List<LimitGroup> limitGroups, Map<LimitGroup, Integer> shardsToCreatePerGroup, DiscoveryNodes discoveryNodes, Metadata metadata) {
        assert (limitGroups.containsAll(shardsToCreatePerGroup.keySet())) : "limit groups " + String.valueOf(limitGroups) + " do not contain groups for shards creation " + String.valueOf(shardsToCreatePerGroup.keySet());
        Result result = null;
        for (LimitGroup limitGroup : limitGroups) {
            result = limitGroup.checkShardLimit(this.getShardLimitPerNode(limitGroup), shardsToCreatePerGroup.getOrDefault((Object)limitGroup, 0), discoveryNodes, metadata);
            if (result.canAddShards()) continue;
            return result;
        }
        assert (result != null);
        return result;
    }

    private static int nodeCount(DiscoveryNodes discoveryNodes, Predicate<DiscoveryNode> nodePredicate) {
        return (int)discoveryNodes.getDataNodes().values().stream().filter(nodePredicate).count();
    }

    private static boolean hasFrozen(DiscoveryNode node) {
        return node.getRoles().contains(DiscoveryNodeRole.DATA_FROZEN_NODE_ROLE);
    }

    private static boolean hasNonFrozen(DiscoveryNode node) {
        return node.getRoles().stream().anyMatch(r -> r.canContainData() && r != DiscoveryNodeRole.DATA_FROZEN_NODE_ROLE);
    }

    static String errorMessageFrom(Result result) {
        return "this action would add [" + result.totalShardsToAdd + "] shards, but this cluster currently has [" + String.valueOf(result.currentUsedShards.get()) + "]/[" + result.maxShardsInCluster + "] maximum " + String.valueOf((Object)result.group) + " shards open; for more information, see " + String.valueOf((Object)ReferenceDocs.MAX_SHARDS_PER_NODE);
    }

    private static boolean isOpenIndex(IndexMetadata indexMetadata) {
        return indexMetadata.getState().equals(IndexMetadata.State.OPEN);
    }

    private static boolean matchesIndexSettingGroup(IndexMetadata indexMetadata, String group) {
        return group.equals(INDEX_SETTING_SHARD_LIMIT_GROUP.get(indexMetadata.getSettings()));
    }

    public static enum LimitGroup {
        NORMAL("normal"){

            @Override
            public int numberOfNodes(DiscoveryNodes discoveryNodes) {
                return ShardLimitValidator.nodeCount(discoveryNodes, ShardLimitValidator::hasNonFrozen);
            }

            @Override
            public int countShards(IndexMetadata indexMetadata) {
                return ShardLimitValidator.isOpenIndex(indexMetadata) && ShardLimitValidator.matchesIndexSettingGroup(indexMetadata, NORMAL.groupName()) ? indexMetadata.getTotalNumberOfShards() : 0;
            }

            @Override
            public int newShardsTotal(int shards, int replicas) {
                return shards * (1 + replicas);
            }

            @Override
            protected int newReplicaShards(boolean isFrozenIndex, int shards, int replicaIncrease) {
                return isFrozenIndex ? 0 : shards * replicaIncrease;
            }
        }
        ,
        FROZEN("frozen"){

            @Override
            public int numberOfNodes(DiscoveryNodes discoveryNodes) {
                return ShardLimitValidator.nodeCount(discoveryNodes, ShardLimitValidator::hasFrozen);
            }

            @Override
            public int countShards(IndexMetadata indexMetadata) {
                return ShardLimitValidator.isOpenIndex(indexMetadata) && ShardLimitValidator.matchesIndexSettingGroup(indexMetadata, FROZEN.groupName()) ? indexMetadata.getTotalNumberOfShards() : 0;
            }

            @Override
            public int newShardsTotal(int shards, int replicas) {
                return shards * (1 + replicas);
            }

            @Override
            protected int newReplicaShards(boolean isFrozenIndex, int shards, int replicaIncrease) {
                return isFrozenIndex ? shards * replicaIncrease : 0;
            }
        }
        ,
        INDEX("index"){

            @Override
            public int numberOfNodes(DiscoveryNodes discoveryNodes) {
                return ShardLimitValidator.nodeCount(discoveryNodes, node -> node.hasRole(DiscoveryNodeRole.INDEX_ROLE.roleName()));
            }

            @Override
            public int countShards(IndexMetadata indexMetadata) {
                return ShardLimitValidator.isOpenIndex(indexMetadata) ? indexMetadata.getNumberOfShards() : 0;
            }

            @Override
            public int newShardsTotal(int shards, int replicas) {
                return shards;
            }

            @Override
            protected int newReplicaShards(boolean isFrozenIndex, int shards, int replicaIncrease) {
                return 0;
            }
        }
        ,
        SEARCH("search"){

            @Override
            public int numberOfNodes(DiscoveryNodes discoveryNodes) {
                return ShardLimitValidator.nodeCount(discoveryNodes, node -> node.hasRole(DiscoveryNodeRole.SEARCH_ROLE.roleName()));
            }

            @Override
            public int countShards(IndexMetadata indexMetadata) {
                return ShardLimitValidator.isOpenIndex(indexMetadata) ? indexMetadata.getNumberOfShards() * indexMetadata.getNumberOfReplicas() : 0;
            }

            @Override
            public int newShardsTotal(int shards, int replicas) {
                return shards * replicas;
            }

            @Override
            protected int newReplicaShards(boolean isFrozenIndex, int shards, int replicaIncrease) {
                return shards * replicaIncrease;
            }
        };

        private final String groupName;

        private LimitGroup(String groupName) {
            this.groupName = groupName;
        }

        public String groupName() {
            return this.groupName;
        }

        public String toString() {
            return this.groupName;
        }

        public abstract int numberOfNodes(DiscoveryNodes var1);

        public abstract int countShards(IndexMetadata var1);

        public abstract int newShardsTotal(int var1, int var2);

        protected abstract int newReplicaShards(boolean var1, int var2, int var3);

        public int newShardsTotal(Settings indexSettings) {
            boolean isFrozenIndex = ShardLimitValidator.FROZEN_GROUP.equals(INDEX_SETTING_SHARD_LIMIT_GROUP.get(indexSettings));
            int numberOfShards = isFrozenIndex == (this == FROZEN) ? IndexMetadata.INDEX_NUMBER_OF_SHARDS_SETTING.get(indexSettings) : 0;
            int numberOfReplicas = isFrozenIndex == (this == FROZEN) ? IndexMetadata.INDEX_NUMBER_OF_REPLICAS_SETTING.get(indexSettings) : 0;
            return this.newShardsTotal(numberOfShards, numberOfReplicas);
        }

        public int newShardsTotal(Settings indexSettings, int updatedReplicas) {
            boolean isFrozenIndex = ShardLimitValidator.FROZEN_GROUP.equals(INDEX_SETTING_SHARD_LIMIT_GROUP.get(indexSettings));
            int shards = IndexMetadata.INDEX_NUMBER_OF_SHARDS_SETTING.get(indexSettings);
            int replicas = IndexMetadata.INDEX_NUMBER_OF_REPLICAS_SETTING.get(indexSettings);
            int replicaIncrease = updatedReplicas - replicas;
            return this.newReplicaShards(isFrozenIndex, shards, replicaIncrease);
        }

        public Result checkShardLimit(int maxConfiguredShardsPerNode, int numberOfNewShards, int replicas, DiscoveryNodes discoveryNodes, Metadata metadata) {
            return this.checkShardLimit(maxConfiguredShardsPerNode, this.newShardsTotal(numberOfNewShards, replicas), discoveryNodes, metadata);
        }

        private Result checkShardLimit(int maxConfiguredShardsPerNode, int newShards, DiscoveryNodes discoveryNodes, Metadata metadata) {
            long currentFilteredShards;
            int nodeCount = this.numberOfNodes(discoveryNodes);
            int maxShardsInCluster = maxConfiguredShardsPerNode * nodeCount;
            int currentOpenShards = metadata.getTotalOpenIndexShards();
            if (nodeCount == 0 || newShards <= 0) {
                return new Result(true, Optional.empty(), newShards, maxShardsInCluster, this);
            }
            if (currentOpenShards + newShards > maxShardsInCluster && (currentFilteredShards = (long)metadata.projects().values().stream().flatMap(projectMetadata -> projectMetadata.indices().values().stream()).mapToInt(this::countShards).sum()) + (long)newShards > (long)maxShardsInCluster) {
                return new Result(false, Optional.of(currentFilteredShards), newShards, maxShardsInCluster, this);
            }
            return new Result(true, Optional.empty(), newShards, maxShardsInCluster, this);
        }
    }

    public record Result(boolean canAddShards, Optional<Long> currentUsedShards, int totalShardsToAdd, int maxShardsInCluster, LimitGroup group) {
    }
}

