/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.common.util.concurrent;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Function;
import org.elasticsearch.common.ExponentiallyWeightedMovingAverage;
import org.elasticsearch.common.metrics.ExponentialBucketHistogram;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.concurrent.TimedRunnable;
import org.elasticsearch.common.util.concurrent.WrappedRunnable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.telemetry.metric.DoubleWithAttributes;
import org.elasticsearch.telemetry.metric.Instrument;
import org.elasticsearch.telemetry.metric.LongWithAttributes;
import org.elasticsearch.telemetry.metric.MeterRegistry;

public final class TaskExecutionTimeTrackingEsThreadPoolExecutor
extends EsThreadPoolExecutor {
    public static final int QUEUE_LATENCY_HISTOGRAM_BUCKETS = 18;
    private static final int[] LATENCY_PERCENTILES_TO_REPORT = new int[]{50, 90, 99};
    private final Function<Runnable, WrappedRunnable> runnableWrapper;
    private final ExponentiallyWeightedMovingAverage executionEWMA;
    private final LongAdder totalExecutionTime = new LongAdder();
    private final boolean trackOngoingTasks;
    private final Map<Runnable, Long> ongoingTasks = new ConcurrentHashMap<Runnable, Long>();
    private volatile long lastPollTime = System.nanoTime();
    private volatile long lastTotalExecutionTime = 0L;
    private final ExponentialBucketHistogram queueLatencyMillisHistogram = new ExponentialBucketHistogram(18);

    TaskExecutionTimeTrackingEsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, Function<Runnable, WrappedRunnable> runnableWrapper, ThreadFactory threadFactory, RejectedExecutionHandler handler, ThreadContext contextHolder, EsExecutors.TaskTrackingConfig trackingConfig) {
        super(name, corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler, contextHolder);
        this.runnableWrapper = runnableWrapper;
        this.executionEWMA = new ExponentiallyWeightedMovingAverage(trackingConfig.getEwmaAlpha(), 0.0);
        this.trackOngoingTasks = trackingConfig.trackOngoingTasks();
    }

    public List<Instrument> setupMetrics(MeterRegistry meterRegistry, String threadPoolName) {
        return List.of(meterRegistry.registerLongsGauge("es.thread_pool." + threadPoolName + ".queue.latency.histogram", "Time tasks spent in the queue for the " + threadPoolName + " thread pool", "milliseconds", () -> {
            long[] snapshot = this.queueLatencyMillisHistogram.getSnapshot();
            int[] bucketUpperBounds = this.queueLatencyMillisHistogram.calculateBucketUpperBounds();
            List<LongWithAttributes> metricValues = Arrays.stream(LATENCY_PERCENTILES_TO_REPORT).mapToObj(percentile -> new LongWithAttributes(this.queueLatencyMillisHistogram.getPercentile((float)percentile / 100.0f, snapshot, bucketUpperBounds), Map.of("percentile", String.valueOf(percentile)))).toList();
            this.queueLatencyMillisHistogram.clear();
            return metricValues;
        }), meterRegistry.registerDoubleGauge("es.thread_pool." + threadPoolName + ".threads.utilization.current", "fraction of maximum thread time utilized for " + threadPoolName, "fraction", () -> new DoubleWithAttributes(this.pollUtilization(), Map.of())));
    }

    @Override
    protected Runnable wrapRunnable(Runnable command) {
        return super.wrapRunnable(this.runnableWrapper.apply(command));
    }

    @Override
    protected Runnable unwrap(Runnable runnable) {
        Runnable unwrapped = super.unwrap(runnable);
        if (unwrapped instanceof WrappedRunnable) {
            return ((WrappedRunnable)unwrapped).unwrap();
        }
        return unwrapped;
    }

    public double getTaskExecutionEWMA() {
        return this.executionEWMA.getAverage();
    }

    public long getTotalTaskExecutionTime() {
        return this.totalExecutionTime.sum();
    }

    public int getCurrentQueueSize() {
        return this.getQueue().size();
    }

    public double pollUtilization() {
        long currentTotalExecutionTimeNanos = this.totalExecutionTime.sum();
        long currentPollTimeNanos = System.nanoTime();
        long totalExecutionTimeSinceLastPollNanos = currentTotalExecutionTimeNanos - this.lastTotalExecutionTime;
        long timeSinceLastPoll = currentPollTimeNanos - this.lastPollTime;
        long maximumExecutionTimeSinceLastPollNanos = timeSinceLastPoll * (long)this.getMaximumPoolSize();
        double utilizationSinceLastPoll = (double)totalExecutionTimeSinceLastPollNanos / (double)maximumExecutionTimeSinceLastPollNanos;
        this.lastTotalExecutionTime = currentTotalExecutionTimeNanos;
        this.lastPollTime = currentPollTimeNanos;
        return utilizationSinceLastPoll;
    }

    @Override
    protected void beforeExecute(Thread t, Runnable r) {
        if (this.trackOngoingTasks) {
            this.ongoingTasks.put(r, System.nanoTime());
        }
        assert (super.unwrap(r) instanceof TimedRunnable) : "expected only TimedRunnables in queue";
        TimedRunnable timedRunnable = (TimedRunnable)super.unwrap(r);
        timedRunnable.beforeExecute();
        long taskQueueLatency = timedRunnable.getQueueTimeNanos();
        assert (taskQueueLatency >= 0L);
        this.queueLatencyMillisHistogram.addObservation(TimeUnit.NANOSECONDS.toMillis(taskQueueLatency));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void afterExecute(Runnable r, Throwable t) {
        try {
            super.afterExecute(r, t);
            assert (super.unwrap(r) instanceof TimedRunnable) : "expected only TimedRunnables in queue";
            TimedRunnable timedRunnable = (TimedRunnable)super.unwrap(r);
            boolean failedOrRejected = timedRunnable.getFailedOrRejected();
            long taskExecutionNanos = timedRunnable.getTotalExecutionNanos();
            assert (taskExecutionNanos >= 0L || failedOrRejected && taskExecutionNanos == -1L) : "expected task to always take longer than 0 nanoseconds or have '-1' failure code, got: " + taskExecutionNanos + ", failedOrRejected: " + failedOrRejected;
            if (taskExecutionNanos != -1L) {
                this.executionEWMA.addValue(taskExecutionNanos);
                this.totalExecutionTime.add(taskExecutionNanos);
            }
        }
        finally {
            assert (this.trackOngoingTasks || this.ongoingTasks.isEmpty());
            if (this.trackOngoingTasks) {
                this.ongoingTasks.remove(r);
            }
        }
    }

    @Override
    protected void appendThreadPoolExecutorDetails(StringBuilder sb) {
        sb.append("task execution EWMA = ").append(TimeValue.timeValueNanos((long)((long)this.executionEWMA.getAverage()))).append(", ").append("total task execution time = ").append(TimeValue.timeValueNanos((long)this.getTotalTaskExecutionTime())).append(", ");
    }

    public Map<Runnable, Long> getOngoingTasks() {
        return this.trackOngoingTasks ? Map.copyOf(this.ongoingTasks) : Map.of();
    }

    public double getEwmaAlpha() {
        return this.executionEWMA.getAlpha();
    }
}

