/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.elastic.authorization;

import java.util.Collection;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.Model;
import org.elasticsearch.persistent.AllocatedPersistentTask;
import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction;
import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
import org.elasticsearch.xpack.inference.services.elastic.authorization.PreconfiguredEndpointModelAdapter;

public class AuthorizationPoller
extends AllocatedPersistentTask {
    public static final String TASK_NAME = "eis-authorization-poller";
    private static final Logger logger = LogManager.getLogger(AuthorizationPoller.class);
    private final ServiceComponents serviceComponents;
    private final ModelRegistry modelRegistry;
    private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler;
    private final Sender sender;
    private final Runnable callback;
    private final AtomicReference<Scheduler.ScheduledCancellable> lastAuthTask = new AtomicReference<Object>(null);
    private final AtomicBoolean shutdown = new AtomicBoolean(false);
    private final ElasticInferenceServiceSettings elasticInferenceServiceSettings;
    private final AtomicBoolean initialized = new AtomicBoolean(false);
    private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;
    private final Client client;
    private final CountDownLatch receivedFirstAuthResponseLatch = new CountDownLatch(1);

    public static AuthorizationPoller create(TaskFields taskFields, Parameters parameters) {
        return new AuthorizationPoller(Objects.requireNonNull(taskFields), Objects.requireNonNull(parameters));
    }

    private AuthorizationPoller(TaskFields taskFields, Parameters parameters) {
        this(taskFields, parameters.serviceComponents, parameters.authorizationRequestHandler, parameters.sender, parameters.elasticInferenceServiceSettings, parameters.modelRegistry, parameters.client, null);
    }

    AuthorizationPoller(TaskFields taskFields, ServiceComponents serviceComponents, ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, Sender sender, ElasticInferenceServiceSettings elasticInferenceServiceSettings, ModelRegistry modelRegistry, Client client, Runnable callback) {
        super(taskFields.id, taskFields.type, taskFields.action, taskFields.description, taskFields.parentTask, taskFields.headers);
        this.serviceComponents = Objects.requireNonNull(serviceComponents);
        this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler);
        this.sender = Objects.requireNonNull(sender);
        this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings);
        this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents(elasticInferenceServiceSettings.getElasticInferenceServiceUrl());
        this.modelRegistry = Objects.requireNonNull(modelRegistry);
        this.client = new OriginSettingClient(Objects.requireNonNull(client), "inference");
        this.callback = callback;
    }

    public void start() {
        if (this.initialized.compareAndSet(false, true)) {
            logger.debug("Initializing EIS authorization logic");
            this.serviceComponents.threadPool().executor("inference_utility").execute(this::scheduleAndSendAuthorizationRequest);
        }
    }

    public void waitForAuthorizationToComplete(TimeValue waitTime) {
        try {
            if (!this.receivedFirstAuthResponseLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS)) {
                throw new IllegalStateException("The wait time has expired for first authorization response to be received.");
            }
        }
        catch (InterruptedException e) {
            throw new IllegalStateException("Waiting for first authorization response to complete was interrupted");
        }
    }

    protected void init(PersistentTasksService persistentTasksService, TaskManager taskManager, String persistentTaskId, long allocationId) {
        super.init(persistentTasksService, taskManager, persistentTaskId, allocationId);
    }

    protected void onCancelled() {
        this.shutdown();
        this.markAsCompleted();
    }

    private void shutdownAndMarkTaskAsFailed(Exception e) {
        this.shutdown();
        this.markAsFailed(e);
    }

    void shutdown() {
        this.shutdown.set(true);
        Scheduler.ScheduledCancellable authTask = this.lastAuthTask.get();
        if (authTask != null) {
            authTask.cancel();
        }
    }

    boolean isShutdown() {
        return this.shutdown.get();
    }

    private void scheduleAuthorizationRequest() {
        try {
            if (!this.elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled()) {
                return;
            }
            Random random = Randomness.get();
            long jitter = (long)((double)this.elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble());
            TimeValue waitTime = TimeValue.timeValueMillis((long)(this.elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter));
            logger.debug(() -> Strings.format((String)"Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", (Object[])new Object[]{this.elasticInferenceServiceSettings.getAuthRequestInterval().millis(), jitter}));
            logger.debug(() -> Strings.format((String)"Next authorization call in %d minutes", (Object[])new Object[]{waitTime.getMinutes()}));
            this.lastAuthTask.set(this.serviceComponents.threadPool().schedule(this::scheduleAndSendAuthorizationRequest, waitTime, (Executor)this.serviceComponents.threadPool().executor("inference_utility")));
        }
        catch (Exception e) {
            logger.warn("Failed scheduling authorization request", (Throwable)e);
            this.shutdownAndMarkTaskAsFailed(e);
        }
    }

    private void scheduleAndSendAuthorizationRequest() {
        if (this.shutdown.get()) {
            return;
        }
        this.scheduleAuthorizationRequest();
        this.sendAuthorizationRequest();
    }

    void sendAuthorizationRequest() {
        if (!this.modelRegistry.isReady()) {
            return;
        }
        ActionListener finalListener = ActionListener.running(() -> {
            if (this.callback != null) {
                this.callback.run();
            }
            this.receivedFirstAuthResponseLatch.countDown();
        }).delegateResponse((delegate, e) -> {
            logger.atWarn().withThrowable((Throwable)e).log("Failed processing EIS preconfigured endpoints");
            delegate.onResponse(null);
        });
        SubscribableListener.newForked(authModelListener -> this.authorizationHandler.getAuthorization((ActionListener<ElasticInferenceServiceAuthorizationModel>)authModelListener, this.sender)).andThenApply(this::getNewInferenceEndpointsToStore).andThen((storeListener, newInferenceIds) -> this.storePreconfiguredModels((Set<String>)newInferenceIds, (ActionListener<Void>)storeListener)).addListener(finalListener);
    }

    private Set<String> getNewInferenceEndpointsToStore(ElasticInferenceServiceAuthorizationModel authModel) {
        ElasticInferenceServiceAuthorizationModel scopedAuthModel = authModel.newLimitedToTaskTypes(EnumSet.copyOf(ElasticInferenceService.IMPLEMENTED_TASK_TYPES));
        Set<String> authorizedModelIds = scopedAuthModel.getAuthorizedModelIds();
        Set<String> existingInferenceIds = this.modelRegistry.getInferenceIds();
        Set<String> newInferenceIds = authorizedModelIds.stream().map(InternalPreconfiguredEndpoints::getWithModelName).flatMap(Collection::stream).map(model -> model.configurations().getInferenceEntityId()).collect(Collectors.toSet());
        newInferenceIds.removeAll(existingInferenceIds);
        return newInferenceIds;
    }

    private void storePreconfiguredModels(Set<String> newInferenceIds, ActionListener<Void> listener) {
        if (newInferenceIds.isEmpty()) {
            listener.onResponse(null);
            return;
        }
        logger.info("Storing new EIS preconfigured inference endpoints with inference IDs {}", newInferenceIds);
        List<Model> modelsToAdd = PreconfiguredEndpointModelAdapter.getModels(newInferenceIds, this.elasticInferenceServiceComponents);
        StoreInferenceEndpointsAction.Request storeRequest = new StoreInferenceEndpointsAction.Request(modelsToAdd, TimeValue.THIRTY_SECONDS);
        ActionListener logResultsListener = ActionListener.wrap(responses -> {
            for (ModelStoreResponse response : responses.getResults()) {
                if (response.failed()) {
                    logger.atWarn().withThrowable((Throwable)response.failureCause()).log("Failed to store new EIS preconfigured inference endpoint with inference ID [{}]", (Object)response.inferenceId());
                    continue;
                }
                logger.atInfo().log("Successfully stored EIS preconfigured inference endpoint with inference ID [{}]", (Object)response.inferenceId());
            }
        }, e -> logger.atWarn().withThrowable((Throwable)e).log("Failed to store new EIS preconfigured inference endpoints [{}]", (Object)newInferenceIds));
        this.client.execute((ActionType)StoreInferenceEndpointsAction.INSTANCE, (ActionRequest)storeRequest, ActionListener.runAfter((ActionListener)logResultsListener, () -> listener.onResponse(null)));
    }

    public record TaskFields(long id, String type, String action, String description, TaskId parentTask, Map<String, String> headers) {
    }

    public record Parameters(ServiceComponents serviceComponents, ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, Sender sender, ElasticInferenceServiceSettings elasticInferenceServiceSettings, ModelRegistry modelRegistry, Client client) {
    }
}

