/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.project.ProjectResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.inference.InferenceLicenceCheck;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;

public class TransportUpdateInferenceModelAction
extends TransportMasterNodeAction<UpdateInferenceModelAction.Request, UpdateInferenceModelAction.Response> {
    private static final Logger logger = LogManager.getLogger(TransportUpdateInferenceModelAction.class);
    private final XPackLicenseState licenseState;
    private final ModelRegistry modelRegistry;
    private final InferenceServiceRegistry serviceRegistry;
    private final Client client;
    private final ProjectResolver projectResolver;

    @Inject
    public TransportUpdateInferenceModelAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, Client client, ProjectResolver projectResolver) {
        super("cluster:admin/xpack/inference/update", transportService, clusterService, threadPool, actionFilters, UpdateInferenceModelAction.Request::new, UpdateInferenceModelAction.Response::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.licenseState = licenseState;
        this.modelRegistry = modelRegistry;
        this.serviceRegistry = serviceRegistry;
        this.client = client;
        this.projectResolver = projectResolver;
    }

    protected void masterOperation(Task task, UpdateInferenceModelAction.Request request, ClusterState state, ActionListener<UpdateInferenceModelAction.Response> masterListener) {
        TaskType bodyTaskType = request.getContentAsSettings().taskType();
        TaskType resolvedTaskType = ServiceUtils.resolveTaskType(request.getTaskType(), bodyTaskType != null ? bodyTaskType.toString() : null);
        AtomicReference service = new AtomicReference();
        String inferenceEntityId = request.getInferenceEntityId();
        SubscribableListener.newForked(listener -> this.checkEndpointExists(inferenceEntityId, (ActionListener<UnparsedModel>)listener)).andThen((listener, unparsedModel) -> {
            Optional optionalService = this.serviceRegistry.getService(unparsedModel.service());
            if (optionalService.isEmpty()) {
                listener.onFailure((Exception)new ElasticsearchStatusException("Service [{}] not found", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{unparsedModel.service()}));
                return;
            }
            if (!InferenceLicenceCheck.isServiceLicenced(((InferenceService)optionalService.get()).name(), this.licenseState)) {
                listener.onFailure((Exception)InferenceLicenceCheck.complianceException(((InferenceService)optionalService.get()).name()));
                return;
            }
            service.set((InferenceService)optionalService.get());
            listener.onResponse(unparsedModel);
        }).andThen((listener, existingUnparsedModel) -> {
            Model existingParsedModel = ((InferenceService)service.get()).parsePersistedConfigWithSecrets(request.getInferenceEntityId(), existingUnparsedModel.taskType(), new HashMap(existingUnparsedModel.settings()), new HashMap(existingUnparsedModel.secrets()));
            Model newModel = this.combineExistingModelWithNewSettings(existingParsedModel, request.getContentAsSettings(), ((InferenceService)service.get()).name(), resolvedTaskType);
            if (this.isInClusterService(((InferenceService)service.get()).name())) {
                this.updateInClusterEndpoint(request, newModel, existingParsedModel, (ActionListener<Boolean>)listener);
            } else {
                this.modelRegistry.updateModelTransaction(newModel, existingParsedModel, (ActionListener<Boolean>)listener);
            }
        }).andThen((listener, didUpdate) -> {
            if (didUpdate.booleanValue()) {
                this.modelRegistry.getModel(inferenceEntityId, (ActionListener<UnparsedModel>)ActionListener.wrap(unparsedModel -> {
                    if (unparsedModel == null) {
                        listener.onFailure((Exception)new ElasticsearchStatusException("Failed to update model, updated model not found", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]));
                    } else {
                        listener.onResponse((Object)((InferenceService)service.get()).parsePersistedConfig(request.getInferenceEntityId(), resolvedTaskType, new HashMap(unparsedModel.settings())).getConfigurations());
                    }
                }, arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
            } else {
                listener.onFailure((Exception)new ElasticsearchStatusException("Failed to update model", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]));
            }
        }).andThen((listener, modelConfig) -> listener.onResponse((Object)new UpdateInferenceModelAction.Response(modelConfig))).addListener(masterListener);
    }

    private Model combineExistingModelWithNewSettings(Model existingParsedModel, UpdateInferenceModelAction.Settings settingsToUpdate, String serviceName, TaskType resolvedTaskType) {
        SecretSettings existingSecretSettings;
        ModelConfigurations existingConfigs = existingParsedModel.getConfigurations();
        TaskSettings existingTaskSettings = existingConfigs.getTaskSettings();
        SecretSettings newSecretSettings = existingSecretSettings = existingParsedModel.getSecretSettings();
        TaskSettings newTaskSettings = existingTaskSettings;
        ServiceSettings newServiceSettings = existingConfigs.getServiceSettings();
        if (settingsToUpdate.serviceSettings() != null && existingSecretSettings != null) {
            newSecretSettings = existingSecretSettings.newSecretSettings(settingsToUpdate.serviceSettings());
        }
        if (settingsToUpdate.serviceSettings() != null && newServiceSettings instanceof ElasticsearchInternalServiceSettings) {
            ElasticsearchInternalServiceSettings elasticServiceSettings = (ElasticsearchInternalServiceSettings)newServiceSettings;
            newServiceSettings = elasticServiceSettings.updateServiceSettings(settingsToUpdate.serviceSettings());
        }
        if (settingsToUpdate.taskSettings() != null && existingTaskSettings != null) {
            newTaskSettings = existingTaskSettings.updatedTaskSettings(settingsToUpdate.taskSettings());
        }
        if (!existingParsedModel.getTaskType().equals((Object)resolvedTaskType)) {
            throw new ElasticsearchStatusException("Task type must match the task type of the existing endpoint", RestStatus.BAD_REQUEST, new Object[0]);
        }
        ModelConfigurations newModelConfigs = new ModelConfigurations(existingParsedModel.getInferenceEntityId(), existingParsedModel.getTaskType(), serviceName, newServiceSettings, newTaskSettings);
        return new Model(newModelConfigs, new ModelSecrets(newSecretSettings));
    }

    private void updateInClusterEndpoint(UpdateInferenceModelAction.Request request, Model newModel, Model existingParsedModel, ActionListener<Boolean> listener) {
        String deploymentId = this.getDeploymentIdForInClusterEndpoint(existingParsedModel);
        String inferenceEntityId = request.getInferenceEntityId();
        this.throwIfTrainedModelDoesntExist(inferenceEntityId, deploymentId);
        if (!inferenceEntityId.equals(deploymentId)) {
            this.modelRegistry.getModel(deploymentId, (ActionListener<UnparsedModel>)ActionListener.wrap(unparsedModel -> listener.onFailure((Exception)new ElasticsearchStatusException("Cannot update inference endpoint [{}] for model deployment [{}] as it was created by another inference endpoint. The model can only be updated using inference endpoint id [{}].", RestStatus.CONFLICT, new Object[]{inferenceEntityId, deploymentId, unparsedModel.inferenceEntityId()})), e -> {
                if (e instanceof ResourceNotFoundException) {
                    listener.onFailure((Exception)new ElasticsearchStatusException("Cannot update inference endpoint [{}] using model deployment [{}]. The model deployment must be updated through the trained models API.", RestStatus.CONFLICT, new Object[]{inferenceEntityId, deploymentId}));
                    return;
                }
                listener.onFailure(e);
            }));
            return;
        }
        ServiceSettings serviceSettings = newModel.getServiceSettings();
        if (serviceSettings instanceof ElasticsearchInternalServiceSettings) {
            ElasticsearchInternalServiceSettings elasticServiceSettings = (ElasticsearchInternalServiceSettings)serviceSettings;
            UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId);
            updateRequest.setNumberOfAllocations(elasticServiceSettings.getNumAllocations());
            updateRequest.setAdaptiveAllocationsSettings(elasticServiceSettings.getAdaptiveAllocationsSettings());
            updateRequest.setSource(UpdateTrainedModelDeploymentAction.Request.Source.INFERENCE_API);
            ActionListener delegate = listener.delegateFailure((l2, response) -> this.modelRegistry.updateModelTransaction(newModel, existingParsedModel, (ActionListener<Boolean>)l2));
            logger.info("Updating trained model deployment [{}] for inference entity [{}] with [{}] num_allocations and adaptive allocations [{}]", (Object)deploymentId, (Object)request.getInferenceEntityId(), (Object)elasticServiceSettings.getNumAllocations(), (Object)elasticServiceSettings.getAdaptiveAllocationsSettings());
            this.client.execute((ActionType)UpdateTrainedModelDeploymentAction.INSTANCE, (ActionRequest)updateRequest, delegate);
        } else {
            listener.onFailure((Exception)new ElasticsearchStatusException("Failed to parse [{}] of update request [{}]", RestStatus.BAD_REQUEST, new Object[]{"num_allocations", request.getContent().utf8ToString()}));
        }
    }

    private boolean isInClusterService(String name) {
        return List.of("elasticsearch", "elser").contains(name);
    }

    private String getDeploymentIdForInClusterEndpoint(Model model) {
        if (model instanceof ElasticsearchInternalModel) {
            ElasticsearchInternalModel esModel = (ElasticsearchInternalModel)model;
            return esModel.mlNodeDeploymentId();
        }
        throw new IllegalStateException(Strings.format((String)"Cannot update inference endpoint [%s]. Class [%s] is not an Elasticsearch internal model", (Object[])new Object[]{model.getInferenceEntityId(), model.getClass().getSimpleName()}));
    }

    private void throwIfTrainedModelDoesntExist(String inferenceEntityId, String deploymentId) throws ElasticsearchStatusException {
        List assignments = TrainedModelAssignmentUtils.modelAssignments((String)deploymentId, (ClusterState)this.clusterService.state());
        if (assignments == null || assignments.isEmpty()) {
            throw ExceptionsHelper.entityNotFoundException((String)"Requested model ID [{}] does not have a matching trained model and thus cannot be updated.", (Object[])new Object[]{inferenceEntityId});
        }
    }

    private void checkEndpointExists(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
        this.modelRegistry.getModelWithSecrets(inferenceEntityId, (ActionListener<UnparsedModel>)ActionListener.wrap(model -> {
            if (model == null) {
                listener.onFailure((Exception)ExceptionsHelper.entityNotFoundException((String)"The inference endpoint [{}] does not exist and cannot be updated", (Object[])new Object[]{inferenceEntityId}));
            } else {
                listener.onResponse(model);
            }
        }, e -> {
            if (e instanceof ResourceNotFoundException) {
                listener.onFailure((Exception)ExceptionsHelper.entityNotFoundException((String)"The inference endpoint [{}] does not exist and cannot be updated", (Object[])new Object[]{inferenceEntityId}));
            } else {
                listener.onFailure(e);
            }
        }));
    }

    private static XContentParser getParser(UpdateInferenceModelAction.Request request) throws IOException {
        return XContentHelper.createParser((XContentParserConfiguration)XContentParserConfiguration.EMPTY, (BytesReference)request.getContent(), (XContentType)request.getContentType());
    }

    protected ClusterBlockException checkBlock(UpdateInferenceModelAction.Request request, ClusterState state) {
        return state.blocks().globalBlockedException(this.projectResolver.getProjectId(), ClusterBlockLevel.METADATA_WRITE);
    }
}

