/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.security.authc.jwt;

import com.nimbusds.jose.jwk.JWK;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.hash.MessageDigests;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.SettingsException;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.watcher.FileChangesListener;
import org.elasticsearch.watcher.FileWatcher;
import org.elasticsearch.xpack.core.security.authc.RealmConfig;
import org.elasticsearch.xpack.core.security.authc.RealmSettings;
import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
import org.elasticsearch.xpack.core.ssl.SSLService;
import org.elasticsearch.xpack.security.PrivilegedFileWatcher;
import org.elasticsearch.xpack.security.authc.jwt.JwkValidateUtil;
import org.elasticsearch.xpack.security.authc.jwt.JwtUtil;
import org.elasticsearch.xpack.security.authc.jwt.PkcJwkSetReloadNotifier;

class JwkSetLoader
implements Releasable {
    private static final Logger logger = LogManager.getLogger(JwkSetLoader.class);
    private static final double URL_RELOAD_JITTER_PCT = 0.01;
    private final AtomicReference<ListenableFuture<Void>> reloadFutureRef = new AtomicReference();
    private final RealmConfig realmConfig;
    private final List<String> allowedJwksAlgsPkc;
    private final PkcJwkSetLoader loader;
    private final PkcJwkSetReloadNotifier reloadNotifier;
    private volatile ContentAndJwksAlgs contentAndJwksAlgs = new ContentAndJwksAlgs(new byte[32], new JwksAlgs(Collections.emptyList(), Collections.emptyList()));

    JwkSetLoader(RealmConfig realmConfig, List<String> allowedJwksAlgsPkc, SSLService sslService, ThreadPool threadPool, PkcJwkSetReloadNotifier reloadNotifier) {
        this.realmConfig = realmConfig;
        this.allowedJwksAlgsPkc = allowedJwksAlgsPkc;
        String jwkSetPath = (String)realmConfig.getSetting(JwtRealmSettings.PKC_JWKSET_PATH);
        assert (Strings.hasText((String)jwkSetPath));
        URI jwkSetPathUri = JwtUtil.parseHttpsUri(jwkSetPath);
        Consumer<byte[]> listener = content -> this.handleReloadedContentAndJwksAlgs((byte[])content);
        this.loader = jwkSetPathUri == null ? new FilePkcJwkSetLoader(realmConfig, threadPool, jwkSetPath, listener) : new UrlPkcJwkSetLoader(realmConfig, threadPool, jwkSetPathUri, JwtUtil.createHttpClient(realmConfig, sslService), listener);
        this.reloadNotifier = reloadNotifier;
        try {
            PlainActionFuture future = new PlainActionFuture();
            this.reload((ActionListener<Void>)future);
            future.actionGet();
        }
        catch (Throwable t) {
            this.close();
            throw t;
        }
        if (((Boolean)realmConfig.getSetting(JwtRealmSettings.PKC_JWKSET_RELOAD_ENABLED)).booleanValue()) {
            this.loader.start();
        }
    }

    void reload(ActionListener<Void> listener) {
        ListenableFuture<Void> future = this.getFuture();
        future.addListener(listener);
    }

    ContentAndJwksAlgs getContentAndJwksAlgs() {
        return this.contentAndJwksAlgs;
    }

    ListenableFuture<Void> getFuture() {
        ListenableFuture newFuture;
        do {
            ListenableFuture<Void> existingFuture;
            if ((existingFuture = this.reloadFutureRef.get()) == null) continue;
            return existingFuture;
        } while (!this.reloadFutureRef.compareAndSet(null, (ListenableFuture<Void>)(newFuture = new ListenableFuture())));
        this.loadInternal((ActionListener<Void>)ActionListener.runBefore((ActionListener)newFuture, () -> {
            ListenableFuture oldValue = this.reloadFutureRef.getAndSet(null);
            assert (oldValue == newFuture) : "future reference changed unexpectedly";
        }));
        return newFuture;
    }

    void loadInternal(ActionListener<Void> listener) {
        this.loader.load((ActionListener<byte[]>)ActionListener.wrap(content -> {
            this.handleReloadedContentAndJwksAlgs((byte[])content);
            listener.onResponse(null);
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    private void handleReloadedContentAndJwksAlgs(byte[] bytes) {
        ContentAndJwksAlgs newContentAndJwksAlgs = this.parseContent(bytes);
        assert (this.contentAndJwksAlgs != null);
        if (!Arrays.equals(this.contentAndJwksAlgs.sha256, newContentAndJwksAlgs.sha256)) {
            logger.info("Reloaded JWK set from sha256=[{}] to sha256=[{}]", (Object)MessageDigests.toHexString((byte[])this.contentAndJwksAlgs.sha256), (Object)MessageDigests.toHexString((byte[])newContentAndJwksAlgs.sha256));
            this.contentAndJwksAlgs = newContentAndJwksAlgs;
            this.reloadNotifier.reloaded();
        }
    }

    private ContentAndJwksAlgs parseContent(byte[] jwkSetContentBytesPkc) {
        String jwkSetContentsPkc = new String(jwkSetContentBytesPkc, StandardCharsets.UTF_8);
        byte[] jwkSetContentsPkcSha256 = JwtUtil.sha256(jwkSetContentsPkc);
        List<JWK> jwksPkc = JwkValidateUtil.loadJwksFromJwkSetString(RealmSettings.getFullSettingKey((RealmConfig)this.realmConfig, (Setting.AffixSetting)JwtRealmSettings.PKC_JWKSET_PATH), jwkSetContentsPkc);
        JwksAlgs jwksAlgsPkc = JwkValidateUtil.filterJwksAndAlgorithms(jwksPkc, this.allowedJwksAlgsPkc);
        logger.debug("Usable PKC: JWKs=[{}] algorithms=[{}] sha256=[{}]", (Object)jwksAlgsPkc.jwks().size(), (Object)String.join((CharSequence)",", jwksAlgsPkc.algs()), (Object)MessageDigests.toHexString((byte[])jwkSetContentsPkcSha256));
        return new ContentAndJwksAlgs(jwkSetContentsPkcSha256, jwksAlgsPkc);
    }

    public void close() {
        this.loader.stop();
    }

    static TimeValue calculateNextUrlReload(TimeValue minVal, TimeValue maxVal, @Nullable Instant targetTime, double maxJitterPct) {
        if (targetTime == null) {
            return minVal;
        }
        Duration min = Duration.ofSeconds(minVal.seconds());
        Duration max = Duration.ofSeconds(maxVal.seconds());
        Duration target = Duration.between(Instant.now(), targetTime);
        if (target.compareTo(min) < 0) {
            return minVal;
        }
        if (target.compareTo(max) > 0) {
            return maxVal;
        }
        Duration jitter = JwkSetLoader.jitterSeconds(target, maxJitterPct);
        return TimeValue.timeValueSeconds((long)(target.toSeconds() + jitter.toSeconds()));
    }

    static Duration jitterSeconds(Duration duration, double maxJitterPct) {
        long maxJitter = (long)((double)duration.toSeconds() * maxJitterPct);
        long jitter = Randomness.get().nextLong(-maxJitter, maxJitter + 1L);
        return Duration.ofSeconds(jitter);
    }

    record ContentAndJwksAlgs(byte[] sha256, JwksAlgs jwksAlgs) {
        ContentAndJwksAlgs {
            Objects.requireNonNull(jwksAlgs, "Filters JWKs and Algs must not be null");
        }
    }

    record JwksAlgs(List<JWK> jwks, List<String> algs) {
        JwksAlgs {
            Objects.requireNonNull(jwks, "JWKs must not be null");
            Objects.requireNonNull(algs, "Algs must not be null");
        }

        boolean isEmpty() {
            return this.jwks.isEmpty() && this.algs.isEmpty();
        }
    }

    static class FilePkcJwkSetLoader
    implements PkcJwkSetLoader {
        private final RealmConfig realmConfig;
        private final String jwkSetPath;
        private final ThreadPool threadPool;
        private final Consumer<byte[]> listener;
        private volatile boolean closed = false;
        private Scheduler.Cancellable task;

        FilePkcJwkSetLoader(RealmConfig realmConfig, ThreadPool threadPool, String jwkSetPath, Consumer<byte[]> listener) {
            this.realmConfig = realmConfig;
            this.jwkSetPath = jwkSetPath;
            this.threadPool = threadPool;
            this.listener = listener;
        }

        @Override
        public void start() {
            FileChangeWatcher fileWatcher = new FileChangeWatcher(JwtUtil.resolvePath(this.realmConfig.env(), this.jwkSetPath));
            TimeValue reloadInterval = (TimeValue)this.realmConfig.getSetting(JwtRealmSettings.PKC_JWKSET_RELOAD_FILE_INTERVAL);
            this.task = this.threadPool.scheduleWithFixedDelay(() -> this.reload(fileWatcher), reloadInterval, (Executor)this.threadPool.generic());
        }

        void reload(FileChangeWatcher fileWatcher) {
            if (this.closed) {
                logger.debug("Skipping file reload because loader is closed");
                return;
            }
            try {
                if (!fileWatcher.changedSinceLastCall()) {
                    logger.debug("No changes detected in PKC JWK set file [{}], aborting", (Object)this.jwkSetPath);
                    return;
                }
                logger.debug("Detected change in PKC JWK set file [{}], reloading", (Object)this.jwkSetPath);
                this.load((ActionListener<byte[]>)ActionListener.wrap(this.listener::accept, e -> logger.warn("Failed to reload PKC JWK set file [" + this.jwkSetPath + "]", (Throwable)e)));
            }
            catch (Exception e2) {
                logger.warn("Failed to check for changes in PKC JWK set file [" + this.jwkSetPath + "]", (Throwable)e2);
            }
        }

        @Override
        public void load(ActionListener<byte[]> listener) {
            try {
                logger.trace("Loading PKC JWKs from path [{}]", (Object)this.jwkSetPath);
                byte[] reloadedBytes = JwtUtil.readFileContents(RealmSettings.getFullSettingKey((RealmConfig)this.realmConfig, (Setting.AffixSetting)JwtRealmSettings.PKC_JWKSET_PATH), this.jwkSetPath, this.realmConfig.env());
                listener.onResponse((Object)reloadedBytes);
            }
            catch (Exception e) {
                listener.onFailure(e);
            }
        }

        @Override
        public void stop() {
            this.closed = true;
            if (this.task != null) {
                this.task.cancel();
            }
        }
    }

    static class UrlPkcJwkSetLoader
    implements PkcJwkSetLoader {
        private final RealmConfig realmConfig;
        private final ThreadPool threadPool;
        private final URI jwkSetPathUri;
        private final Consumer<byte[]> listener;
        private final CloseableHttpAsyncClient httpClient;
        private final TimeValue reloadIntervalMin;
        private final TimeValue reloadIntervalMax;
        private volatile boolean closed = false;
        private Scheduler.Cancellable task;

        UrlPkcJwkSetLoader(RealmConfig realmConfig, ThreadPool threadPool, URI jwkSetPathUri, CloseableHttpAsyncClient httpClient, Consumer<byte[]> listener) {
            this.realmConfig = realmConfig;
            this.threadPool = threadPool;
            this.jwkSetPathUri = jwkSetPathUri;
            this.listener = listener;
            this.httpClient = httpClient;
            this.reloadIntervalMin = (TimeValue)realmConfig.getSetting(JwtRealmSettings.PKC_JWKSET_RELOAD_URL_INTERVAL_MIN);
            this.reloadIntervalMax = (TimeValue)realmConfig.getSetting(JwtRealmSettings.PKC_JWKSET_RELOAD_URL_INTERVAL_MAX);
            if (this.reloadIntervalMax.compareTo(this.reloadIntervalMin) < 0) {
                throw new SettingsException("Invalid PKC JWK set URL reload intervals: max [" + String.valueOf(this.reloadIntervalMax) + "] is less than min [" + String.valueOf(this.reloadIntervalMin) + "]");
            }
        }

        @Override
        public void start() {
            this.scheduleReload(this.reloadIntervalMin);
        }

        void scheduleReload(TimeValue period) {
            if (this.closed) {
                logger.debug("Skipping reload schedule because loader is closed");
                return;
            }
            this.task = this.threadPool.schedule(this::reload, period, (Executor)this.threadPool.generic());
        }

        void reload() {
            this.doLoad((ActionListener<JwtUtil.JwksResponse>)ActionListener.wrap(res -> {
                Instant targetTime = res.expires() != null ? res.expires() : (res.maxAgeSeconds() != null ? Instant.now().plusSeconds(res.maxAgeSeconds().intValue()) : null);
                TimeValue period = JwkSetLoader.calculateNextUrlReload(this.reloadIntervalMin, this.reloadIntervalMax, targetTime, 0.01);
                logger.debug("Successfully reloaded PKC JWK set from HTTPS URI [{}], reload delay is [{}]", (Object)this.jwkSetPathUri, (Object)period);
                this.listener.accept(res.content());
                this.scheduleReload(period);
            }, e -> {
                logger.warn("Failed to reload PKC JWK set from HTTPS URI [" + String.valueOf(this.jwkSetPathUri) + "]", (Throwable)e);
                this.scheduleReload(this.reloadIntervalMin);
            }));
        }

        @Override
        public void load(ActionListener<byte[]> listener) {
            this.doLoad((ActionListener<JwtUtil.JwksResponse>)listener.map(JwtUtil.JwksResponse::content));
        }

        private void doLoad(ActionListener<JwtUtil.JwksResponse> listener) {
            logger.trace("Loading PKC JWKs from https URI [{}]", (Object)this.jwkSetPathUri);
            JwtUtil.readUriContents(RealmSettings.getFullSettingKey((RealmConfig)this.realmConfig, (Setting.AffixSetting)JwtRealmSettings.PKC_JWKSET_PATH), this.jwkSetPathUri, this.httpClient, listener);
        }

        @Override
        public void stop() {
            this.closed = true;
            if (this.task != null) {
                this.task.cancel();
            }
            try {
                this.httpClient.close();
            }
            catch (IOException e) {
                logger.warn(() -> "Exception closing HTTPS client for realm [" + this.realmConfig.name() + "]", (Throwable)e);
            }
        }
    }

    static interface PkcJwkSetLoader {
        public void load(ActionListener<byte[]> var1);

        public void start();

        public void stop();
    }

    static class FileChangeWatcher
    implements FileChangesListener {
        final FileWatcher fileWatcher;
        boolean changed = false;

        FileChangeWatcher(Path path) {
            this.fileWatcher = new PrivilegedFileWatcher(path);
            this.fileWatcher.addListener((Object)this);
        }

        boolean changedSinceLastCall() throws IOException {
            this.fileWatcher.checkAndNotify();
            boolean c = this.changed;
            this.changed = false;
            return c;
        }

        public void onFileChanged(Path file) {
            this.changed = true;
        }

        public void onFileInit(Path file) {
            this.changed = true;
        }
    }
}

