/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;
import org.elasticsearch.xpack.esql.optimizer.LogicalPreOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.LogicalPlanPreOptimizerRule;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;

public class FoldInferenceFunctions
implements LogicalPlanPreOptimizerRule {
    private final InferenceFunctionEvaluator inferenceFunctionEvaluator;

    public FoldInferenceFunctions(LogicalPreOptimizerContext preOptimizerContext) {
        this(InferenceFunctionEvaluator.factory().create(preOptimizerContext.foldCtx(), preOptimizerContext.inferenceService()));
    }

    protected FoldInferenceFunctions(InferenceFunctionEvaluator inferenceFunctionEvaluator) {
        this.inferenceFunctionEvaluator = inferenceFunctionEvaluator;
    }

    @Override
    public void apply(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
        this.foldInferenceFunctions(plan, listener);
    }

    private void foldInferenceFunctions(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
        List<InferenceFunction<?>> inferenceFunctions = this.collectFoldableInferenceFunctions(plan);
        if (inferenceFunctions.isEmpty()) {
            listener.onResponse((Object)plan);
            return;
        }
        HashMap inferenceFunctionsToResults = new HashMap();
        CountDownActionListener completionListener = new CountDownActionListener(inferenceFunctions.size(), listener.delegateFailureIgnoreResponseAndWrap(l -> {
            LogicalPlan transformedPlan = (LogicalPlan)((Object)((Object)plan.transformExpressionsUp(InferenceFunction.class, f -> (Expression)inferenceFunctionsToResults.getOrDefault(f, f))));
            this.foldInferenceFunctions(transformedPlan, (ActionListener<LogicalPlan>)l);
        }));
        for (InferenceFunction<?> inferenceFunction : inferenceFunctions) {
            this.inferenceFunctionEvaluator.fold(inferenceFunction, (ActionListener<Expression>)completionListener.delegateFailureAndWrap((l, result) -> {
                inferenceFunctionsToResults.put(inferenceFunction, result);
                l.onResponse(null);
            }));
        }
    }

    private List<InferenceFunction<?>> collectFoldableInferenceFunctions(LogicalPlan plan) {
        ArrayList inferenceFunctions = new ArrayList();
        plan.forEachExpressionUp(InferenceFunction.class, f -> {
            if (f.foldable() && !f.hasNestedInferenceFunction()) {
                inferenceFunctions.add((InferenceFunction<?>)((Object)f));
            }
        });
        return inferenceFunctions;
    }
}

