/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.elastic.authorization;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;

public class ElasticInferenceServiceAuthorizationHandler
implements Closeable {
    private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandler.class);
    private final ServiceComponents serviceComponents;
    private final AtomicReference<AuthorizedContent> authorizedContent = new AtomicReference<AuthorizedContent>(AuthorizedContent.empty());
    private final ModelRegistry modelRegistry;
    private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler;
    private final Map<String, DefaultModelConfig> defaultModelsConfigs;
    private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1);
    private final EnumSet<TaskType> implementedTaskTypes;
    private final InferenceService inferenceService;
    private final Sender sender;
    private final Runnable callback;
    private final AtomicReference<Scheduler.ScheduledCancellable> lastAuthTask = new AtomicReference<Object>(null);
    private final AtomicBoolean shutdown = new AtomicBoolean(false);
    private final ElasticInferenceServiceSettings elasticInferenceServiceSettings;

    public ElasticInferenceServiceAuthorizationHandler(ServiceComponents serviceComponents, ModelRegistry modelRegistry, ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, Map<String, DefaultModelConfig> defaultModelsConfigs, EnumSet<TaskType> implementedTaskTypes, InferenceService inferenceService, Sender sender, ElasticInferenceServiceSettings elasticInferenceServiceSettings) {
        this(serviceComponents, modelRegistry, authorizationRequestHandler, defaultModelsConfigs, implementedTaskTypes, Objects.requireNonNull(inferenceService), sender, elasticInferenceServiceSettings, null);
    }

    ElasticInferenceServiceAuthorizationHandler(ServiceComponents serviceComponents, ModelRegistry modelRegistry, ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, Map<String, DefaultModelConfig> defaultModelsConfigs, EnumSet<TaskType> implementedTaskTypes, InferenceService inferenceService, Sender sender, ElasticInferenceServiceSettings elasticInferenceServiceSettings, Runnable callback) {
        this.serviceComponents = Objects.requireNonNull(serviceComponents);
        this.modelRegistry = Objects.requireNonNull(modelRegistry);
        this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler);
        this.defaultModelsConfigs = Objects.requireNonNull(defaultModelsConfigs);
        this.implementedTaskTypes = Objects.requireNonNull(implementedTaskTypes);
        this.inferenceService = inferenceService;
        this.sender = Objects.requireNonNull(sender);
        this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings);
        this.callback = callback;
    }

    public void init() {
        logger.debug("Initializing authorization logic");
        this.serviceComponents.threadPool().executor("inference_utility").execute(this::scheduleAndSendAuthorizationRequest);
    }

    public void waitForAuthorizationToComplete(TimeValue waitTime) {
        try {
            if (!this.firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS)) {
                throw new IllegalStateException("The wait time has expired for authorization to complete.");
            }
        }
        catch (InterruptedException e) {
            throw new IllegalStateException("Waiting for authorization to complete was interrupted");
        }
    }

    public synchronized Set<TaskType> supportedStreamingTasks() {
        EnumSet<TaskType> authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION);
        authorizedStreamingTaskTypes.retainAll(this.authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes());
        return authorizedStreamingTaskTypes;
    }

    public synchronized List<InferenceService.DefaultConfigId> defaultConfigIds() {
        return this.authorizedContent.get().configIds;
    }

    public synchronized void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
        List<Model> models = this.authorizedContent.get().defaultModelConfigs.stream().map(DefaultModelConfig::model).toList();
        defaultsListener.onResponse(models);
    }

    public synchronized EnumSet<TaskType> supportedTaskTypes() {
        return this.authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes();
    }

    public synchronized boolean hideFromConfigurationApi() {
        return !this.authorizedContent.get().taskTypesAndModels.isAuthorized();
    }

    @Override
    public void close() throws IOException {
        this.shutdown.set(true);
        if (this.lastAuthTask.get() != null) {
            this.lastAuthTask.get().cancel();
        }
    }

    private void scheduleAuthorizationRequest() {
        try {
            if (!this.elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled()) {
                return;
            }
            Random random = Randomness.get();
            long jitter = (long)((double)this.elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble());
            TimeValue waitTime = TimeValue.timeValueMillis((long)(this.elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter));
            logger.debug(() -> Strings.format((String)"Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", (Object[])new Object[]{this.elasticInferenceServiceSettings.getAuthRequestInterval().millis(), jitter}));
            logger.debug(() -> Strings.format((String)"Next authorization call in %d minutes", (Object[])new Object[]{waitTime.getMinutes()}));
            this.lastAuthTask.set(this.serviceComponents.threadPool().schedule(this::scheduleAndSendAuthorizationRequest, waitTime, (Executor)this.serviceComponents.threadPool().executor("inference_utility")));
        }
        catch (Exception e) {
            logger.warn("Failed scheduling authorization request", (Throwable)e);
        }
    }

    private void scheduleAndSendAuthorizationRequest() {
        if (this.shutdown.get()) {
            return;
        }
        this.scheduleAuthorizationRequest();
        this.sendAuthorizationRequest();
    }

    private void sendAuthorizationRequest() {
        try {
            ActionListener listener = ActionListener.wrap(model -> {
                this.setAuthorizedContent((ElasticInferenceServiceAuthorizationModel)model);
                if (this.callback != null) {
                    this.callback.run();
                }
            }, e -> this.firstAuthorizationCompletedLatch.countDown());
            this.authorizationHandler.getAuthorization((ActionListener<ElasticInferenceServiceAuthorizationModel>)listener, this.sender);
        }
        catch (Exception e2) {
            logger.warn("Failure while sending the request to retrieve authorization", (Throwable)e2);
            this.firstAuthorizationCompletedLatch.countDown();
        }
    }

    private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) {
        logger.debug(() -> Strings.format((String)"Received authorization response, %s", (Object[])new Object[]{auth}));
        ElasticInferenceServiceAuthorizationModel authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(this.implementedTaskTypes));
        logger.debug(() -> Strings.format((String)"Authorization entity limited to service task types, %s", (Object[])new Object[]{authorizedTaskTypesAndModels}));
        Set<String> authorizedDefaultModelIds = this.getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels);
        List<InferenceService.DefaultConfigId> authorizedDefaultConfigIds = this.getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels);
        List<DefaultModelConfig> authorizedDefaultModelObjects = this.getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds);
        this.authorizedContent.set(new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects));
        this.authorizedContent.get().configIds().forEach(this.modelRegistry::putDefaultIdIfAbsent);
        this.handleRevokedDefaultConfigs(authorizedDefaultModelIds);
    }

    private Set<String> getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorizationModel auth) {
        Set<String> authorizedModels = auth.getAuthorizedModelIds();
        TreeSet<String> authorizedDefaultModelIds = new TreeSet<String>(this.defaultModelsConfigs.keySet());
        authorizedDefaultModelIds.retainAll(authorizedModels);
        return authorizedDefaultModelIds;
    }

    private List<InferenceService.DefaultConfigId> getAuthorizedDefaultConfigIds(Set<String> authorizedDefaultModelIds, ElasticInferenceServiceAuthorizationModel auth) {
        ArrayList<InferenceService.DefaultConfigId> authorizedConfigIds = new ArrayList<InferenceService.DefaultConfigId>();
        for (String id : authorizedDefaultModelIds) {
            DefaultModelConfig modelConfig = this.defaultModelsConfigs.get(id);
            if (modelConfig == null) continue;
            if (!auth.getAuthorizedTaskTypes().contains(modelConfig.model().getTaskType())) {
                logger.warn(Strings.format((String)"The authorization response included the default model: %s, but did not authorize the assumed task type of the model: %s. Enabling model.", (Object[])new Object[]{id, modelConfig.model().getTaskType()}));
            }
            authorizedConfigIds.add(new InferenceService.DefaultConfigId(modelConfig.model().getInferenceEntityId(), modelConfig.settings(), this.inferenceService));
        }
        authorizedConfigIds.sort(Comparator.comparing(InferenceService.DefaultConfigId::inferenceId));
        return authorizedConfigIds;
    }

    private List<DefaultModelConfig> getAuthorizedDefaultModelsObjects(Set<String> authorizedDefaultModelIds) {
        ArrayList<DefaultModelConfig> authorizedModels = new ArrayList<DefaultModelConfig>();
        for (String id : authorizedDefaultModelIds) {
            DefaultModelConfig modelConfig2 = this.defaultModelsConfigs.get(id);
            if (modelConfig2 == null) continue;
            authorizedModels.add(modelConfig2);
        }
        authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model().getInferenceEntityId()));
        return authorizedModels;
    }

    private void handleRevokedDefaultConfigs(Set<String> authorizedDefaultModelIds) {
        HashSet<String> unauthorizedDefaultModelIds = new HashSet<String>(this.defaultModelsConfigs.keySet());
        unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds);
        Set<String> unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream().map(this.defaultModelsConfigs::get).filter(Objects::nonNull).map(modelConfig -> modelConfig.model().getInferenceEntityId()).collect(Collectors.toSet());
        ActionListener deleteInferenceEndpointsListener = ActionListener.wrap(result -> {
            logger.debug(Strings.format((String)"Successfully revoked access to default inference endpoint IDs: %s", (Object[])new Object[]{unauthorizedDefaultModelIds}));
            this.firstAuthorizationCompletedLatch.countDown();
        }, e -> {
            logger.warn(Strings.format((String)"Failed to revoke access to default inference endpoint IDs: %s, error: %s", (Object[])new Object[]{unauthorizedDefaultModelIds, e}));
            this.firstAuthorizationCompletedLatch.countDown();
        });
        logger.debug(() -> Strings.format((String)"Synchronizing default inference endpoints, attempting to remove ids: %s", (Object[])new Object[]{unauthorizedDefaultInferenceEndpointIds}));
        this.modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, (ActionListener<Boolean>)deleteInferenceEndpointsListener);
    }

    private record AuthorizedContent(ElasticInferenceServiceAuthorizationModel taskTypesAndModels, List<InferenceService.DefaultConfigId> configIds, List<DefaultModelConfig> defaultModelConfigs) {
        static AuthorizedContent empty() {
            return new AuthorizedContent(ElasticInferenceServiceAuthorizationModel.newDisabledService(), List.of(), List.of());
        }
    }
}

