/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.inference;

import java.util.List;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AsyncOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;

public abstract class InferenceOperator
extends AsyncOperator<OngoingInferenceResult> {
    private final String inferenceId;
    private final BlockFactory blockFactory;
    private final BulkInferenceExecutor bulkInferenceExecutor;

    public InferenceOperator(DriverContext driverContext, InferenceRunner inferenceRunner, BulkInferenceExecutionConfig bulkExecutionConfig, ThreadPool threadPool, String inferenceId) {
        super(driverContext, bulkExecutionConfig.workers());
        this.blockFactory = driverContext.blockFactory();
        this.bulkInferenceExecutor = new BulkInferenceExecutor(inferenceRunner, threadPool, bulkExecutionConfig);
        this.inferenceId = inferenceId;
    }

    protected BlockFactory blockFactory() {
        return this.blockFactory;
    }

    protected String inferenceId() {
        return this.inferenceId;
    }

    protected void performAsync(Page input, ActionListener<OngoingInferenceResult> listener) {
        try {
            BulkInferenceRequestIterator requests = this.requests(input);
            listener = ActionListener.releaseAfter(listener, (Releasable)requests);
            this.bulkInferenceExecutor.execute(requests, (ActionListener<List<InferenceAction.Response>>)listener.map(responses -> new OngoingInferenceResult(input, (List<InferenceAction.Response>)responses)));
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    protected void releaseFetchedOnAnyThread(OngoingInferenceResult ongoingInferenceResult) {
        Releasables.close((Releasable)ongoingInferenceResult);
    }

    public Page getOutput() {
        OngoingInferenceResult ongoingInferenceResult = (OngoingInferenceResult)this.fetchFromBuffer();
        if (ongoingInferenceResult == null) {
            return null;
        }
        try {
            Page page;
            block11: {
                OutputBuilder outputBuilder = this.outputBuilder(ongoingInferenceResult.inputPage);
                try {
                    for (InferenceAction.Response response : ongoingInferenceResult.responses) {
                        outputBuilder.addInferenceResponse(response);
                    }
                    page = outputBuilder.buildOutput();
                    if (outputBuilder == null) break block11;
                }
                catch (Throwable throwable) {
                    if (outputBuilder != null) {
                        try {
                            outputBuilder.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                outputBuilder.close();
            }
            return page;
        }
        finally {
            this.releaseFetchedOnAnyThread(ongoingInferenceResult);
        }
    }

    protected abstract BulkInferenceRequestIterator requests(Page var1);

    protected abstract OutputBuilder outputBuilder(Page var1);

    public record OngoingInferenceResult(Page inputPage, List<InferenceAction.Response> responses) implements Releasable
    {
        public void close() {
            InferenceOperator.releasePageOnAnyThread((Page)this.inputPage);
        }
    }

    public static interface OutputBuilder
    extends Releasable {
        public void addInferenceResponse(InferenceAction.Response var1);

        public Page buildOutput();

        public static <IR extends InferenceServiceResults> IR inferenceResults(InferenceAction.Response inferenceResponse, Class<IR> clazz) {
            InferenceServiceResults results = inferenceResponse.getResults();
            if (clazz.isInstance(results)) {
                return (IR)((InferenceServiceResults)clazz.cast(results));
            }
            throw new IllegalStateException(LoggerMessageFormat.format((String)"Inference result has wrong type. Got [{}] while expecting [{}]", (String)results.getClass().getName(), (Object[])new Object[]{clazz.getName()}));
        }

        default public void releasePageOnAnyThread(Page page) {
            InferenceOperator.releasePageOnAnyThread((Page)page);
        }
    }
}

