/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.http.netty4;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleUserEventChannelHandler;
import io.netty.handler.ssl.SslClientHelloHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.util.ReferenceCounted;
import io.netty.util.concurrent.Future;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import java.util.function.ToLongFunction;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.telemetry.metric.LongWithAttributes;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.transport.NodeDisconnectedException;
import org.elasticsearch.transport.netty4.Netty4Plugin;

class TlsHandshakeThrottleManager
extends AbstractLifecycleComponent {
    private static final Logger logger = LogManager.getLogger(TlsHandshakeThrottleManager.class);
    private final ClusterSettings clusterSettings;
    private final MeterRegistry meterRegistry;
    private final List<AutoCloseable> metricsToClose = new ArrayList<AutoCloseable>(3);
    volatile int maxInProgressTlsHandshakes;
    volatile int maxDelayedTlsHandshakes;
    private final Map<Thread, TlsHandshakeThrottle> tlsHandshakeThrottles = ConcurrentCollections.newConcurrentMap();

    TlsHandshakeThrottleManager(ClusterSettings clusterSettings, MeterRegistry meterRegistry) {
        this.clusterSettings = clusterSettings;
        this.meterRegistry = meterRegistry;
    }

    private static <T> Setting<T> getRegisteredInstance(ClusterSettings clusterSettings, Setting<?> setting) {
        return Objects.requireNonNull(clusterSettings.get(setting.getKey()));
    }

    protected void doStart() {
        this.clusterSettings.initializeAndWatch(TlsHandshakeThrottleManager.getRegisteredInstance(this.clusterSettings, Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_IN_PROGRESS), maxInProgressTlsHandshakes -> {
            this.maxInProgressTlsHandshakes = maxInProgressTlsHandshakes;
        });
        this.clusterSettings.initializeAndWatch(TlsHandshakeThrottleManager.getRegisteredInstance(this.clusterSettings, Netty4Plugin.SETTING_HTTP_NETTY_TLS_HANDSHAKES_MAX_DELAYED), maxDelayedTlsHandshakes -> {
            this.maxDelayedTlsHandshakes = maxDelayedTlsHandshakes;
        });
        this.metricsToClose.add((AutoCloseable)this.meterRegistry.registerLongGauge("es.http.tls_handshakes.in_progress.current", "current number of in-progress TLS handshakes for HTTP connections", "count", this.getMetric(TlsHandshakeThrottle::getInProgressHandshakesCount)));
        this.metricsToClose.add((AutoCloseable)this.meterRegistry.registerLongGauge("es.http.tls_handshakes.delayed.current", "current number of delayed TLS handshakes for HTTP connections", "count", this.getMetric(TlsHandshakeThrottle::getCurrentDelayedHandshakesCount)));
        this.metricsToClose.add((AutoCloseable)this.meterRegistry.registerLongAsyncCounter("es.http.tls_handshakes.delayed.total", "total number of TLS handshakes for HTTP connections that were delayed due to throttling", "count", this.getMetric(TlsHandshakeThrottle::getTotalDelayedHandshakesCount)));
        this.metricsToClose.add((AutoCloseable)this.meterRegistry.registerLongAsyncCounter("es.http.tls_handshakes.dropped.total", "number of TLS handshakes for HTTP connections dropped due to throttling", "count", this.getMetric(TlsHandshakeThrottle::getDroppedHandshakesCount)));
    }

    protected void doStop() {
        this.tlsHandshakeThrottles.values().forEach(TlsHandshakeThrottle::close);
        for (AutoCloseable metricToClose : this.metricsToClose) {
            try {
                metricToClose.close();
            }
            catch (Exception e) {
                assert (false) : e;
                logger.error(Strings.format((String)"exception closing metric [%s]", (Object[])new Object[]{metricToClose}), (Throwable)e);
            }
        }
    }

    protected void doClose() {
    }

    private Supplier<LongWithAttributes> getMetric(ToLongFunction<TlsHandshakeThrottle> metricFunction) {
        return () -> {
            long result = 0L;
            for (TlsHandshakeThrottle tlsHandshakeThrottle : this.tlsHandshakeThrottles.values()) {
                result += metricFunction.applyAsLong(tlsHandshakeThrottle);
            }
            return new LongWithAttributes(result);
        };
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Nullable
    TlsHandshakeThrottle getThrottleForCurrentThread() {
        Lifecycle lifecycle = this.lifecycle;
        synchronized (lifecycle) {
            if (this.lifecycle.stoppedOrClosed()) {
                throw new IllegalStateException("HTTP transport is already stopped");
            }
            if (this.maxInProgressTlsHandshakes == 0) {
                return null;
            }
            return this.tlsHandshakeThrottles.computeIfAbsent(Thread.currentThread(), ignored -> new TlsHandshakeThrottle());
        }
    }

    class TlsHandshakeThrottle {
        private volatile int inProgressHandshakesCount = 0;
        private final ArrayDeque<AbstractRunnable> delayedHandshakes = new ArrayDeque();
        private volatile int delayedHandshakesCount = 0;
        private volatile long totalDelayedHandshakesCount = 0L;
        private volatile long droppedHandshakesCount = 0L;

        TlsHandshakeThrottle() {
        }

        private AbstractRunnable takeFirstDelayedHandshake() {
            AbstractRunnable result = this.delayedHandshakes.removeFirst();
            this.delayedHandshakesCount = this.delayedHandshakes.size();
            return result;
        }

        private AbstractRunnable takeLastDelayedHandshake() {
            AbstractRunnable result = this.delayedHandshakes.removeLast();
            this.delayedHandshakesCount = this.delayedHandshakes.size();
            return result;
        }

        private void addDelayedHandshake(AbstractRunnable abstractRunnable) {
            this.delayedHandshakes.addFirst(abstractRunnable);
            ++this.totalDelayedHandshakesCount;
            this.delayedHandshakesCount = this.delayedHandshakes.size();
        }

        void close() {
            while (!this.delayedHandshakes.isEmpty()) {
                this.takeFirstDelayedHandshake().onFailure((Exception)new NodeClosedException((DiscoveryNode)null));
            }
        }

        ChannelHandler newHandshakeThrottleHandler(SubscribableListener<Void> handshakeCompletePromise) {
            return new HandshakeThrottleHandler(handshakeCompletePromise);
        }

        ChannelHandler newHandshakeCompletionWatcher(SubscribableListener<Void> handshakeCompletePromise) {
            return new HandshakeCompletionWatcher(handshakeCompletePromise);
        }

        void handleHandshakeCompletion() {
            if (this.delayedHandshakes.isEmpty()) {
                --this.inProgressHandshakesCount;
            } else {
                this.takeFirstDelayedHandshake().run();
            }
        }

        public int getInProgressHandshakesCount() {
            return this.inProgressHandshakesCount;
        }

        public int getCurrentDelayedHandshakesCount() {
            return this.delayedHandshakesCount;
        }

        public long getTotalDelayedHandshakesCount() {
            return this.totalDelayedHandshakesCount;
        }

        public long getDroppedHandshakesCount() {
            return this.droppedHandshakesCount;
        }

        private class HandshakeThrottleHandler
        extends SslClientHelloHandler<Void> {
            private final SubscribableListener<Void> handshakeStartedPromise = new SubscribableListener();
            private final SubscribableListener<Void> handshakeCompletePromise;

            HandshakeThrottleHandler(SubscribableListener<Void> handshakeCompletePromise) {
                this.handshakeCompletePromise = handshakeCompletePromise;
            }

            public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
                if (msg instanceof ReferenceCounted) {
                    ReferenceCounted referenceCounted = (ReferenceCounted)msg;
                    referenceCounted.retain();
                }
                this.handshakeStartedPromise.addListener((ActionListener)new ActionListener<Void>(this){

                    public void onResponse(Void unused) {
                        ctx.fireChannelRead(msg);
                    }

                    public void onFailure(Exception e) {
                        if (msg instanceof ReferenceCounted) {
                            ReferenceCounted referenceCounted = (ReferenceCounted)msg;
                            referenceCounted.release();
                        }
                    }
                });
                super.channelRead(ctx, msg);
            }

            public void channelReadComplete(final ChannelHandlerContext ctx) throws Exception {
                this.handshakeStartedPromise.addListener((ActionListener)new ActionListener<Void>(this){

                    public void onResponse(Void unused) {
                        ctx.fireChannelReadComplete();
                    }

                    public void onFailure(Exception e) {
                    }
                });
                super.channelReadComplete(ctx);
            }

            public void channelInactive(ChannelHandlerContext ctx) throws Exception {
                this.handshakeStartedPromise.onFailure((Exception)new NodeDisconnectedException(null, "connection closed before handshake started", null, null));
                super.channelInactive(ctx);
            }

            protected Future<Void> lookup(final ChannelHandlerContext ctx, ByteBuf clientHello) {
                if (clientHello == null) {
                    logger.debug("lookup with no ClientHello, closing [{}]", (Object)ctx.channel());
                    ctx.channel().close();
                    IllegalArgumentException exception = new IllegalArgumentException("did not receive initial ClientHello on channel [" + String.valueOf(ctx.channel()) + "]");
                    this.handshakeStartedPromise.onFailure((Exception)exception);
                    return ctx.executor().newFailedFuture((Throwable)exception);
                }
                if (!ctx.channel().isActive()) {
                    logger.debug("lookup after channel inactive, ignoring [{}]", (Object)ctx.channel());
                    NodeDisconnectedException exception = new NodeDisconnectedException(null, "lookup after channel inactive [" + String.valueOf(ctx.channel()) + "]", null, null);
                    this.handshakeStartedPromise.onFailure((Exception)exception);
                    return ctx.executor().newFailedFuture((Throwable)exception);
                }
                int maxInProgressTlsHandshakes = TlsHandshakeThrottleManager.this.maxInProgressTlsHandshakes;
                if (maxInProgressTlsHandshakes == 0 || TlsHandshakeThrottle.this.inProgressHandshakesCount < maxInProgressTlsHandshakes) {
                    ++TlsHandshakeThrottle.this.inProgressHandshakesCount;
                    this.handshakeCompletePromise.addListener(ActionListener.running(TlsHandshakeThrottle.this::handleHandshakeCompletion));
                    ctx.channel().pipeline().remove((ChannelHandler)this);
                    this.handshakeStartedPromise.onResponse(null);
                } else {
                    logger.debug("[{}] in-progress TLS handshakes already, enqueueing new handshake on [{}]", (Object)TlsHandshakeThrottle.this.inProgressHandshakesCount, (Object)ctx.channel());
                    TlsHandshakeThrottle.this.addDelayedHandshake(new AbstractRunnable(){

                        public void onFailure(Exception e) {
                            logger.debug("[{}] in-progress and [{}] delayed TLS handshakes, cancelling handshake on [{}]: {}", (Object)TlsHandshakeThrottle.this.inProgressHandshakesCount, (Object)TlsHandshakeThrottle.this.delayedHandshakes.size(), (Object)ctx.channel(), (Object)e.getMessage());
                            ctx.channel().close();
                        }

                        protected void doRun() {
                            logger.debug("[{}] in flight and [{}] delayed TLS handshakes, processing delayed handshake on [{}]", (Object)TlsHandshakeThrottle.this.inProgressHandshakesCount, (Object)TlsHandshakeThrottle.this.delayedHandshakes.size(), (Object)ctx.channel());
                            HandshakeThrottleHandler.this.handshakeCompletePromise.addListener(ActionListener.running(TlsHandshakeThrottle.this::handleHandshakeCompletion));
                            ctx.pipeline().remove((ChannelHandler)HandshakeThrottleHandler.this);
                            HandshakeThrottleHandler.this.handshakeStartedPromise.onResponse(null);
                        }

                        public String toString() {
                            return "delayed handshake on [" + String.valueOf(ctx.channel()) + "]";
                        }
                    });
                    int maxDelayedTlsHandshakes = TlsHandshakeThrottleManager.this.maxDelayedTlsHandshakes;
                    while (TlsHandshakeThrottle.this.delayedHandshakes.size() > maxDelayedTlsHandshakes) {
                        AbstractRunnable lastDelayedHandshake = TlsHandshakeThrottle.this.takeLastDelayedHandshake();
                        ++TlsHandshakeThrottle.this.droppedHandshakesCount;
                        lastDelayedHandshake.onFailure((Exception)new ElasticsearchException("too many in-flight TLS handshakes", new Object[0]));
                    }
                }
                ctx.read();
                return ctx.executor().newSucceededFuture(null);
            }

            protected void onLookupComplete(ChannelHandlerContext ctx, Future<Void> future) {
            }
        }

        private static class HandshakeCompletionWatcher
        extends SimpleUserEventChannelHandler<SslHandshakeCompletionEvent> {
            private final SubscribableListener<Void> handshakeCompletePromise;

            HandshakeCompletionWatcher(SubscribableListener<Void> handshakeCompletePromise) {
                this.handshakeCompletePromise = handshakeCompletePromise;
            }

            protected void eventReceived(ChannelHandlerContext ctx, SslHandshakeCompletionEvent evt) {
                ctx.pipeline().remove((ChannelHandler)this);
                if (evt.isSuccess()) {
                    this.handshakeCompletePromise.onResponse(null);
                } else {
                    Object object;
                    ExceptionsHelper.maybeDieOnAnotherThread((Throwable)evt.cause());
                    Throwable throwable = evt.cause();
                    if (throwable instanceof Exception) {
                        Exception exception = (Exception)throwable;
                        object = exception;
                    } else {
                        object = new ElasticsearchException("TLS handshake failed", evt.cause(), new Object[0]);
                    }
                    this.handshakeCompletePromise.onFailure((Exception)object);
                }
            }

            public void channelInactive(ChannelHandlerContext ctx) throws Exception {
                if (!this.handshakeCompletePromise.isDone()) {
                    this.handshakeCompletePromise.onFailure((Exception)new ElasticsearchException("channel closed before TLS handshake completed", new Object[0]));
                }
                super.channelInactive(ctx);
            }
        }
    }
}

