/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.inference;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition;
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
import org.elasticsearch.xpack.esql.inference.InferenceResolution;
import org.elasticsearch.xpack.esql.inference.ResolvedInference;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;

public class InferenceResolver {
    private final Client client;
    private final EsqlFunctionRegistry functionRegistry;
    private final ThreadPool threadPool;

    public InferenceResolver(Client client, EsqlFunctionRegistry functionRegistry, ThreadPool threadPool) {
        this.client = client;
        this.functionRegistry = functionRegistry;
        this.threadPool = threadPool;
    }

    public void resolveInferenceIds(LogicalPlan plan, ActionListener<InferenceResolution> listener) {
        this.resolveInferenceIds(this.collectInferenceIds(plan), listener);
    }

    List<String> collectInferenceIds(LogicalPlan plan) {
        ArrayList<String> inferenceIds = new ArrayList<String>();
        this.collectInferenceIdsFromInferencePlans(plan, inferenceIds::add);
        this.collectInferenceIdsFromInferenceFunctions(plan, inferenceIds::add);
        return inferenceIds;
    }

    void resolveInferenceIds(List<String> inferenceIds, ActionListener<InferenceResolution> listener) {
        this.resolveInferenceIds(Set.copyOf(inferenceIds), listener);
    }

    void resolveInferenceIds(Set<String> inferenceIds, ActionListener<InferenceResolution> listener) {
        if (inferenceIds.isEmpty()) {
            listener.onResponse((Object)InferenceResolution.EMPTY);
            return;
        }
        InferenceResolution.Builder inferenceResolutionBuilder = InferenceResolution.builder();
        CountDownActionListener countdownListener = new CountDownActionListener(inferenceIds.size(), listener.delegateFailureIgnoreResponseAndWrap(l -> l.onResponse((Object)inferenceResolutionBuilder.build())));
        for (String inferenceId : inferenceIds) {
            this.client.execute((ActionType)GetInferenceModelAction.INSTANCE, (ActionRequest)new GetInferenceModelAction.Request(inferenceId, TaskType.ANY), (ActionListener)new ThreadedActionListener((Executor)this.threadPool.executor("search_coordination"), ActionListener.wrap(r -> {
                ResolvedInference resolvedInference = new ResolvedInference(inferenceId, ((ModelConfigurations)r.getEndpoints().getFirst()).getTaskType());
                inferenceResolutionBuilder.withResolvedInference(resolvedInference);
                countdownListener.onResponse(null);
            }, e -> {
                inferenceResolutionBuilder.withError(inferenceId, e.getMessage());
                countdownListener.onResponse(null);
            })));
        }
    }

    private void collectInferenceIdsFromInferencePlans(LogicalPlan plan, Consumer<String> c) {
        plan.forEachUp(InferencePlan.class, inferencePlan -> c.accept(InferenceResolver.inferenceId(inferencePlan)));
    }

    private void collectInferenceIdsFromInferenceFunctions(LogicalPlan plan, Consumer<String> c) {
        EsqlFunctionRegistry snapshotRegistry = this.functionRegistry.snapshotRegistry();
        plan.forEachExpressionUp(UnresolvedFunction.class, f -> {
            String inferenceId;
            FunctionDefinition def;
            String functionName = snapshotRegistry.resolveAlias(f.name());
            if (snapshotRegistry.functionExists(functionName) && InferenceFunction.class.isAssignableFrom((def = snapshotRegistry.resolveFunction(functionName)).clazz()) && (inferenceId = InferenceResolver.inferenceId(f, def)) != null) {
                c.accept(inferenceId);
            }
        });
    }

    private static String inferenceId(InferencePlan<?> plan) {
        return BytesRefs.toString((Object)plan.inferenceId().fold(FoldContext.small()));
    }

    private static String inferenceId(UnresolvedFunction f, FunctionDefinition def) {
        EsqlFunctionRegistry.FunctionDescription functionDescription = EsqlFunctionRegistry.description(def);
        for (int i = 0; i < functionDescription.args().size(); ++i) {
            Expression inferenceId;
            EsqlFunctionRegistry.ArgSignature arg = functionDescription.args().get(i);
            if (i >= f.arguments().size()) {
                return null;
            }
            if (!arg.name().equals("inference_id") || (inferenceId = f.arguments().get(i)) == null || !inferenceId.foldable() || !DataType.isString(inferenceId.dataType())) continue;
            return BytesRefs.toString((Object)inferenceId.fold(FoldContext.small()));
        }
        return null;
    }

    public static Factory factory(Client client) {
        return new Factory(client, client.threadPool());
    }

    public static class Factory {
        private final Client client;
        private final ThreadPool threadPool;

        private Factory(Client client, ThreadPool threadPool) {
            this.client = client;
            this.threadPool = threadPool;
        }

        public InferenceResolver create(EsqlFunctionRegistry functionRegistry) {
            return new InferenceResolver(this.client, functionRegistry, this.threadPool);
        }
    }
}

