/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.inference;

import java.lang.runtime.SwitchBootstraps;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.indices.breaker.AllCircuitBreakerStats;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.indices.breaker.CircuitBreakerStats;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
import org.elasticsearch.xpack.esql.inference.InferenceService;
import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingOperator;

public class InferenceFunctionEvaluator {
    private static final Factory FACTORY = new Factory();
    private final FoldContext foldContext;
    private final InferenceOperatorProvider inferenceOperatorProvider;

    public static Factory factory() {
        return FACTORY;
    }

    InferenceFunctionEvaluator(FoldContext foldContext, InferenceOperatorProvider inferenceOperatorProvider) {
        this.foldContext = foldContext;
        this.inferenceOperatorProvider = inferenceOperatorProvider;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void fold(InferenceFunction<?> f, ActionListener<Expression> listener) {
        if (!f.foldable()) {
            listener.onFailure((Exception)new IllegalArgumentException("Inference function must be foldable"));
            return;
        }
        if (f.dataType() == DataType.NULL) {
            listener.onResponse((Object)Literal.of(f, null));
            return;
        }
        NoopCircuitBreaker breaker = new NoopCircuitBreaker("request");
        BigArrays bigArrays = new BigArrays(null, new CircuitBreakerService(this, (CircuitBreaker)breaker){
            final /* synthetic */ CircuitBreaker val$breaker;
            {
                this.val$breaker = circuitBreaker;
            }

            public CircuitBreaker getBreaker(String name) {
                if (!name.equals("request")) {
                    throw new UnsupportedOperationException("Only REQUEST circuit breaker is supported");
                }
                return this.val$breaker;
            }

            public AllCircuitBreakerStats stats() {
                throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context");
            }

            public CircuitBreakerStats stats(String name) {
                throw new UnsupportedOperationException("Circuit breaker stats not supported in fold context");
            }
        }, "request").withCircuitBreaking();
        DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory((CircuitBreaker)breaker, bigArrays));
        try {
            Operator inferenceOperator = this.inferenceOperatorProvider.getOperator(f, driverContext);
            try {
                inferenceOperator.addInput(new Page(1, new Block[0]));
                driverContext.waitForAsyncActions(listener.delegateFailureIgnoreResponseAndWrap(l -> {
                    try {
                        Page output = inferenceOperator.getOutput();
                        if (output == null) {
                            l.onFailure((Exception)new IllegalStateException("Expected output page from inference operator"));
                            return;
                        }
                        output.allowPassingToDifferentDriver();
                        l = ActionListener.releaseBefore((Releasable)output, (ActionListener)l);
                        if (output.getPositionCount() != 1 || output.getBlockCount() != 1) {
                            l.onFailure((Exception)new IllegalStateException("Expected a single block with a single value from inference operator"));
                            return;
                        }
                        l.onResponse((Object)Literal.of((Expression)f, (Object)this.processValue(f.dataType(), BlockUtils.toJavaObject((Block)output.getBlock(0), (int)0))));
                    }
                    catch (Exception e) {
                        l.onFailure(e);
                    }
                    finally {
                        Releasables.close((Releasable)inferenceOperator);
                    }
                }));
            }
            catch (Exception e) {
                Releasables.close((Releasable)inferenceOperator);
                listener.onFailure(e);
            }
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
        finally {
            driverContext.finish();
        }
    }

    private Object processValue(DataType dataType, Object value) {
        if (dataType == DataType.DENSE_VECTOR && !(value instanceof List)) {
            value = List.of(value);
        }
        return value;
    }

    public static class Factory {
        private Factory() {
        }

        public InferenceFunctionEvaluator create(FoldContext foldContext, InferenceService inferenceService) {
            return new InferenceFunctionEvaluator(foldContext, this.createInferenceOperatorProvider(foldContext, inferenceService));
        }

        private InferenceOperatorProvider createInferenceOperatorProvider(FoldContext foldContext, InferenceService inferenceService) {
            return (arg_0, arg_1) -> this.lambda$createInferenceOperatorProvider$0(inferenceService, foldContext, arg_0, arg_1);
        }

        private String inferenceId(InferenceFunction<?> f, FoldContext foldContext) {
            return BytesRefs.toString((Object)f.inferenceId().fold(foldContext));
        }

        private EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory(Expression e, FoldContext foldContext) {
            assert (e.foldable()) : "Input expression must be foldable";
            return EvalMapper.toEvaluator(foldContext, (Expression)Literal.of((FoldContext)foldContext, (Expression)e), null);
        }

        /*
         * Unable to fully structure code
         */
        private /* synthetic */ Operator lambda$createInferenceOperatorProvider$0(InferenceService inferenceService, FoldContext foldContext, InferenceFunction inferenceFunction, DriverContext driverContext) {
            v0 = inferenceFunction;
            Objects.requireNonNull(v0);
            selector0$temp = v0;
            index$1 = 0;
            switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{TextEmbedding.class}, (Object)selector0$temp, index$1)) {
                case 0: {
                    ** break;
                }
                default: {
                    throw new IllegalArgumentException("Unknown inference function: " + inferenceFunction.getClass().getName());
                }
lbl-1000:
                // 1 sources

                {
                    textEmbedding = (TextEmbedding)selector0$temp;
                    operatorFactory = new TextEmbeddingOperator.Factory(inferenceService, this.inferenceId(inferenceFunction, foldContext), this.expressionEvaluatorFactory(textEmbedding.inputText(), foldContext));
                }
            }
            return operatorFactory.get(driverContext);
        }
    }

    static interface InferenceOperatorProvider {
        public Operator getOperator(InferenceFunction<?> var1, DriverContext var2);
    }
}

