/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.inference.rerank;

import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
import org.elasticsearch.xpack.esql.inference.InferenceService;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunner;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRunnerConfig;
import org.elasticsearch.xpack.esql.inference.rerank.RerankOperatorOutputBuilder;
import org.elasticsearch.xpack.esql.inference.rerank.RerankOperatorRequestIterator;

public class RerankOperator
extends InferenceOperator {
    private static final int DEFAULT_BATCH_SIZE = 20;
    private final String queryText;
    private final EvalOperator.ExpressionEvaluator rowEncoder;
    private final int scoreChannel;
    private final int batchSize = 20;

    public RerankOperator(DriverContext driverContext, BulkInferenceRunner bulkInferenceRunner, String inferenceId, String queryText, EvalOperator.ExpressionEvaluator rowEncoder, int scoreChannel, int maxOutstandingPages) {
        super(driverContext, bulkInferenceRunner, inferenceId, maxOutstandingPages);
        this.queryText = queryText;
        this.rowEncoder = rowEncoder;
        this.scoreChannel = scoreChannel;
    }

    protected void doClose() {
        Releasables.close((Releasable)this.rowEncoder);
    }

    public String toString() {
        return "RerankOperator[inference_id=[" + this.inferenceId() + "], query=[" + this.queryText + "], score_channel=[" + this.scoreChannel + "]]";
    }

    @Override
    protected RerankOperatorRequestIterator requests(Page inputPage) {
        return new RerankOperatorRequestIterator((BytesRefBlock)this.rowEncoder.eval(inputPage), this.inferenceId(), this.queryText, 20);
    }

    @Override
    protected RerankOperatorOutputBuilder outputBuilder(Page input) {
        DoubleBlock.Builder outputBlockBuilder = this.blockFactory().newDoubleBlockBuilder(input.getPositionCount());
        return new RerankOperatorOutputBuilder(outputBlockBuilder, input, this.scoreChannel);
    }

    public record Factory(InferenceService inferenceService, String inferenceId, String queryText, EvalOperator.ExpressionEvaluator.Factory rowEncoderFactory, int scoreChannel) implements Operator.OperatorFactory
    {
        public String describe() {
            return "RerankOperator[inference_id=[" + this.inferenceId + "], query=[" + this.queryText + "], score_channel=[" + this.scoreChannel + "]]";
        }

        public Operator get(DriverContext driverContext) {
            return new RerankOperator(driverContext, this.inferenceService.bulkInferenceRunner(), this.inferenceId, this.queryText, this.rowEncoderFactory.get(driverContext), this.scoreChannel, BulkInferenceRunnerConfig.DEFAULT.maxOutstandingRequests());
        }
    }
}

