/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.cluster.routing.allocation.allocator;

import org.elasticsearch.cluster.ClusterInfo;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.routing.RoutingNodes;
import org.elasticsearch.cluster.routing.allocation.WriteLoadForecaster;
import org.elasticsearch.cluster.routing.allocation.allocator.BalancedShardsAllocator;
import org.elasticsearch.index.shard.ShardId;

public class WeightFunction {
    private final float theta0;
    private final float theta1;
    private final float theta2;
    private final float theta3;

    public WeightFunction(float shardBalance, float indexBalance, float writeLoadBalance, float diskUsageBalance) {
        float sum = shardBalance + indexBalance + writeLoadBalance + diskUsageBalance;
        if (sum <= 0.0f) {
            throw new IllegalArgumentException("Balance factors must sum to a value > 0 but was: " + sum);
        }
        this.theta0 = shardBalance / sum;
        this.theta1 = indexBalance / sum;
        this.theta2 = writeLoadBalance / sum;
        this.theta3 = diskUsageBalance / sum;
    }

    float calculateNodeWeightWithIndex(BalancedShardsAllocator.Balancer balancer, BalancedShardsAllocator.ModelNode node, BalancedShardsAllocator.ProjectIndex index) {
        float weightIndex = (float)node.numShards(index) - balancer.avgShardsPerNode(index);
        float nodeWeight = this.calculateNodeWeight(node.numShards(), balancer.avgShardsPerNode(), node.writeLoad(), balancer.avgWriteLoadPerNode(), node.diskUsageInBytes(), balancer.avgDiskUsageInBytesPerNode());
        return nodeWeight + this.theta1 * weightIndex;
    }

    public float calculateNodeWeight(int nodeNumShards, float avgShardsPerNode, double nodeWriteLoad, double avgWriteLoadPerNode, double diskUsageInBytes, double avgDiskUsageInBytesPerNode) {
        float weightShard = (float)nodeNumShards - avgShardsPerNode;
        float ingestLoad = (float)(nodeWriteLoad - avgWriteLoadPerNode);
        float diskUsage = (float)(diskUsageInBytes - avgDiskUsageInBytesPerNode);
        return this.theta0 * weightShard + this.theta2 * ingestLoad + this.theta3 * diskUsage;
    }

    float minWeightDelta(float shardWriteLoad, float shardSizeBytes) {
        return this.theta0 * 1.0f + this.theta1 * 1.0f + this.theta2 * shardWriteLoad + this.theta3 * shardSizeBytes;
    }

    public static float avgShardPerNode(Metadata metadata, RoutingNodes routingNodes) {
        return (float)metadata.getTotalNumberOfShards() / (float)routingNodes.size();
    }

    public static double avgWriteLoadPerNode(WriteLoadForecaster writeLoadForecaster, Metadata metadata, RoutingNodes routingNodes) {
        return WeightFunction.getTotalWriteLoad(writeLoadForecaster, metadata) / (double)routingNodes.size();
    }

    public static double avgDiskUsageInBytesPerNode(ClusterInfo clusterInfo, Metadata metadata, RoutingNodes routingNodes) {
        return (double)WeightFunction.getTotalDiskUsageInBytes(clusterInfo, metadata) / (double)routingNodes.size();
    }

    private static double getTotalWriteLoad(WriteLoadForecaster writeLoadForecaster, Metadata metadata) {
        double writeLoad = 0.0;
        for (ProjectMetadata project : metadata.projects().values()) {
            for (IndexMetadata indexMetadata : project.indices().values()) {
                writeLoad += WeightFunction.getIndexWriteLoad(writeLoadForecaster, indexMetadata);
            }
        }
        return writeLoad;
    }

    private static double getIndexWriteLoad(WriteLoadForecaster writeLoadForecaster, IndexMetadata indexMetadata) {
        double shardWriteLoad = writeLoadForecaster.getForecastedWriteLoad(indexMetadata).orElse(0.0);
        return shardWriteLoad * (double)WeightFunction.numberOfCopies(indexMetadata);
    }

    private static int numberOfCopies(IndexMetadata indexMetadata) {
        return indexMetadata.getNumberOfShards() * (1 + indexMetadata.getNumberOfReplicas());
    }

    private static long getTotalDiskUsageInBytes(ClusterInfo clusterInfo, Metadata metadata) {
        long totalDiskUsageInBytes = 0L;
        for (ProjectMetadata project : metadata.projects().values()) {
            for (IndexMetadata indexMetadata : project.indices().values()) {
                totalDiskUsageInBytes += WeightFunction.getIndexDiskUsageInBytes(clusterInfo, indexMetadata);
            }
        }
        return totalDiskUsageInBytes;
    }

    static long getIndexDiskUsageInBytes(ClusterInfo clusterInfo, IndexMetadata indexMetadata) {
        if (indexMetadata.ignoreDiskWatermarks()) {
            return 0L;
        }
        long forecastedShardSize = indexMetadata.getForecastedShardSizeInBytes().orElse(-1L);
        long totalSizeInBytes = 0L;
        int shardCount = 0;
        for (int shard = 0; shard < indexMetadata.getNumberOfShards(); ++shard) {
            long replicaShardSize;
            ShardId shardId = new ShardId(indexMetadata.getIndex(), shard);
            long primaryShardSize = Math.max(forecastedShardSize, clusterInfo.getShardSize(shardId, true, -1L));
            if (primaryShardSize != -1L) {
                totalSizeInBytes += primaryShardSize;
                ++shardCount;
            }
            if ((replicaShardSize = Math.max(forecastedShardSize, clusterInfo.getShardSize(shardId, false, -1L))) == -1L) continue;
            totalSizeInBytes += replicaShardSize * (long)indexMetadata.getNumberOfReplicas();
            shardCount += indexMetadata.getNumberOfReplicas();
        }
        if (shardCount == WeightFunction.numberOfCopies(indexMetadata)) {
            return totalSizeInBytes;
        }
        return shardCount == 0 ? 0L : totalSizeInBytes / (long)shardCount * (long)WeightFunction.numberOfCopies(indexMetadata);
    }
}

