/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.registry;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ElasticsearchClient;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.AckedBatchedClusterStateUpdateTask;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateAckListener;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.ClusterStateTaskListener;
import org.elasticsearch.cluster.SimpleBatchedAckListenerTaskExecutor;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.cluster.service.MasterServiceTaskQueue;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.gateway.GatewayService;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.reindex.BulkByScrollResponse;
import org.elasticsearch.index.reindex.DeleteByQueryAction;
import org.elasticsearch.index.reindex.DeleteByQueryRequest;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse;
import org.elasticsearch.xpack.inference.registry.ClearInferenceEndpointCacheAction;
import org.elasticsearch.xpack.inference.registry.ModelRegistryMetadata;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

public class ModelRegistry
implements ClusterStateListener {
    private static final String TASK_TYPE_FIELD = "task_type";
    private static final String MODEL_ID_FIELD = "model_id";
    private static final Logger logger = LogManager.getLogger(ModelRegistry.class);
    private final OriginSettingClient client;
    private final Map<String, InferenceService.DefaultConfigId> defaultConfigIds;
    private final MasterServiceTaskQueue<MetadataTask> metadataTaskQueue;
    private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false);
    private final Set<String> preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap());
    private final ClusterService clusterService;
    private final AtomicReference<Metadata> lastMetadata = new AtomicReference();

    public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) {
        if (modelConfigMap.config() == null) {
            throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST, new Object[0]);
        }
        String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), MODEL_ID_FIELD);
        String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), "service");
        String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TASK_TYPE_FIELD);
        TaskType taskType = TaskType.fromString((String)taskTypeStr);
        return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets());
    }

    public ModelRegistry(ClusterService clusterService, Client client) {
        this.clusterService = Objects.requireNonNull(clusterService);
        this.client = new OriginSettingClient(client, "inference");
        this.defaultConfigIds = new ConcurrentHashMap<String, InferenceService.DefaultConfigId>();
        SimpleBatchedAckListenerTaskExecutor<MetadataTask> executor = new SimpleBatchedAckListenerTaskExecutor<MetadataTask>(this){

            public Tuple<ClusterState, ClusterStateAckListener> executeTask(MetadataTask task, ClusterState clusterState) throws Exception {
                ProjectMetadata projectMetadata = clusterState.metadata().getProject(task.getProjectId());
                ModelRegistryMetadata updated = task.executeTask(ModelRegistryMetadata.fromState(projectMetadata));
                ProjectMetadata.Builder newProjectMetadata = ProjectMetadata.builder((ProjectMetadata)projectMetadata).putCustom("model_registry", (Metadata.ProjectCustom)updated);
                return new Tuple((Object)ClusterState.builder((ClusterState)clusterState).putProjectMetadata(newProjectMetadata).build(), (Object)task);
            }
        };
        this.metadataTaskQueue = clusterService.createTaskQueue("model_registry", Priority.NORMAL, (ClusterStateTaskExecutor)executor);
    }

    public boolean containsPreconfiguredInferenceEndpointId(String inferenceEntityId) {
        if (this.defaultConfigIds.containsKey(inferenceEntityId)) {
            return true;
        }
        if (this.lastMetadata.get() != null) {
            ProjectMetadata project = this.lastMetadata.get().getProject(ProjectId.DEFAULT);
            ModelRegistryMetadata state = ModelRegistryMetadata.fromState(project);
            Set<String> eisPreconfiguredEndpoints = state.getServiceInferenceIds("elastic");
            return eisPreconfiguredEndpoints.contains(inferenceEntityId);
        }
        return false;
    }

    public synchronized void putDefaultIdIfAbsent(InferenceService.DefaultConfigId defaultConfigId) {
        this.defaultConfigIds.putIfAbsent(defaultConfigId.inferenceId(), defaultConfigId);
    }

    public synchronized void addDefaultIds(InferenceService.DefaultConfigId defaultConfigId) throws IllegalStateException {
        InferenceService.DefaultConfigId config = this.defaultConfigIds.get(defaultConfigId.inferenceId());
        if (config != null) {
            throw new IllegalStateException("Cannot add default endpoint to the inference endpoint registry with duplicate inference id [" + defaultConfigId.inferenceId() + "] declared by service [" + defaultConfigId.service().name() + "]. The inference Id is already use by [" + config.service().name() + "] service.");
        }
        this.defaultConfigIds.put(defaultConfigId.inferenceId(), defaultConfigId);
    }

    public void clearDefaultIds() {
        this.defaultConfigIds.clear();
    }

    public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
        if (this.lastMetadata.get() == null) {
            throw new IllegalStateException("initial cluster state not set yet");
        }
        InferenceService.DefaultConfigId config = this.defaultConfigIds.get(inferenceEntityId);
        if (config != null) {
            return config.settings();
        }
        ProjectMetadata project = this.lastMetadata.get().getProject(ProjectId.DEFAULT);
        ModelRegistryMetadata state = ModelRegistryMetadata.fromState(project);
        MinimalServiceSettings existing = state.getMinimalServiceSettings(inferenceEntityId);
        if (state.isUpgraded() && existing == null) {
            throw new ResourceNotFoundException(inferenceEntityId + " does not exist in this cluster.", new Object[0]);
        }
        return existing;
    }

    public Set<String> getInferenceIds() {
        Set<Object> metadataInferenceIds = Set.of();
        if (this.lastMetadata.get() != null) {
            ProjectMetadata project = this.lastMetadata.get().getProject(ProjectId.DEFAULT);
            ModelRegistryMetadata state = ModelRegistryMetadata.fromState(project);
            metadataInferenceIds = state.getInferenceIds();
        }
        HashSet<String> ids = new HashSet<String>(metadataInferenceIds);
        ids.addAll(Set.copyOf(this.defaultConfigIds.keySet()));
        return ids;
    }

    public void getModelWithSecrets(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
        ActionListener searchListener = ActionListener.wrap(searchResponse -> {
            if (searchResponse.getHits().getHits().length == 0) {
                InferenceService.DefaultConfigId maybeDefault = this.defaultConfigIds.get(inferenceEntityId);
                if (maybeDefault != null) {
                    this.getDefaultConfig(true, maybeDefault, listener);
                } else {
                    listener.onFailure((Exception)((Object)this.inferenceNotFoundException(inferenceEntityId)));
                }
                return;
            }
            listener.onResponse((Object)ModelRegistry.unparsedModelFromMap(this.createModelConfigMap(searchResponse.getHits(), inferenceEntityId)));
        }, e -> {
            logger.warn(Strings.format((String)"Failed to load inference endpoint with secrets [%s]", (Object[])new Object[]{inferenceEntityId}), (Throwable)e);
            listener.onFailure((Exception)new ElasticsearchException(Strings.format((String)"Failed to load inference endpoint with secrets [%s], error: [%s]", (Object[])new Object[]{inferenceEntityId, e.getMessage()}), (Throwable)e, new Object[0]));
        });
        QueryBuilder queryBuilder = ModelRegistry.documentIdQuery(inferenceEntityId);
        SearchRequest modelSearch = (SearchRequest)this.client.prepareSearch(new String[]{".inference*", ".secrets-inference*"}).setQuery(queryBuilder).setSize(2).setAllowPartialSearchResults(false).request();
        this.client.search(modelSearch, searchListener);
    }

    public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
        ActionListener searchListener = ActionListener.wrap(searchResponse -> {
            if (searchResponse.getHits().getHits().length == 0) {
                InferenceService.DefaultConfigId maybeDefault = this.defaultConfigIds.get(inferenceEntityId);
                if (maybeDefault != null) {
                    this.getDefaultConfig(true, maybeDefault, listener);
                } else {
                    listener.onFailure((Exception)((Object)this.inferenceNotFoundException(inferenceEntityId)));
                }
                return;
            }
            List<UnparsedModel> modelConfigs = this.parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
            assert (modelConfigs.size() == 1);
            listener.onResponse((Object)modelConfigs.get(0));
        }, e -> {
            logger.warn(Strings.format((String)"Failed to load inference endpoint [%s]", (Object[])new Object[]{inferenceEntityId}), (Throwable)e);
            listener.onFailure((Exception)new ElasticsearchException(Strings.format((String)"Failed to load inference endpoint [%s], error: [%s]", (Object[])new Object[]{inferenceEntityId, e.getMessage()}), (Throwable)e, new Object[0]));
        });
        QueryBuilder queryBuilder = ModelRegistry.documentIdQuery(inferenceEntityId);
        SearchRequest modelSearch = (SearchRequest)this.client.prepareSearch(new String[]{".inference*"}).setQuery(queryBuilder).setSize(1).setTrackTotalHits(false).request();
        this.client.search(modelSearch, searchListener);
    }

    private ResourceNotFoundException inferenceNotFoundException(String inferenceEntityId) {
        return new ResourceNotFoundException("Inference endpoint not found [{}]", new Object[]{inferenceEntityId});
    }

    public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedModel>> listener) {
        ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
            List<UnparsedModel> modelConfigs = this.parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
            List<InferenceService.DefaultConfigId> defaultConfigsForTaskType = ModelRegistry.taskTypeMatchedDefaults(taskType, this.defaultConfigIds.values());
            this.addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, (ActionListener<List<UnparsedModel>>)delegate);
        });
        ConstantScoreQueryBuilder queryBuilder = QueryBuilders.constantScoreQuery((QueryBuilder)QueryBuilders.termsQuery((String)TASK_TYPE_FIELD, (String[])new String[]{taskType.toString()}));
        SearchRequest modelSearch = (SearchRequest)this.client.prepareSearch(new String[]{".inference*"}).setQuery((QueryBuilder)queryBuilder).setSize(10000).setTrackTotalHits(false).addSort(MODEL_ID_FIELD, SortOrder.ASC).request();
        this.client.search(modelSearch, searchListener);
    }

    public void getAllModels(boolean persistDefaultEndpoints, ActionListener<List<UnparsedModel>> listener) {
        ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
            List<UnparsedModel> foundConfigs = this.parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
            this.addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, this.defaultConfigIds.values(), (ActionListener<List<UnparsedModel>>)delegate);
        });
        ConstantScoreQueryBuilder queryBuilder = QueryBuilders.constantScoreQuery((QueryBuilder)QueryBuilders.existsQuery((String)TASK_TYPE_FIELD));
        SearchRequest modelSearch = (SearchRequest)this.client.prepareSearch(new String[]{".inference*"}).setQuery((QueryBuilder)queryBuilder).setSize(10000).setTrackTotalHits(false).addSort(MODEL_ID_FIELD, SortOrder.ASC).request();
        this.client.search(modelSearch, searchListener);
    }

    private void addAllDefaultConfigsIfMissing(boolean persistDefaultEndpoints, List<UnparsedModel> foundConfigs, Collection<InferenceService.DefaultConfigId> matchedDefaults, ActionListener<List<UnparsedModel>> listener) {
        Set foundIds = foundConfigs.stream().map(UnparsedModel::inferenceEntityId).collect(Collectors.toSet());
        List<InferenceService.DefaultConfigId> missing = matchedDefaults.stream().filter(d -> !foundIds.contains(d.inferenceId())).toList();
        if (missing.isEmpty()) {
            listener.onResponse(foundConfigs);
        } else {
            GroupedActionListener groupedListener = new GroupedActionListener(missing.size(), listener.delegateFailure((delegate, listOfModels) -> {
                ArrayList<UnparsedModel> allConfigs = new ArrayList<UnparsedModel>();
                allConfigs.addAll(foundConfigs);
                allConfigs.addAll((Collection<UnparsedModel>)listOfModels);
                allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId));
                delegate.onResponse(allConfigs);
            }));
            for (InferenceService.DefaultConfigId required : missing) {
                this.getDefaultConfig(persistDefaultEndpoints, required, (ActionListener<UnparsedModel>)groupedListener);
            }
        }
    }

    private void getDefaultConfig(boolean persistDefaultEndpoints, InferenceService.DefaultConfigId defaultConfig, ActionListener<UnparsedModel> listener) {
        defaultConfig.service().defaultConfigs(listener.delegateFailureAndWrap((delegate, models) -> {
            boolean foundModel = false;
            for (Model m : models) {
                if (!m.getInferenceEntityId().equals(defaultConfig.inferenceId())) continue;
                foundModel = true;
                if (persistDefaultEndpoints) {
                    this.storeDefaultEndpoint(m, () -> listener.onResponse((Object)ModelRegistry.modelToUnparsedModel(m)));
                    break;
                }
                listener.onResponse((Object)ModelRegistry.modelToUnparsedModel(m));
                break;
            }
            if (!foundModel) {
                listener.onFailure((Exception)new IllegalStateException("Configuration not found for default inference id [" + defaultConfig.inferenceId() + "]"));
            }
        }));
    }

    private void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
        ActionListener responseListener = ActionListener.wrap(success -> logger.debug("Added default inference endpoint [{}]", (Object)preconfigured.getInferenceEntityId()), exception -> {
            if (exception instanceof ResourceAlreadyExistsException) {
                logger.debug("Default inference id [{}] already exists", (Object)preconfigured.getInferenceEntityId());
            } else {
                logger.error("Failed to store default inference id [" + preconfigured.getInferenceEntityId() + "]", (Throwable)exception);
            }
        });
        this.storeModel(preconfigured, false, (ActionListener<Boolean>)ActionListener.runAfter((ActionListener)responseListener, (Runnable)runAfter), AcknowledgedRequest.DEFAULT_ACK_TIMEOUT);
    }

    private ArrayList<ModelConfigMap> parseHitsAsModels(SearchHits hits) {
        ArrayList<ModelConfigMap> modelConfigs = new ArrayList<ModelConfigMap>();
        for (SearchHit hit : hits) {
            modelConfigs.add(new ModelConfigMap(hit.getSourceAsMap(), Map.of()));
        }
        return modelConfigs;
    }

    private ModelConfigMap createModelConfigMap(SearchHits hits, String inferenceEntityId) {
        Map mappedHits = Arrays.stream(hits.getHits()).collect(Collectors.toMap(hit -> {
            if (hit.getIndex().startsWith(".inference")) {
                return ".inference";
            }
            if (hit.getIndex().startsWith(".secrets-inference")) {
                return ".secrets-inference";
            }
            logger.warn(Strings.format((String)"Found invalid index for inference endpoint [%s] at index [%s]", (Object[])new Object[]{inferenceEntityId, hit.getIndex()}));
            throw new IllegalArgumentException(Strings.format((String)"Invalid result while loading inference endpoint [%s] index: [%s]. Try deleting and reinitializing the service", (Object[])new Object[]{inferenceEntityId, hit.getIndex()}));
        }, Function.identity()));
        if (!mappedHits.containsKey(".inference") || !mappedHits.containsKey(".secrets-inference") || mappedHits.size() > 2) {
            logger.warn(Strings.format((String)"Failed to load inference endpoint [%s], found endpoint parts from index prefixes: [%s]", (Object[])new Object[]{inferenceEntityId, mappedHits.keySet()}));
            throw new IllegalStateException(Strings.format((String)"Failed to load inference endpoint [%s]. Endpoint is in an invalid state, try deleting and reinitializing the service", (Object[])new Object[]{inferenceEntityId}));
        }
        return new ModelConfigMap(((SearchHit)mappedHits.get(".inference")).getSourceAsMap(), ((SearchHit)mappedHits.get(".secrets-inference")).getSourceAsMap());
    }

    public void updateModelTransaction(Model newModel, Model existingModel, ActionListener<Boolean> finalListener) {
        String inferenceEntityId = newModel.getConfigurations().getInferenceEntityId();
        logger.info("Attempting to store update to inference endpoint [{}]", (Object)inferenceEntityId);
        if (this.preventDeletionLock.contains(inferenceEntityId)) {
            logger.warn(Strings.format((String)"Attempted to update endpoint [{}] that is already being updated", (Object[])new Object[]{inferenceEntityId}));
            finalListener.onFailure((Exception)new ElasticsearchStatusException("Endpoint [{}] is currently being updated. Try again once the update completes", RestStatus.CONFLICT, new Object[]{inferenceEntityId}));
            return;
        }
        this.preventDeletionLock.add(inferenceEntityId);
        SubscribableListener.newForked(subListener -> {
            IndexRequestBuilder configRequestBuilder = ModelRegistry.createIndexRequestBuilder(inferenceEntityId, ".inference", (ToXContentObject)newModel.getConfigurations(), true, (Client)this.client);
            ActionListener storeConfigListener = subListener.delegateResponse((l, e) -> {
                this.preventDeletionLock.remove(inferenceEntityId);
                l.onFailure(e);
            });
            this.client.prepareBulk().add(configRequestBuilder).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeConfigListener);
        }).andThen((subListener, configResponse) -> {
            if (configResponse.hasFailures()) {
                logger.error(Strings.format((String)"Failed to update inference endpoint [%s] due to [%s]", (Object[])new Object[]{inferenceEntityId, configResponse.buildFailureMessage()}));
                this.preventDeletionLock.remove(inferenceEntityId);
                finalListener.onFailure((Exception)new ElasticsearchStatusException(Strings.format((String)"Failed to update inference endpoint [%s] due to [%s]", (Object[])new Object[]{inferenceEntityId, configResponse.buildFailureMessage()}), RestStatus.INTERNAL_SERVER_ERROR, new Object[]{configResponse.buildFailureMessage()}));
            } else {
                IndexRequestBuilder secretsRequestBuilder = ModelRegistry.createIndexRequestBuilder(inferenceEntityId, ".secrets-inference", (ToXContentObject)newModel.getSecrets(), true, (Client)this.client);
                ActionListener storeSecretsListener = subListener.delegateResponse((l, e) -> {
                    this.preventDeletionLock.remove(inferenceEntityId);
                    l.onFailure(e);
                });
                this.client.prepareBulk().add(secretsRequestBuilder).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(storeSecretsListener);
            }
        }).andThen((subListener, secretsResponse) -> {
            if (secretsResponse.hasFailures()) {
                IndexRequestBuilder configRequestBuilder = ModelRegistry.createIndexRequestBuilder(inferenceEntityId, ".inference", (ToXContentObject)existingModel.getConfigurations(), true, (Client)this.client);
                logger.error("Failed to update inference endpoint secrets [{}], attempting rolling back to previous state", (Object)inferenceEntityId);
                ActionListener rollbackConfigListener = subListener.delegateResponse((l, e) -> {
                    this.preventDeletionLock.remove(inferenceEntityId);
                    l.onFailure(e);
                });
                this.client.prepareBulk().add(configRequestBuilder).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(rollbackConfigListener);
            } else {
                this.preventDeletionLock.remove(inferenceEntityId);
                this.refreshInferenceEndpointCache();
                finalListener.onResponse((Object)true);
            }
        }).andThen((subListener, configResponse) -> {
            this.preventDeletionLock.remove(inferenceEntityId);
            if (configResponse.hasFailures()) {
                logger.error(Strings.format((String)"Failed to update inference endpoint [%s] due to [%s]", (Object[])new Object[]{inferenceEntityId, configResponse.buildFailureMessage()}));
                finalListener.onFailure((Exception)new ElasticsearchStatusException(Strings.format((String)"Failed to rollback while handling failure to update inference endpoint [%s]. Endpoint may be in an inconsistent state due to [%s]", (Object[])new Object[]{inferenceEntityId, configResponse.buildFailureMessage()}), RestStatus.INTERNAL_SERVER_ERROR, new Object[0]));
            } else {
                logger.warn("Failed to update inference endpoint [{}], successfully rolled back to previous state", (Object)inferenceEntityId);
                finalListener.onResponse((Object)false);
            }
        });
    }

    public void storeModel(Model model, ActionListener<Boolean> listener, TimeValue timeout) {
        this.storeModel(model, true, listener, timeout);
    }

    private void storeModel(Model model, boolean updateClusterState, ActionListener<Boolean> listener, TimeValue timeout) {
        this.storeModels(List.of(model), updateClusterState, (ActionListener<List<ModelStoreResponse>>)listener.delegateFailureAndWrap((delegate, responses) -> {
            Optional<ModelStoreResponse> firstFailureResponse = responses.stream().filter(ModelStoreResponse::failed).findFirst();
            if (!firstFailureResponse.isPresent()) {
                delegate.onResponse((Object)Boolean.TRUE);
                return;
            }
            ModelStoreResponse failureItem = firstFailureResponse.get();
            if (ExceptionsHelper.unwrapCause((Throwable)failureItem.failureCause()) instanceof VersionConflictEngineException) {
                delegate.onFailure((Exception)new ResourceAlreadyExistsException("Inference endpoint [{}] already exists", new Object[]{failureItem.inferenceId()}));
                return;
            }
            delegate.onFailure((Exception)new ElasticsearchStatusException(Strings.format((String)"Failed to store inference endpoint [%s]", (Object[])new Object[]{failureItem.inferenceId()}), RestStatus.INTERNAL_SERVER_ERROR, (Throwable)failureItem.failureCause(), new Object[0]));
        }), timeout);
    }

    public void storeModels(List<Model> models, ActionListener<List<ModelStoreResponse>> listener, TimeValue timeout) {
        this.storeModels(models, true, listener, timeout);
    }

    private void storeModels(List<Model> models, boolean updateClusterState, ActionListener<List<ModelStoreResponse>> listener, TimeValue timeout) {
        if (models.isEmpty()) {
            listener.onResponse(List.of());
            return;
        }
        List<Model> modelsWithoutDuplicates = models.stream().distinct().toList();
        BulkRequestBuilder bulkRequestBuilder = this.client.prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
        for (Model model : modelsWithoutDuplicates) {
            bulkRequestBuilder.add(ModelRegistry.createIndexRequestBuilder(model.getInferenceEntityId(), ".inference", (ToXContentObject)model.getConfigurations(), false, (Client)this.client));
            bulkRequestBuilder.add(ModelRegistry.createIndexRequestBuilder(model.getInferenceEntityId(), ".secrets-inference", (ToXContentObject)model.getSecrets(), false, (Client)this.client));
        }
        bulkRequestBuilder.execute(this.getStoreMultipleModelsListener(modelsWithoutDuplicates, updateClusterState, listener, timeout));
    }

    private ActionListener<BulkResponse> getStoreMultipleModelsListener(List<Model> models, boolean updateClusterState, ActionListener<List<ModelStoreResponse>> listener, TimeValue timeout) {
        ActionListener cleanupListener = listener.delegateFailureAndWrap((delegate, responses) -> {
            Set<String> inferenceIdsToBeRemoved = responses.stream().filter(r -> r.modifiedIndex() && r.modelStoreResponse().failed()).map(r -> r.modelStoreResponse().inferenceId()).collect(Collectors.toSet());
            List<ModelStoreResponse> storageResponses = responses.stream().map(StoreResponseWithIndexInfo::modelStoreResponse).toList();
            ActionListener deleteListener = ActionListener.wrap(ignored -> delegate.onResponse((Object)storageResponses), e -> {
                logger.atWarn().withThrowable((Throwable)e).log("Failed to clean up partially stored inference endpoints {}. The service may be in an inconsistent state. Please try deleting and re-adding the endpoints.", (Object)inferenceIdsToBeRemoved);
                delegate.onResponse((Object)storageResponses);
            });
            this.deleteModels(inferenceIdsToBeRemoved, (ActionListener<Boolean>)deleteListener);
        });
        return ActionListener.wrap(bulkItemResponses -> {
            Map<String, String> docIdToInferenceId = models.stream().collect(Collectors.toMap(m -> Model.documentId((String)m.getInferenceEntityId()), Model::getInferenceEntityId, (id1, id2) -> {
                logger.warn("Encountered duplicate inference ids when storing endpoints: [{}]", id1);
                return id1;
            }));
            Map<String, Model> inferenceIdToModel = models.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity(), (id1, id2) -> id1));
            if (bulkItemResponses.getItems().length == 0) {
                String inferenceEntityIds = String.join((CharSequence)", ", models.stream().map(Model::getInferenceEntityId).toList());
                logger.warn("Storing inference endpoints [{}] failed, no items were received from the bulk response", (Object)inferenceEntityIds);
                listener.onFailure((Exception)new ElasticsearchStatusException("Failed to store inference endpoints [{}], empty bulk response received.", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{inferenceEntityIds}));
                return;
            }
            ResponseInfo responseInfo = ModelRegistry.getResponseInfo(bulkItemResponses, docIdToInferenceId, inferenceIdToModel);
            if (updateClusterState) {
                this.updateClusterState(responseInfo.successfullyStoredModels(), (ActionListener<AcknowledgedResponse>)cleanupListener.delegateFailureIgnoreResponseAndWrap(delegate -> delegate.onResponse(responseInfo.responses())), timeout);
            } else {
                cleanupListener.onResponse(responseInfo.responses());
            }
        }, e -> {
            String errorMessage = Strings.format((String)"Failed to store inference endpoints [%s]", (Object[])new Object[]{models.stream().map(Model::getInferenceEntityId).collect(Collectors.joining(", "))});
            logger.warn(errorMessage, (Throwable)e);
            listener.onFailure((Exception)new ElasticsearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR, (Throwable)e, new Object[0]));
        });
    }

    private static ResponseInfo getResponseInfo(BulkResponse bulkResponse, Map<String, String> docIdToInferenceId, Map<String, Model> inferenceIdToModel) {
        ArrayList<StoreResponseWithIndexInfo> responses = new ArrayList<StoreResponseWithIndexInfo>();
        ArrayList<Model> successfullyStoredModels = new ArrayList<Model>();
        BulkItemResponse[] bulkItems = bulkResponse.getItems();
        for (int i = 0; i < bulkItems.length; i += 2) {
            BulkItemResponse configurationItem = bulkItems[i];
            ModelStoreResponse configStoreResponse = ModelRegistry.createModelStoreResponse(configurationItem, docIdToInferenceId);
            Model modelFromBulkItem = ModelRegistry.getModelFromMap(docIdToInferenceId.get(configurationItem.getId()), inferenceIdToModel);
            if (i + 1 >= bulkResponse.getItems().length) {
                logger.error("Expected an even number of bulk response items, got [{}]", (Object)bulkResponse.getItems().length);
                if (configStoreResponse.failed()) {
                    responses.add(new StoreResponseWithIndexInfo(configStoreResponse, false));
                } else {
                    responses.add(new StoreResponseWithIndexInfo(new ModelStoreResponse(configStoreResponse.inferenceId(), RestStatus.INTERNAL_SERVER_ERROR, (Exception)new IllegalStateException("Failed to receive part of bulk response")), true));
                }
                return new ResponseInfo(responses, successfullyStoredModels);
            }
            BulkItemResponse secretsItem = bulkItems[i + 1];
            ModelStoreResponse secretsStoreResponse = ModelRegistry.createModelStoreResponse(secretsItem, docIdToInferenceId);
            assert (secretsStoreResponse.inferenceId().equals(configStoreResponse.inferenceId())) : "Mismatched inference ids in bulk response items, configuration id [" + configStoreResponse.inferenceId() + "] secrets id [" + secretsStoreResponse.inferenceId() + "]";
            if (configStoreResponse.failed()) {
                responses.add(new StoreResponseWithIndexInfo(configStoreResponse, !secretsStoreResponse.failed()));
                continue;
            }
            if (secretsStoreResponse.failed()) {
                responses.add(new StoreResponseWithIndexInfo(secretsStoreResponse, true));
                continue;
            }
            responses.add(new StoreResponseWithIndexInfo(configStoreResponse, true));
            if (modelFromBulkItem == null) continue;
            successfullyStoredModels.add(modelFromBulkItem);
        }
        return new ResponseInfo(responses, successfullyStoredModels);
    }

    private static ModelStoreResponse createModelStoreResponse(BulkItemResponse item, Map<String, String> docIdToInferenceId) {
        BulkItemResponse.Failure failure = item.getFailure();
        String inferenceIdOrUnknown = "unknown";
        String inferenceIdMaybeNull = docIdToInferenceId.get(item.getId());
        if (inferenceIdMaybeNull == null) {
            logger.warn("Failed to find inference id for document id [{}]", (Object)item.getId());
        } else {
            inferenceIdOrUnknown = inferenceIdMaybeNull;
        }
        if (item.isFailed() && failure != null) {
            logger.warn(Strings.format((String)"Failed to store document id: [%s] inference id: [%s] index: [%s] bulk failure message [%s]", (Object[])new Object[]{item.getId(), inferenceIdOrUnknown, item.getIndex(), item.getFailureMessage()}));
            return new ModelStoreResponse(inferenceIdOrUnknown, item.status(), failure.getCause());
        }
        return new ModelStoreResponse(inferenceIdOrUnknown, item.status(), null);
    }

    private static Model getModelFromMap(@Nullable String inferenceId, Map<String, Model> inferenceIdToModel) {
        if (inferenceId != null) {
            return inferenceIdToModel.get(inferenceId);
        }
        return null;
    }

    private void updateClusterState(List<Model> models, ActionListener<AcknowledgedResponse> listener, TimeValue timeout) {
        Set inferenceIdsSet = models.stream().map(Model::getInferenceEntityId).collect(Collectors.toSet());
        ActionListener storeListener = listener.delegateResponse((delegate, exc) -> {
            logger.warn(Strings.format((String)"Failed to add minimal service settings to cluster state for inference endpoints %s", (Object[])new Object[]{inferenceIdsSet}), (Throwable)exc);
            this.deleteModels(inferenceIdsSet, (ActionListener<Boolean>)ActionListener.running(() -> delegate.onFailure((Exception)new ElasticsearchStatusException(Strings.format((String)"Failed to add the inference endpoints %s. The service may be in an inconsistent state. Please try deleting and re-adding the endpoints.", (Object[])new Object[]{inferenceIdsSet}), RestStatus.INTERNAL_SERVER_ERROR, (Throwable)exc, new Object[0]))));
        });
        try {
            this.metadataTaskQueue.submitTask(Strings.format((String)"add model metadata for %s", (Object[])new Object[]{inferenceIdsSet}), (ClusterStateTaskListener)new AddModelMetadataTask(ProjectId.DEFAULT, models.stream().map(model -> new ModelAndSettings(model.getInferenceEntityId(), new MinimalServiceSettings(model))).toList(), (ActionListener<AcknowledgedResponse>)storeListener), timeout);
        }
        catch (Exception exc2) {
            storeListener.onFailure(exc2);
        }
    }

    public boolean isReady() {
        if (this.lastMetadata.get() == null) {
            return false;
        }
        return !this.clusterService.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK);
    }

    public synchronized void removeDefaultConfigs(Set<String> inferenceEntityIds, ActionListener<Boolean> listener) {
        if (inferenceEntityIds.isEmpty()) {
            listener.onResponse((Object)true);
            return;
        }
        this.defaultConfigIds.keySet().removeAll(inferenceEntityIds);
        this.deleteModels(inferenceEntityIds, false, listener);
    }

    public void deleteModel(String inferenceEntityId, ActionListener<Boolean> listener) {
        this.deleteModels(Set.of(inferenceEntityId), listener);
    }

    public void deleteModels(Set<String> inferenceEntityIds, ActionListener<Boolean> listener) {
        this.deleteModels(inferenceEntityIds, true, listener);
    }

    private void deleteModels(Set<String> inferenceEntityIds, boolean updateClusterState, ActionListener<Boolean> listener) {
        if (inferenceEntityIds.isEmpty()) {
            listener.onResponse((Object)true);
            return;
        }
        HashSet<String> lockedInferenceIds = new HashSet<String>(inferenceEntityIds);
        lockedInferenceIds.retainAll(this.preventDeletionLock);
        if (!lockedInferenceIds.isEmpty()) {
            listener.onFailure((Exception)new ElasticsearchStatusException(org.elasticsearch.common.Strings.format((String)"The inference endpoint(s) %s are currently being updated, please wait until after they are finished updating to delete.", (Object[])new Object[]{lockedInferenceIds}), RestStatus.CONFLICT, new Object[0]));
            return;
        }
        DeleteByQueryRequest request = ModelRegistry.createDeleteRequest(inferenceEntityIds);
        this.client.execute((ActionType)DeleteByQueryAction.INSTANCE, (ActionRequest)request, ActionListener.runAfter(this.getDeleteModelClusterStateListener(inferenceEntityIds, updateClusterState, listener), this::refreshInferenceEndpointCache));
    }

    private ActionListener<BulkByScrollResponse> getDeleteModelClusterStateListener(final Set<String> inferenceEntityIds, final boolean updateClusterState, final ActionListener<Boolean> listener) {
        return new ActionListener<BulkByScrollResponse>(){

            public void onResponse(BulkByScrollResponse bulkByScrollResponse) {
                if (!updateClusterState) {
                    listener.onResponse((Object)Boolean.TRUE);
                    return;
                }
                var clusterStateListener = new ActionListener<AcknowledgedResponse>(){

                    public void onResponse(AcknowledgedResponse acknowledgedResponse) {
                        listener.onResponse((Object)acknowledgedResponse.isAcknowledged());
                    }

                    public void onFailure(Exception exc) {
                        listener.onFailure((Exception)new ElasticsearchStatusException(Strings.format((String)"Failed to delete the inference endpoint [%s]. The service may be in an inconsistent state. Please try deleting the endpoint again.", (Object[])new Object[]{inferenceEntityIds}), RestStatus.INTERNAL_SERVER_ERROR, (Throwable)exc, new Object[0]));
                    }
                };
                try {
                    ModelRegistry.this.metadataTaskQueue.submitTask("delete models [" + String.valueOf(inferenceEntityIds) + "]", (ClusterStateTaskListener)new DeleteModelMetadataTask(ProjectId.DEFAULT, inferenceEntityIds, clusterStateListener), null);
                }
                catch (Exception exc) {
                    clusterStateListener.onFailure(exc);
                }
            }

            public void onFailure(Exception exc) {
                listener.onFailure(exc);
            }
        };
    }

    private void refreshInferenceEndpointCache() {
        this.client.execute(ClearInferenceEndpointCacheAction.INSTANCE, (ActionRequest)new ClearInferenceEndpointCacheAction.Request(), ActionListener.wrap(ignored -> logger.debug("Successfully refreshed inference endpoint cache."), e -> logger.atDebug().withThrowable((Throwable)e).log("Failed to refresh inference endpoint cache.")));
    }

    private static DeleteByQueryRequest createDeleteRequest(Set<String> inferenceEntityIds) {
        DeleteByQueryRequest request = (DeleteByQueryRequest)new DeleteByQueryRequest().setAbortOnVersionConflict(false);
        request.indices(new String[]{".inference*", ".secrets-inference*"});
        request.setQuery(ModelRegistry.documentIdsQuery(inferenceEntityIds));
        request.setRefresh(true);
        return request;
    }

    public static IndexRequestBuilder createIndexRequestBuilder(String inferenceId, String indexName, ToXContentObject body, boolean allowOverwriting, Client client) {
        IndexRequestBuilder indexRequestBuilder;
        block8: {
            XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
            try {
                XContentBuilder source = body.toXContent(xContentBuilder, (ToXContent.Params)new ToXContent.MapParams(Map.of("for_index", Boolean.TRUE.toString())));
                indexRequestBuilder = ((IndexRequestBuilder)new IndexRequestBuilder((ElasticsearchClient)client).setIndex(indexName)).setCreate(!allowOverwriting).setId(Model.documentId((String)inferenceId)).setSource(source);
                if (xContentBuilder == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (xContentBuilder != null) {
                        try {
                            xContentBuilder.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException ex) {
                    throw new ElasticsearchException("Unexpected serialization exception for index [{}] inference ID [{}]", (Throwable)ex, new Object[]{indexName, inferenceId});
                }
            }
            xContentBuilder.close();
        }
        return indexRequestBuilder;
    }

    private static UnparsedModel modelToUnparsedModel(Model model) {
        UnparsedModel unparsedModel;
        block8: {
            XContentBuilder builder = XContentFactory.jsonBuilder();
            try {
                model.getConfigurations().toXContent(builder, (ToXContent.Params)new ToXContent.MapParams(Map.of("for_index", Boolean.TRUE.toString())));
                Map modelConfigMap = (Map)XContentHelper.convertToMap((BytesReference)BytesReference.bytes((XContentBuilder)builder), (boolean)false, (XContentType)builder.contentType()).v2();
                unparsedModel = ModelRegistry.unparsedModelFromMap(new ModelConfigMap(modelConfigMap, new HashMap<String, Object>()));
                if (builder == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (builder != null) {
                        try {
                            builder.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException ex) {
                    throw new ElasticsearchException("[{}] Error serializing inference endpoint configuration", new Object[]{model.getInferenceEntityId(), ex});
                }
            }
            builder.close();
        }
        return unparsedModel;
    }

    private static QueryBuilder documentIdQuery(String inferenceEntityId) {
        return QueryBuilders.constantScoreQuery((QueryBuilder)QueryBuilders.idsQuery().addIds(new String[]{Model.documentId((String)inferenceEntityId)}));
    }

    private static QueryBuilder documentIdsQuery(Set<String> inferenceEntityIds) {
        String[] documentIdsArray = (String[])inferenceEntityIds.stream().map(Model::documentId).toArray(String[]::new);
        return QueryBuilders.constantScoreQuery((QueryBuilder)QueryBuilders.idsQuery().addIds(documentIdsArray));
    }

    static Optional<InferenceService.DefaultConfigId> idMatchedDefault(String inferenceId, List<InferenceService.DefaultConfigId> defaultConfigIds) {
        return defaultConfigIds.stream().filter(defaultConfigId -> defaultConfigId.inferenceId().equals(inferenceId)).findFirst();
    }

    static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(TaskType taskType, Collection<InferenceService.DefaultConfigId> defaultConfigIds) {
        return defaultConfigIds.stream().filter(defaultConfigId -> defaultConfigId.settings().taskType().equals((Object)taskType)).collect(Collectors.toList());
    }

    public void clusterChanged(ClusterChangedEvent event) {
        if (this.lastMetadata.get() == null || event.metadataChanged()) {
            this.lastMetadata.set(event.state().metadata());
        }
        if (!event.localNodeMaster()) {
            return;
        }
        if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) {
            return;
        }
        if (event.state().metadata().projects().size() > 1) {
            return;
        }
        ModelRegistryMetadata state = ModelRegistryMetadata.fromState(event.state().projectState().metadata());
        if (state.isUpgraded()) {
            return;
        }
        if (!this.upgradeMetadataInProgress.compareAndSet(false, true)) {
            return;
        }
        this.client.execute((ActionType)GetInferenceModelAction.INSTANCE, (ActionRequest)new GetInferenceModelAction.Request("*", TaskType.ANY, false), (ActionListener)new ActionListener<GetInferenceModelAction.Response>(){

            public void onResponse(GetInferenceModelAction.Response response) {
                HashMap<String, MinimalServiceSettings> map = new HashMap<String, MinimalServiceSettings>();
                for (ModelConfigurations model : response.getEndpoints()) {
                    if (ModelRegistry.this.defaultConfigIds.containsKey(model.getInferenceEntityId())) continue;
                    map.put(model.getInferenceEntityId(), new MinimalServiceSettings(model.getService(), model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity(), model.getServiceSettings().elementType()));
                }
                ModelRegistry.this.metadataTaskQueue.submitTask("model registry auto upgrade", (ClusterStateTaskListener)new UpgradeModelsMetadataTask(ProjectId.DEFAULT, map, (ActionListener<AcknowledgedResponse>)ActionListener.running(() -> ModelRegistry.this.upgradeMetadataInProgress.set(false))), null);
            }

            public void onFailure(Exception e) {
                ModelRegistry.this.upgradeMetadataInProgress.set(false);
            }
        });
    }

    public record ModelConfigMap(Map<String, Object> config, Map<String, Object> secrets) {
    }

    private record StoreResponseWithIndexInfo(ModelStoreResponse modelStoreResponse, boolean modifiedIndex) {
    }

    private record ResponseInfo(List<StoreResponseWithIndexInfo> responses, List<Model> successfullyStoredModels) {
    }

    private static class AddModelMetadataTask
    extends MetadataTask {
        private final List<ModelAndSettings> models = new ArrayList<ModelAndSettings>();

        AddModelMetadataTask(ProjectId projectId, List<ModelAndSettings> models, ActionListener<AcknowledgedResponse> listener) {
            super(projectId, listener);
            this.models.addAll(models);
        }

        @Override
        ModelRegistryMetadata executeTask(ModelRegistryMetadata current) {
            return current.withAddedModels(this.models);
        }
    }

    public record ModelAndSettings(String inferenceEntityId, MinimalServiceSettings settings) {
    }

    private static class DeleteModelMetadataTask
    extends MetadataTask {
        private final Set<String> inferenceEntityIds;

        DeleteModelMetadataTask(ProjectId projectId, Set<String> inferenceEntityId, ActionListener<AcknowledgedResponse> listener) {
            super(projectId, listener);
            this.inferenceEntityIds = inferenceEntityId;
        }

        @Override
        ModelRegistryMetadata executeTask(ModelRegistryMetadata current) {
            return current.withRemovedModel(this.inferenceEntityIds);
        }
    }

    private static class UpgradeModelsMetadataTask
    extends MetadataTask {
        private final Map<String, MinimalServiceSettings> fromIndex;

        UpgradeModelsMetadataTask(ProjectId projectId, Map<String, MinimalServiceSettings> fromIndex, ActionListener<AcknowledgedResponse> listener) {
            super(projectId, listener);
            this.fromIndex = fromIndex;
        }

        @Override
        ModelRegistryMetadata executeTask(ModelRegistryMetadata current) {
            return current.withUpgradedModels(this.fromIndex);
        }
    }

    private static abstract class MetadataTask
    extends AckedBatchedClusterStateUpdateTask {
        private final ProjectId projectId;

        MetadataTask(ProjectId projectId, ActionListener<AcknowledgedResponse> listener) {
            super(TimeValue.THIRTY_SECONDS, listener);
            this.projectId = projectId;
        }

        abstract ModelRegistryMetadata executeTask(ModelRegistryMetadata var1);

        public ProjectId getProjectId() {
            return this.projectId;
        }
    }
}

