/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.inference;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
import org.elasticsearch.xpack.esql.inference.ResolvedInference;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;

public class InferenceRunner {
    private final Client client;
    private final ThreadPool threadPool;

    public InferenceRunner(Client client, ThreadPool threadPool) {
        this.client = client;
        this.threadPool = threadPool;
    }

    public ThreadPool threadPool() {
        return this.threadPool;
    }

    public void resolveInferenceIds(List<InferencePlan<?>> plans, ActionListener<InferenceResolution> listener) {
        this.resolveInferenceIds(plans.stream().map(InferenceRunner::planInferenceId).collect(Collectors.toSet()), listener);
    }

    private void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResolution> listener) {
        if (inferenceIds.isEmpty()) {
            listener.onResponse((Object)InferenceResolution.EMPTY);
            return;
        }
        InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();
        CountDownActionListener countdownListener = new CountDownActionListener(inferenceIds.size(), ActionListener.wrap(_r -> listener.onResponse((Object)inferenceResolutionBuilder.build()), arg_0 -> listener.onFailure(arg_0)));
        for (String inferenceId : inferenceIds) {
            this.client.execute((ActionType)GetInferenceModelAction.INSTANCE, (ActionRequest)new GetInferenceModelAction.Request(inferenceId, TaskType.ANY), ActionListener.wrap(r -> {
                ResolvedInference resolvedInference = new ResolvedInference(inferenceId, ((ModelConfigurations)r.getEndpoints().getFirst()).getTaskType());
                inferenceResolutionBuilder.withResolvedInference(resolvedInference);
                countdownListener.onResponse(null);
            }, e -> {
                inferenceResolutionBuilder.withError(inferenceId, e.getMessage());
                countdownListener.onResponse(null);
            }));
        }
    }

    private static String planInferenceId(InferencePlan<?> plan) {
        return BytesRefs.toString((Object)plan.inferenceId().fold(FoldContext.small()));
    }

    public void doInference(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"inference", (ActionType)InferenceAction.INSTANCE, (ActionRequest)request, listener);
    }
}

