/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.security.action.saml;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.function.Predicate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.security.action.saml.SamlInvalidateSessionRequest;
import org.elasticsearch.xpack.core.security.action.saml.SamlInvalidateSessionResponse;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import org.elasticsearch.xpack.security.authc.Realms;
import org.elasticsearch.xpack.security.authc.TokenService;
import org.elasticsearch.xpack.security.authc.saml.SamlLogoutRequestHandler;
import org.elasticsearch.xpack.security.authc.saml.SamlRealm;
import org.elasticsearch.xpack.security.authc.saml.SamlRedirect;
import org.elasticsearch.xpack.security.authc.saml.SamlUtils;
import org.opensaml.saml.saml2.core.LogoutResponse;
import org.opensaml.saml.saml2.core.StatusResponseType;

public final class TransportSamlInvalidateSessionAction
extends HandledTransportAction<SamlInvalidateSessionRequest, SamlInvalidateSessionResponse> {
    private static final Logger LOGGER = LogManager.getLogger(TransportSamlInvalidateSessionAction.class);
    private final TokenService tokenService;
    private final Realms realms;
    private final Executor genericExecutor;

    @Inject
    public TransportSamlInvalidateSessionAction(TransportService transportService, ActionFilters actionFilters, TokenService tokenService, Realms realms) {
        super("cluster:admin/xpack/security/saml/invalidate", transportService, actionFilters, SamlInvalidateSessionRequest::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.tokenService = tokenService;
        this.realms = realms;
        this.genericExecutor = transportService.getThreadPool().generic();
    }

    protected void doExecute(Task task, SamlInvalidateSessionRequest request, ActionListener<SamlInvalidateSessionResponse> listener) {
        this.genericExecutor.execute((Runnable)ActionRunnable.wrap(listener, l -> this.doExecuteForked(task, request, (ActionListener<SamlInvalidateSessionResponse>)l)));
    }

    private void doExecuteForked(Task task, SamlInvalidateSessionRequest request, ActionListener<SamlInvalidateSessionResponse> listener) {
        assert (ThreadPool.assertCurrentThreadPool((String[])new String[]{"generic"}));
        List<SamlRealm> realms = SamlRealm.findSamlRealms(this.realms, request.getRealmName(), request.getAssertionConsumerServiceURL());
        if (realms.isEmpty()) {
            listener.onFailure((Exception)SamlUtils.samlException("Cannot find any matching realm for [{}]", request));
        } else if (realms.size() > 1) {
            listener.onFailure((Exception)SamlUtils.samlException("Found multiple matching realms [{}] for [{}]", realms, request));
        } else {
            this.invalidateSession(realms.get(0), request, listener);
        }
    }

    private void invalidateSession(SamlRealm realm, SamlInvalidateSessionRequest request, ActionListener<SamlInvalidateSessionResponse> listener) {
        try {
            SamlLogoutRequestHandler.Result result = realm.getLogoutHandler().parseFromQueryString(request.getQueryString());
            this.findAndInvalidateTokens(realm, result, (ActionListener<Integer>)ActionListener.wrap(count -> listener.onResponse((Object)new SamlInvalidateSessionResponse(realm.name(), count.intValue(), TransportSamlInvalidateSessionAction.buildLogoutResponseUrl(realm, result))), arg_0 -> listener.onFailure(arg_0)));
        }
        catch (ElasticsearchSecurityException e) {
            LOGGER.info("Failed to invalidate SAML session", (Throwable)e);
            listener.onFailure((Exception)((Object)e));
        }
    }

    private static String buildLogoutResponseUrl(SamlRealm realm, SamlLogoutRequestHandler.Result result) {
        LogoutResponse response = realm.buildLogoutResponse(result.getRequestId());
        return new SamlRedirect((StatusResponseType)response, realm.getSigningConfiguration()).getRedirectUrl(result.getRelayState());
    }

    private void findAndInvalidateTokens(SamlRealm realm, SamlLogoutRequestHandler.Result result, ActionListener<Integer> listener) {
        Map<String, Object> tokenMetadata = realm.createTokenMetadata(result.getNameId(), result.getSession());
        if (Strings.isNullOrEmpty((String)((String)tokenMetadata.get("saml_nameid_val")))) {
            LOGGER.debug("Logout request [{}] has no NameID value, so cannot invalidate any sessions", (Object)result);
            listener.onResponse((Object)0);
            return;
        }
        this.tokenService.invalidateActiveTokens(realm.name(), null, TransportSamlInvalidateSessionAction.containsMetadata(tokenMetadata), (ActionListener<TokensInvalidationResult>)ActionListener.wrap(tokensInvalidationResult -> {
            if (LOGGER.isInfoEnabled() && !tokensInvalidationResult.getErrors().isEmpty()) {
                try (XContentBuilder builder = XContentBuilder.builder((XContent)XContentType.JSON.xContent());){
                    tokensInvalidationResult.toXContent(builder, ToXContent.EMPTY_PARAMS);
                    LOGGER.info("Failed to invalidate some SAML access or refresh tokens {}", (Object)Strings.toString((XContentBuilder)builder));
                }
            }
            int totalTokensFound = tokensInvalidationResult.getInvalidatedTokens().size() + tokensInvalidationResult.getPreviouslyInvalidatedTokens().size() + tokensInvalidationResult.getErrors().size();
            listener.onResponse((Object)totalTokensFound);
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    private static Predicate<Map<String, Object>> containsMetadata(Map<String, Object> requiredMetadata) {
        return source -> {
            Map actualMetadata = (Map)source.get("metadata");
            return requiredMetadata.entrySet().stream().allMatch(e -> Objects.equals(actualMetadata.get(e.getKey()), e.getValue()));
        };
    }
}

