/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.common.util.concurrent;

import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Function;
import org.elasticsearch.common.ExponentiallyWeightedMovingAverage;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.concurrent.TimedRunnable;
import org.elasticsearch.common.util.concurrent.WrappedRunnable;
import org.elasticsearch.core.TimeValue;

public final class TaskExecutionTimeTrackingEsThreadPoolExecutor
extends EsThreadPoolExecutor {
    private final Function<Runnable, WrappedRunnable> runnableWrapper;
    private final ExponentiallyWeightedMovingAverage executionEWMA;
    private final LongAdder totalExecutionTime = new LongAdder();
    private final boolean trackOngoingTasks;
    private final Map<Runnable, Long> ongoingTasks = new ConcurrentHashMap<Runnable, Long>();

    TaskExecutionTimeTrackingEsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, Function<Runnable, WrappedRunnable> runnableWrapper, ThreadFactory threadFactory, RejectedExecutionHandler handler, ThreadContext contextHolder, EsExecutors.TaskTrackingConfig trackingConfig) {
        super(name, corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler, contextHolder);
        this.runnableWrapper = runnableWrapper;
        this.executionEWMA = new ExponentiallyWeightedMovingAverage(trackingConfig.getEwmaAlpha(), 0.0);
        this.trackOngoingTasks = trackingConfig.trackOngoingTasks();
    }

    @Override
    protected Runnable wrapRunnable(Runnable command) {
        return super.wrapRunnable(this.runnableWrapper.apply(command));
    }

    @Override
    protected Runnable unwrap(Runnable runnable) {
        Runnable unwrapped = super.unwrap(runnable);
        if (unwrapped instanceof WrappedRunnable) {
            return ((WrappedRunnable)unwrapped).unwrap();
        }
        return unwrapped;
    }

    public double getTaskExecutionEWMA() {
        return this.executionEWMA.getAverage();
    }

    public long getTotalTaskExecutionTime() {
        return this.totalExecutionTime.sum();
    }

    public int getCurrentQueueSize() {
        return this.getQueue().size();
    }

    @Override
    protected void beforeExecute(Thread t, Runnable r) {
        if (this.trackOngoingTasks) {
            this.ongoingTasks.put(r, System.nanoTime());
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void afterExecute(Runnable r, Throwable t) {
        try {
            super.afterExecute(r, t);
            assert (super.unwrap(r) instanceof TimedRunnable) : "expected only TimedRunnables in queue";
            TimedRunnable timedRunnable = (TimedRunnable)super.unwrap(r);
            boolean failedOrRejected = timedRunnable.getFailedOrRejected();
            long taskExecutionNanos = timedRunnable.getTotalExecutionNanos();
            assert (taskExecutionNanos >= 0L || failedOrRejected && taskExecutionNanos == -1L) : "expected task to always take longer than 0 nanoseconds or have '-1' failure code, got: " + taskExecutionNanos + ", failedOrRejected: " + failedOrRejected;
            if (taskExecutionNanos != -1L) {
                this.executionEWMA.addValue(taskExecutionNanos);
                this.totalExecutionTime.add(taskExecutionNanos);
            }
        }
        finally {
            assert (this.trackOngoingTasks || this.ongoingTasks.isEmpty());
            if (this.trackOngoingTasks) {
                this.ongoingTasks.remove(r);
            }
        }
    }

    @Override
    protected void appendThreadPoolExecutorDetails(StringBuilder sb) {
        sb.append("task execution EWMA = ").append(TimeValue.timeValueNanos((long)this.executionEWMA.getAverage())).append(", ").append("total task execution time = ").append(TimeValue.timeValueNanos(this.getTotalTaskExecutionTime())).append(", ");
    }

    public Map<Runnable, Long> getOngoingTasks() {
        return this.trackOngoingTasks ? Map.copyOf(this.ongoingTasks) : Map.of();
    }

    public double getEwmaAlpha() {
        return this.executionEWMA.getAlpha();
    }
}

