/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.IndicesRequest;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ProjectState;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.cluster.project.ProjectResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;

public class TransportGetInferenceFieldsAction
extends HandledTransportAction<GetInferenceFieldsAction.Request, GetInferenceFieldsAction.Response> {
    private final TransportService transportService;
    private final ClusterService clusterService;
    private final ProjectResolver projectResolver;
    private final IndexNameExpressionResolver indexNameExpressionResolver;
    private final Client client;

    @Inject
    public TransportGetInferenceFieldsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, ProjectResolver projectResolver, IndexNameExpressionResolver indexNameExpressionResolver, Client client) {
        super("cluster:internal/xpack/inference/fields/get", transportService, actionFilters, GetInferenceFieldsAction.Request::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.transportService = transportService;
        this.clusterService = clusterService;
        this.projectResolver = projectResolver;
        this.indexNameExpressionResolver = indexNameExpressionResolver;
        this.client = client;
    }

    protected void doExecute(Task task, GetInferenceFieldsAction.Request request, ActionListener<GetInferenceFieldsAction.Response> listener) {
        Set indices = request.getIndices();
        Map fields = request.getFields();
        boolean resolveWildcards = request.resolveWildcards();
        boolean useDefaultFields = request.useDefaultFields();
        String query = request.getQuery();
        IndicesOptions indicesOptions = request.getIndicesOptions();
        try {
            Map groupedIndices = this.transportService.getRemoteClusterService().groupIndices(indicesOptions, indices.toArray(new String[0]), true);
            OriginalIndices localIndices = (OriginalIndices)groupedIndices.remove("");
            if (!groupedIndices.isEmpty()) {
                throw new IllegalArgumentException("GetInferenceFieldsAction does not support remote indices");
            }
            ProjectState projectState = this.projectResolver.getProjectState(this.clusterService.state());
            String[] concreteLocalIndices = this.indexNameExpressionResolver.concreteIndexNames(projectState.metadata(), (IndicesRequest)localIndices);
            HashMap<String, List<GetInferenceFieldsAction.ExtendedInferenceFieldMetadata>> inferenceFieldsMap = new HashMap<String, List<GetInferenceFieldsAction.ExtendedInferenceFieldMetadata>>(concreteLocalIndices.length);
            Arrays.stream(concreteLocalIndices).forEach(index -> {
                List<GetInferenceFieldsAction.ExtendedInferenceFieldMetadata> inferenceFieldMetadataList = this.getInferenceFieldMetadata((String)index, fields, resolveWildcards, useDefaultFields);
                inferenceFieldsMap.put((String)index, inferenceFieldMetadataList);
            });
            if (query != null && !query.isBlank()) {
                Set<String> inferenceIds = inferenceFieldsMap.values().stream().flatMap(Collection::stream).map(eifm -> eifm.inferenceFieldMetadata().getSearchInferenceId()).collect(Collectors.toSet());
                this.getInferenceResults(query, inferenceIds, inferenceFieldsMap, listener);
            } else {
                listener.onResponse((Object)new GetInferenceFieldsAction.Response(inferenceFieldsMap, Map.of()));
            }
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private List<GetInferenceFieldsAction.ExtendedInferenceFieldMetadata> getInferenceFieldMetadata(String index, Map<String, Float> fields, boolean resolveWildcards, boolean useDefaultFields) {
        IndexMetadata indexMetadata = (IndexMetadata)this.projectResolver.getProjectMetadata(this.clusterService.state()).indices().get(index);
        if (indexMetadata == null) {
            throw new IndexNotFoundException(index);
        }
        Map matchingInferenceFieldMap = indexMetadata.getMatchingInferenceFields(fields, resolveWildcards, useDefaultFields);
        return matchingInferenceFieldMap.entrySet().stream().map(e -> new GetInferenceFieldsAction.ExtendedInferenceFieldMetadata((InferenceFieldMetadata)e.getKey(), ((Float)e.getValue()).floatValue())).toList();
    }

    private void getInferenceResults(String query, Set<String> inferenceIds, Map<String, List<GetInferenceFieldsAction.ExtendedInferenceFieldMetadata>> inferenceFieldsMap, ActionListener<GetInferenceFieldsAction.Response> listener) {
        if (inferenceIds.isEmpty()) {
            listener.onResponse((Object)new GetInferenceFieldsAction.Response(inferenceFieldsMap, Map.of()));
            return;
        }
        GroupedActionListener gal = new GroupedActionListener(inferenceIds.size(), listener.delegateFailureAndWrap((l, c) -> {
            HashMap inferenceResultsMap = new HashMap(inferenceIds.size());
            c.forEach(t -> inferenceResultsMap.put((String)t.v1(), (InferenceResults)t.v2()));
            GetInferenceFieldsAction.Response response = new GetInferenceFieldsAction.Response(inferenceFieldsMap, inferenceResultsMap);
            l.onResponse((Object)response);
        }));
        List<InferenceAction.Request> inferenceRequests = inferenceIds.stream().map(i -> new InferenceAction.Request(TaskType.ANY, i, null, null, null, List.of(query), Map.of(), InputType.INTERNAL_SEARCH, null, false)).toList();
        inferenceRequests.forEach(request -> ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)InferenceAction.INSTANCE, (ActionRequest)request, (ActionListener)gal.delegateFailureAndWrap((l, r) -> {
            String inferenceId = request.getInferenceEntityId();
            InferenceResults inferenceResults = TransportGetInferenceFieldsAction.validateAndConvertInferenceResults(r.getResults(), inferenceId);
            l.onResponse((Object)Tuple.tuple((Object)inferenceId, (Object)inferenceResults));
        })));
    }

    private static InferenceResults validateAndConvertInferenceResults(InferenceServiceResults inferenceServiceResults, String inferenceId) {
        List inferenceResultsList = inferenceServiceResults.transformToCoordinationFormat();
        if (inferenceResultsList.isEmpty()) {
            return new ErrorInferenceResults((Exception)new IllegalArgumentException("No inference results retrieved for inference ID [" + inferenceId + "]"));
        }
        if (inferenceResultsList.size() > 1) {
            return new ErrorInferenceResults((Exception)new IllegalStateException(inferenceResultsList.size() + " inference results retrieved for inference ID [" + inferenceId + "]"));
        }
        return (InferenceResults)inferenceResultsList.getFirst();
    }
}

