/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action;

import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.project.ProjectResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.features.FeatureService;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.StrictDynamicMappingException;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.inference.InferenceFeatures;
import org.elasticsearch.xpack.inference.InferenceLicenceCheck;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.common.SemanticTextInfoExtractor;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;

public class TransportPutInferenceModelAction
extends TransportMasterNodeAction<PutInferenceModelAction.Request, PutInferenceModelAction.Response> {
    private static final Logger logger = LogManager.getLogger(TransportPutInferenceModelAction.class);
    private final XPackLicenseState licenseState;
    private final ModelRegistry modelRegistry;
    private final InferenceServiceRegistry serviceRegistry;
    private volatile boolean skipValidationAndStart;
    private final ProjectResolver projectResolver;
    private final FeatureService featureService;

    @Inject
    public TransportPutInferenceModelAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, Settings settings, ProjectResolver projectResolver, FeatureService featureService) {
        super("cluster:admin/xpack/inference/put", transportService, clusterService, threadPool, actionFilters, PutInferenceModelAction.Request::new, PutInferenceModelAction.Response::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.licenseState = licenseState;
        this.modelRegistry = modelRegistry;
        this.serviceRegistry = serviceRegistry;
        this.skipValidationAndStart = (Boolean)InferencePlugin.SKIP_VALIDATE_AND_START.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(InferencePlugin.SKIP_VALIDATE_AND_START, this::setSkipValidationAndStart);
        this.projectResolver = projectResolver;
        this.featureService = featureService;
    }

    protected void masterOperation(Task task, PutInferenceModelAction.Request request, ClusterState state, ActionListener<PutInferenceModelAction.Response> listener) throws Exception {
        Optional service;
        if (this.modelRegistry.containsPreconfiguredInferenceEndpointId(request.getInferenceEntityId())) {
            listener.onFailure((Exception)new ElasticsearchStatusException("[{}] is a reserved inference ID. Cannot create a new inference endpoint with a reserved ID.", RestStatus.BAD_REQUEST, new Object[]{request.getInferenceEntityId()}));
            return;
        }
        Map<String, Object> requestAsMap = this.requestToMap(request);
        TaskType resolvedTaskType = ServiceUtils.resolveTaskType(request.getTaskType(), (String)requestAsMap.remove("task_type"));
        if (resolvedTaskType == TaskType.EMBEDDING && !this.featureService.clusterHasFeature(state, InferenceFeatures.EMBEDDING_TASK_TYPE)) {
            listener.onFailure((Exception)new ElasticsearchStatusException("task_type [" + String.valueOf(TaskType.EMBEDDING) + "] is not supported by all nodes in the cluster; please complete upgrades before creating an endpoint with this task_type", RestStatus.BAD_REQUEST, new Object[0]));
            return;
        }
        String serviceName = (String)requestAsMap.remove("service");
        if (serviceName == null) {
            listener.onFailure((Exception)new ElasticsearchStatusException("Inference endpoint configuration is missing the [service] setting", RestStatus.BAD_REQUEST, new Object[0]));
            return;
        }
        if (!InferenceLicenceCheck.isServiceLicenced(serviceName, this.licenseState)) {
            listener.onFailure((Exception)InferenceLicenceCheck.complianceException(serviceName));
            return;
        }
        if (List.of("elser", "elasticsearch").contains(serviceName)) {
            requestAsMap.put("service", serviceName);
        }
        if ((service = this.serviceRegistry.getService(serviceName)).isEmpty()) {
            listener.onFailure((Exception)new ElasticsearchStatusException("Unknown service [{}]", RestStatus.BAD_REQUEST, new Object[]{serviceName}));
            return;
        }
        if (((InferenceService)service.get()).getMinimalSupportedVersion().after((VersionId)state.getMinTransportVersion())) {
            logger.warn(Strings.format((String)"Service [%s] requires version [%s] but minimum cluster version is [%s]", (Object[])new Object[]{serviceName, ((InferenceService)service.get()).getMinimalSupportedVersion(), state.getMinTransportVersion()}));
            listener.onFailure((Exception)new ElasticsearchStatusException(Strings.format((String)"All nodes in the cluster are not aware of the service [%s].Wait for the cluster to finish upgrading and try again.", (Object[])new Object[]{serviceName}), RestStatus.BAD_REQUEST, new Object[0]));
            return;
        }
        List assignments = TrainedModelAssignmentUtils.modelAssignments((String)request.getInferenceEntityId(), (ClusterState)this.clusterService.state());
        if (!(assignments == null || assignments.isEmpty())) {
            listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"Inference endpoint IDs must be unique. Requested inference endpoint ID [{}] matches existing trained model ID(s) but must not.", (Object[])new Object[]{request.getInferenceEntityId()}));
            return;
        }
        this.parseAndStoreModel((InferenceService)service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), state.metadata(), listener);
    }

    private void parseAndStoreModel(InferenceService service, String inferenceEntityId, TaskType taskType, Map<String, Object> config, TimeValue timeout, Metadata metadata, ActionListener<PutInferenceModelAction.Response> listener) {
        ActionListener storeModelListener = listener.delegateFailureAndWrap((delegate, verifiedModel) -> this.modelRegistry.storeModel((Model)verifiedModel, (ActionListener<Boolean>)ActionListener.wrap(r -> this.startInferenceEndpoint(service, timeout, (Model)verifiedModel, (ActionListener<PutInferenceModelAction.Response>)delegate), e -> {
            if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
                delegate.onFailure((Exception)new ElasticsearchStatusException("One or more nodes in your cluster does not support chunking_settings. Please update all nodes in your cluster to the latest version to use chunking_settings.", RestStatus.BAD_REQUEST, new Object[0]));
            } else {
                delegate.onFailure(e);
            }
        }), timeout));
        ActionListener modelValidatingListener = listener.delegateFailureAndWrap((delegate, model) -> {
            if (this.skipValidationAndStart) {
                storeModelListener.onResponse(model);
            } else {
                ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service).validate(service, (Model)model, timeout, (ActionListener<Model>)storeModelListener);
            }
        });
        ActionListener existingUsesListener = listener.delegateFailureAndWrap((delegate, model) -> this.threadPool.executor("inference_utility").execute(() -> this.checkForExistingUsesOfInferenceId(metadata, (Model)model, (ActionListener<Model>)modelValidatingListener)));
        service.parseRequestConfig(inferenceEntityId, taskType, config, existingUsesListener);
    }

    private void checkForExistingUsesOfInferenceId(Metadata metadata, Model model, ActionListener<Model> modelValidatingListener) {
        Set<String> inferenceEntityIdSet = Set.of(model.getInferenceEntityId());
        Set<String> indicesWithIncompatibleMappings = this.findIndicesWithIncompatibleMappings(model, metadata, inferenceEntityIdSet);
        if (indicesWithIncompatibleMappings.isEmpty()) {
            modelValidatingListener.onResponse((Object)model);
        } else {
            modelValidatingListener.onFailure((Exception)new ElasticsearchStatusException(TransportPutInferenceModelAction.buildErrorString(model.getInferenceEntityId(), indicesWithIncompatibleMappings), RestStatus.BAD_REQUEST, new Object[0]));
        }
    }

    private Set<String> findIndicesWithIncompatibleMappings(Model model, Metadata metadata, Set<String> inferenceEntityIdSet) {
        Map<String, MinimalServiceSettings> serviceSettingsMap = SemanticTextInfoExtractor.getModelSettingsForIndicesReferencingInferenceEndpoints(metadata, inferenceEntityIdSet);
        HashSet<String> incompatibleIndices = new HashSet<String>();
        if (!serviceSettingsMap.isEmpty()) {
            MinimalServiceSettings newSettings = new MinimalServiceSettings(model);
            serviceSettingsMap.forEach((indexName, existingSettings) -> {
                if (!SemanticTextFieldMapper.canMergeModelSettings(existingSettings, newSettings, new FieldMapper.Conflicts(""))) {
                    incompatibleIndices.add((String)indexName);
                }
            });
        }
        return incompatibleIndices;
    }

    private static String buildErrorString(String inferenceId, Set<String> indicesWithIncompatibleMappings) {
        return "Inference endpoint [" + inferenceId + "] could not be created because the inference_id is being used in mappings with incompatible settings for indices: " + String.valueOf(indicesWithIncompatibleMappings) + ". Please either use a different inference_id or update the index mappings to refer to a different inference_id.";
    }

    private void startInferenceEndpoint(InferenceService service, TimeValue timeout, Model model, ActionListener<PutInferenceModelAction.Response> listener) {
        if (this.skipValidationAndStart) {
            listener.onResponse((Object)new PutInferenceModelAction.Response(model.getConfigurations()));
        } else {
            service.start(model, timeout, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
        }
    }

    private Map<String, Object> requestToMap(PutInferenceModelAction.Request request) throws IOException {
        try (XContentParser parser = XContentHelper.createParser((XContentParserConfiguration)XContentParserConfiguration.EMPTY, (BytesReference)request.getContent(), (XContentType)request.getContentType());){
            Map map = parser.map();
            return map;
        }
    }

    private void setSkipValidationAndStart(boolean skipValidationAndStart) {
        this.skipValidationAndStart = skipValidationAndStart;
    }

    protected ClusterBlockException checkBlock(PutInferenceModelAction.Request request, ClusterState state) {
        return state.blocks().globalBlockedException(this.projectResolver.getProjectId(), ClusterBlockLevel.METADATA_WRITE);
    }
}

