/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.security.authc.jwt;

import com.nimbusds.jose.jwk.JWK;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.hash.MessageDigests;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.xpack.core.security.authc.RealmConfig;
import org.elasticsearch.xpack.core.security.authc.RealmSettings;
import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
import org.elasticsearch.xpack.core.ssl.SSLService;
import org.elasticsearch.xpack.security.authc.jwt.JwkValidateUtil;
import org.elasticsearch.xpack.security.authc.jwt.JwtUtil;

public class JwkSetLoader
implements Releasable {
    private static final Logger logger = LogManager.getLogger(JwkSetLoader.class);
    private final AtomicReference<ListenableFuture<Void>> reloadFutureRef = new AtomicReference();
    private final RealmConfig realmConfig;
    private final List<String> allowedJwksAlgsPkc;
    private final String jwkSetPath;
    @Nullable
    private final URI jwkSetPathUri;
    @Nullable
    private final CloseableHttpAsyncClient httpClient;
    private volatile ContentAndJwksAlgs contentAndJwksAlgs = new ContentAndJwksAlgs(new byte[32], new JwksAlgs(Collections.emptyList(), Collections.emptyList()));

    public JwkSetLoader(RealmConfig realmConfig, List<String> allowedJwksAlgsPkc, SSLService sslService) {
        this.realmConfig = realmConfig;
        this.allowedJwksAlgsPkc = allowedJwksAlgsPkc;
        this.jwkSetPath = (String)realmConfig.getSetting(JwtRealmSettings.PKC_JWKSET_PATH);
        assert (Strings.hasText((String)this.jwkSetPath));
        this.jwkSetPathUri = JwtUtil.parseHttpsUri(this.jwkSetPath);
        this.httpClient = this.jwkSetPathUri == null ? null : JwtUtil.createHttpClient(realmConfig, sslService);
        try {
            PlainActionFuture future = new PlainActionFuture();
            this.reload((ActionListener<Void>)future);
            future.actionGet();
        }
        catch (Throwable t) {
            this.close();
            throw t;
        }
    }

    void reload(ActionListener<Void> listener) {
        ListenableFuture<Void> future = this.getFuture();
        future.addListener(listener);
    }

    ContentAndJwksAlgs getContentAndJwksAlgs() {
        return this.contentAndJwksAlgs;
    }

    ListenableFuture<Void> getFuture() {
        ListenableFuture newFuture;
        do {
            ListenableFuture<Void> existingFuture;
            if ((existingFuture = this.reloadFutureRef.get()) == null) continue;
            return existingFuture;
        } while (!this.reloadFutureRef.compareAndSet(null, (ListenableFuture<Void>)(newFuture = new ListenableFuture())));
        this.loadInternal((ActionListener<Void>)ActionListener.runBefore((ActionListener)newFuture, () -> {
            ListenableFuture oldValue = this.reloadFutureRef.getAndSet(null);
            assert (oldValue == newFuture) : "future reference changed unexpectedly";
        }));
        return newFuture;
    }

    void loadInternal(ActionListener<Void> listener) {
        if (this.httpClient == null) {
            logger.trace("Loading PKC JWKs from path [{}]", (Object)this.jwkSetPath);
            byte[] reloadedBytes2 = JwtUtil.readFileContents(RealmSettings.getFullSettingKey((RealmConfig)this.realmConfig, (Setting.AffixSetting)JwtRealmSettings.PKC_JWKSET_PATH), this.jwkSetPath, this.realmConfig.env());
            this.handleReloadedContentAndJwksAlgs(reloadedBytes2);
            listener.onResponse(null);
        } else {
            logger.trace("Loading PKC JWKs from https URI [{}]", (Object)this.jwkSetPathUri);
            JwtUtil.readUriContents(RealmSettings.getFullSettingKey((RealmConfig)this.realmConfig, (Setting.AffixSetting)JwtRealmSettings.PKC_JWKSET_PATH), this.jwkSetPathUri, this.httpClient, (ActionListener<byte[]>)listener.map(reloadedBytes -> {
                logger.trace("Loaded bytes [{}] from [{}]", (Object)((byte[])reloadedBytes).length, (Object)this.jwkSetPathUri);
                this.handleReloadedContentAndJwksAlgs((byte[])reloadedBytes);
                return null;
            }));
        }
    }

    private void handleReloadedContentAndJwksAlgs(byte[] bytes) {
        ContentAndJwksAlgs newContentAndJwksAlgs = this.parseContent(bytes);
        assert (newContentAndJwksAlgs != null);
        assert (this.contentAndJwksAlgs != null);
        if (!Arrays.equals(this.contentAndJwksAlgs.sha256, newContentAndJwksAlgs.sha256)) {
            logger.info("Reloaded JWK set from sha256=[{}] to sha256=[{}]", (Object)MessageDigests.toHexString((byte[])this.contentAndJwksAlgs.sha256), (Object)MessageDigests.toHexString((byte[])newContentAndJwksAlgs.sha256));
            this.contentAndJwksAlgs = newContentAndJwksAlgs;
        }
    }

    private ContentAndJwksAlgs parseContent(byte[] jwkSetContentBytesPkc) {
        String jwkSetContentsPkc = new String(jwkSetContentBytesPkc, StandardCharsets.UTF_8);
        byte[] jwkSetContentsPkcSha256 = JwtUtil.sha256(jwkSetContentsPkc);
        List<JWK> jwksPkc = JwkValidateUtil.loadJwksFromJwkSetString(RealmSettings.getFullSettingKey((RealmConfig)this.realmConfig, (Setting.AffixSetting)JwtRealmSettings.PKC_JWKSET_PATH), jwkSetContentsPkc);
        JwksAlgs jwksAlgsPkc = JwkValidateUtil.filterJwksAndAlgorithms(jwksPkc, this.allowedJwksAlgsPkc);
        logger.debug("Usable PKC: JWKs=[{}] algorithms=[{}] sha256=[{}]", (Object)jwksAlgsPkc.jwks().size(), (Object)String.join((CharSequence)",", jwksAlgsPkc.algs()), (Object)MessageDigests.toHexString((byte[])jwkSetContentsPkcSha256));
        return new ContentAndJwksAlgs(jwkSetContentsPkcSha256, jwksAlgsPkc);
    }

    public void close() {
        if (this.httpClient != null) {
            try {
                this.httpClient.close();
            }
            catch (IOException e) {
                logger.warn(() -> "Exception closing HTTPS client for realm [" + this.realmConfig.name() + "]", (Throwable)e);
            }
        }
    }

    record ContentAndJwksAlgs(byte[] sha256, JwksAlgs jwksAlgs) {
        ContentAndJwksAlgs {
            Objects.requireNonNull(jwksAlgs, "Filters JWKs and Algs must not be null");
        }
    }

    record JwksAlgs(List<JWK> jwks, List<String> algs) {
        JwksAlgs {
            Objects.requireNonNull(jwks, "JWKs must not be null");
            Objects.requireNonNull(algs, "Algs must not be null");
        }

        boolean isEmpty() {
            return this.jwks.isEmpty() && this.algs.isEmpty();
        }
    }
}

