/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.common.util.concurrent;

import java.util.AbstractQueue;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.LongAccumulator;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Function;
import org.elasticsearch.common.ExponentiallyWeightedMovingAverage;
import org.elasticsearch.common.metrics.ExponentialBucketHistogram;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
import org.elasticsearch.common.util.concurrent.SizeBlockingQueue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.concurrent.TimedRunnable;
import org.elasticsearch.common.util.concurrent.WrappedRunnable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.telemetry.metric.DoubleWithAttributes;
import org.elasticsearch.telemetry.metric.Instrument;
import org.elasticsearch.telemetry.metric.LongWithAttributes;
import org.elasticsearch.telemetry.metric.MeterRegistry;

public final class TaskExecutionTimeTrackingEsThreadPoolExecutor
extends EsThreadPoolExecutor {
    public static final int QUEUE_LATENCY_HISTOGRAM_BUCKETS = 18;
    private static final int[] LATENCY_PERCENTILES_TO_REPORT = new int[]{50, 90, 99};
    private final Function<Runnable, WrappedRunnable> runnableWrapper;
    private final ExponentiallyWeightedMovingAverage executionEWMA;
    private final LongAdder totalExecutionTime = new LongAdder();
    private final boolean trackOngoingTasks;
    private final Map<Runnable, Long> ongoingTasks = new ConcurrentHashMap<Runnable, Long>();
    private final ExponentialBucketHistogram queueLatencyMillisHistogram = new ExponentialBucketHistogram(18);
    private final boolean trackMaxQueueLatency;
    private LongAccumulator maxQueueLatencyMillisSinceLastPoll = new LongAccumulator(Long::max, 0L);
    private volatile UtilizationTracker apmUtilizationTracker = new UtilizationTracker();
    private volatile UtilizationTracker allocationUtilizationTracker = new UtilizationTracker();

    TaskExecutionTimeTrackingEsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, Function<Runnable, WrappedRunnable> runnableWrapper, ThreadFactory threadFactory, RejectedExecutionHandler handler, ThreadContext contextHolder, EsExecutors.TaskTrackingConfig trackingConfig) {
        super(name, corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler, contextHolder);
        this.runnableWrapper = runnableWrapper;
        this.executionEWMA = new ExponentiallyWeightedMovingAverage(trackingConfig.getExecutionTimeEwmaAlpha(), 0.0);
        this.trackOngoingTasks = trackingConfig.trackOngoingTasks();
        this.trackMaxQueueLatency = trackingConfig.trackMaxQueueLatency();
    }

    public List<Instrument> setupMetrics(MeterRegistry meterRegistry, String threadPoolName) {
        return List.of(meterRegistry.registerLongsGauge("es.thread_pool." + threadPoolName + ".queue.latency.histogram", "Time tasks spent in the queue for the " + threadPoolName + " thread pool", "milliseconds", () -> {
            long[] snapshot = this.queueLatencyMillisHistogram.getSnapshot();
            int[] bucketUpperBounds = this.queueLatencyMillisHistogram.calculateBucketUpperBounds();
            List<LongWithAttributes> metricValues = Arrays.stream(LATENCY_PERCENTILES_TO_REPORT).mapToObj(percentile -> new LongWithAttributes(this.queueLatencyMillisHistogram.getPercentile((float)percentile / 100.0f, snapshot, bucketUpperBounds), Map.of("percentile", String.valueOf(percentile)))).toList();
            this.queueLatencyMillisHistogram.clear();
            return metricValues;
        }), meterRegistry.registerDoubleGauge("es.thread_pool." + threadPoolName + ".threads.utilization.current", "fraction of maximum thread time utilized for " + threadPoolName, "fraction", () -> new DoubleWithAttributes(this.pollUtilization(UtilizationTrackingPurpose.APM), Map.of())));
    }

    @Override
    protected Runnable wrapRunnable(Runnable command) {
        return super.wrapRunnable(this.runnableWrapper.apply(command));
    }

    @Override
    protected Runnable unwrap(Runnable runnable) {
        Runnable unwrapped = super.unwrap(runnable);
        if (unwrapped instanceof WrappedRunnable) {
            return ((WrappedRunnable)unwrapped).unwrap();
        }
        return unwrapped;
    }

    public double getTaskExecutionEWMA() {
        return this.executionEWMA.getAverage();
    }

    public long getTotalTaskExecutionTime() {
        return this.totalExecutionTime.sum();
    }

    public int getCurrentQueueSize() {
        return this.getQueue().size();
    }

    public long getMaxQueueLatencyMillisSinceLastPollAndReset() {
        if (!this.trackMaxQueueLatency) {
            return 0L;
        }
        return this.maxQueueLatencyMillisSinceLastPoll.getThenReset();
    }

    public long peekMaxQueueLatencyInQueueMillis() {
        if (!this.trackMaxQueueLatency) {
            return 0L;
        }
        BlockingQueue<Runnable> queue = this.getQueue();
        assert (queue instanceof LinkedTransferQueue || queue instanceof SizeBlockingQueue) : "Not the type of queue expected: " + String.valueOf(queue.getClass());
        AbstractQueue linkedTransferOrSizeBlockingQueue = queue instanceof LinkedTransferQueue ? (LinkedTransferQueue)queue : (SizeBlockingQueue)queue;
        Object task = linkedTransferOrSizeBlockingQueue.peek();
        if (task == null) {
            return 0L;
        }
        assert (task instanceof WrappedRunnable) : "Not the type of task expected: " + String.valueOf(task.getClass());
        Runnable wrappedTask = ((WrappedRunnable)task).unwrap();
        assert (wrappedTask instanceof TimedRunnable) : "Not the type of task expected: " + String.valueOf(task.getClass());
        TimedRunnable timedTask = (TimedRunnable)wrappedTask;
        return TimeUnit.NANOSECONDS.toMillis(timedTask.getTimeSinceCreationNanos());
    }

    public double pollUtilization(UtilizationTrackingPurpose utilizationTrackingPurpose) {
        switch (utilizationTrackingPurpose.ordinal()) {
            case 0: {
                return this.apmUtilizationTracker.pollUtilization();
            }
            case 1: {
                return this.allocationUtilizationTracker.pollUtilization();
            }
        }
        throw new IllegalStateException("No operation defined for [" + String.valueOf((Object)utilizationTrackingPurpose) + "]");
    }

    @Override
    protected void beforeExecute(Thread t, Runnable r) {
        if (this.trackOngoingTasks) {
            this.ongoingTasks.put(r, System.nanoTime());
        }
        assert (super.unwrap(r) instanceof TimedRunnable) : "expected only TimedRunnables in queue";
        TimedRunnable timedRunnable = (TimedRunnable)super.unwrap(r);
        timedRunnable.beforeExecute();
        long taskQueueLatency = timedRunnable.getQueueTimeNanos();
        assert (taskQueueLatency >= 0L);
        long queueLatencyMillis = TimeUnit.NANOSECONDS.toMillis(taskQueueLatency);
        this.queueLatencyMillisHistogram.addObservation(queueLatencyMillis);
        if (this.trackMaxQueueLatency) {
            this.maxQueueLatencyMillisSinceLastPoll.accumulate(queueLatencyMillis);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void afterExecute(Runnable r, Throwable t) {
        try {
            super.afterExecute(r, t);
            assert (super.unwrap(r) instanceof TimedRunnable) : "expected only TimedRunnables in queue";
            TimedRunnable timedRunnable = (TimedRunnable)super.unwrap(r);
            boolean failedOrRejected = timedRunnable.getFailedOrRejected();
            long taskExecutionNanos = timedRunnable.getTotalExecutionNanos();
            assert (taskExecutionNanos >= 0L || failedOrRejected && taskExecutionNanos == -1L) : "expected task to always take longer than 0 nanoseconds or have '-1' failure code, got: " + taskExecutionNanos + ", failedOrRejected: " + failedOrRejected;
            if (taskExecutionNanos != -1L) {
                this.executionEWMA.addValue(taskExecutionNanos);
                this.totalExecutionTime.add(taskExecutionNanos);
            }
        }
        finally {
            assert (this.trackOngoingTasks || this.ongoingTasks.isEmpty());
            if (this.trackOngoingTasks) {
                this.ongoingTasks.remove(r);
            }
        }
    }

    @Override
    protected void appendThreadPoolExecutorDetails(StringBuilder sb) {
        sb.append("task execution EWMA = ").append(TimeValue.timeValueNanos((long)this.executionEWMA.getAverage())).append(", ").append("total task execution time = ").append(TimeValue.timeValueNanos(this.getTotalTaskExecutionTime())).append(", ");
    }

    public Map<Runnable, Long> getOngoingTasks() {
        return this.trackOngoingTasks ? Map.copyOf(this.ongoingTasks) : Map.of();
    }

    public double getExecutionEwmaAlpha() {
        return this.executionEWMA.getAlpha();
    }

    public boolean trackingMaxQueueLatency() {
        return this.trackMaxQueueLatency;
    }

    private class UtilizationTracker {
        long lastPollTime = System.nanoTime();
        long lastTotalExecutionTime = 0L;

        private UtilizationTracker() {
        }

        public synchronized double pollUtilization() {
            long currentTotalExecutionTimeNanos = TaskExecutionTimeTrackingEsThreadPoolExecutor.this.totalExecutionTime.sum();
            long currentPollTimeNanos = System.nanoTime();
            long totalExecutionTimeSinceLastPollNanos = currentTotalExecutionTimeNanos - this.lastTotalExecutionTime;
            long timeSinceLastPoll = currentPollTimeNanos - this.lastPollTime;
            long maximumExecutionTimeSinceLastPollNanos = timeSinceLastPoll * (long)TaskExecutionTimeTrackingEsThreadPoolExecutor.this.getMaximumPoolSize();
            double utilizationSinceLastPoll = (double)totalExecutionTimeSinceLastPollNanos / (double)maximumExecutionTimeSinceLastPollNanos;
            this.lastTotalExecutionTime = currentTotalExecutionTimeNanos;
            this.lastPollTime = currentPollTimeNanos;
            return utilizationSinceLastPoll;
        }
    }

    public static enum UtilizationTrackingPurpose {
        APM,
        ALLOCATION;

    }
}

