/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.scalar.string;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.xpack.core.common.chunks.MemoryIndexChunkScorer;
import org.elasticsearch.xpack.core.inference.chunking.SentenceBoundaryChunkingSettings;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.MapParam;
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Options;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.TopSnippetsBytesRefEvaluator;
import org.elasticsearch.xpack.esql.expression.function.scalar.util.ChunkUtils;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;

public class TopSnippets
extends EsqlScalarFunction
implements OptionalArgument {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "TopSnippets", TopSnippets::new);
    static final int DEFAULT_NUM_SNIPPETS = 5;
    static final int DEFAULT_WORD_SIZE = 300;
    private final Expression field;
    private final Expression query;
    private final Expression options;
    private static final String NUM_SNIPPETS = "num_snippets";
    private static final String NUM_WORDS = "num_words";
    public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(Map.entry("num_snippets", DataType.INTEGER), Map.entry("num_words", DataType.INTEGER));

    @FunctionInfo(appliesTo={@FunctionAppliesTo(lifeCycle=FunctionAppliesToLifecycle.PREVIEW, version="9.3.0")}, returnType={"keyword"}, preview=true, description="Use `TOP_SNIPPETS` to extract the best snippets for a given query string from a text field.", detailedDescription="    TopSnippets can be used on fields from the text famiy like <<text, text>> and <<semantic-text, semantic_text>>.\n    TopSnippets will extract the best snippets for a given query string.\n", examples={@Example(file="top-snippets", tag="top-snippets-with-field", applies_to="stack: preview 9.3.0"), @Example(file="top-snippets", tag="top-snippets-with-options", applies_to="stack: preview 9.3.0")})
    public TopSnippets(Source source, @Param(name="field", type={"keyword", "text"}, description="The input to chunk.") Expression field, @Param(name="query", type={"keyword"}, description="The input text containing only query terms for snippet extraction.\nLucene query syntax, operators, and wildcards are not allowed.\n") Expression query, @MapParam(name="options", description="(Optional) TopSnippets additional options as [function named parameters](/reference/query-languages/esql/esql-syntax.md#esql-function-named-params).", optional=true, params={@MapParam.MapParamEntry(name="num_snippets", type={"integer"}, description="The maximum number of matching snippets to return.", valueHint={"3"}), @MapParam.MapParamEntry(name="num_words", type={"integer"}, description="The maximum number of words to return in each snippet.\nThis allows better control of inference costs by limiting the size of tokens per snippet.\n", valueHint={"300"})}) Expression options) {
        super(source, options == null ? List.of(field, query) : List.of(field, query, options));
        this.field = field;
        this.query = query;
        this.options = options;
    }

    public TopSnippets(StreamInput in) throws IOException {
        this(Source.readFrom((PlanStreamInput)in), (Expression)in.readNamedWriteable(Expression.class), (Expression)in.readNamedWriteable(Expression.class), (Expression)in.readOptionalNamedWriteable(Expression.class));
    }

    public void writeTo(StreamOutput out) throws IOException {
        this.source().writeTo(out);
        out.writeNamedWriteable((NamedWriteable)this.field);
        out.writeNamedWriteable((NamedWriteable)this.query);
        out.writeOptionalNamedWriteable((NamedWriteable)this.options);
    }

    public String getWriteableName() {
        return TopSnippets.ENTRY.name;
    }

    @Override
    public DataType dataType() {
        return DataType.KEYWORD;
    }

    @Override
    protected Expression.TypeResolution resolveType() {
        if (!this.childrenResolved()) {
            return new Expression.TypeResolution("Unresolved children");
        }
        return TypeResolutions.isString(this.field(), this.sourceText(), TypeResolutions.ParamOrdinal.FIRST).and(() -> TypeResolutions.isString(this.query(), this.sourceText(), TypeResolutions.ParamOrdinal.SECOND)).and(() -> Options.resolve(this.options(), this.source(), TypeResolutions.ParamOrdinal.THIRD, ALLOWED_OPTIONS)).and(this::validateOptions);
    }

    private Expression.TypeResolution validateOptions() {
        if (this.options() == null) {
            return Expression.TypeResolution.TYPE_RESOLVED;
        }
        MapExpression optionsMap = (MapExpression)this.options();
        return this.validateOptionValueIsPositiveInteger(optionsMap, NUM_SNIPPETS).and(this.validateOptionValueIsPositiveInteger(optionsMap, NUM_WORDS));
    }

    private Expression.TypeResolution validateOptionValueIsPositiveInteger(MapExpression optionsMap, String paramName) {
        Object value;
        Expression expr = optionsMap.keyFoldedMap().get(paramName);
        if (expr != null && (value = expr.fold(FoldContext.small())) != null && ((Number)value).intValue() <= 0) {
            return new Expression.TypeResolution("'" + paramName + "' option must be a positive integer, found [" + ((Number)value).intValue() + "]");
        }
        return Expression.TypeResolution.TYPE_RESOLVED;
    }

    @Override
    public boolean foldable() {
        return this.field().foldable() && this.query().foldable() && (this.options() == null || this.options().foldable());
    }

    @Override
    public Expression replaceChildren(List<Expression> newChildren) {
        return new TopSnippets(this.source(), newChildren.get(0), newChildren.get(1), newChildren.size() > 2 ? newChildren.get(2) : null);
    }

    @Override
    protected NodeInfo<? extends Expression> info() {
        return NodeInfo.create(this, TopSnippets::new, this.field, this.query, this.options);
    }

    Expression field() {
        return this.field;
    }

    Expression query() {
        return this.query;
    }

    Expression options() {
        return this.options;
    }

    private int numSnippets() {
        return this.extractIntegerOption(NUM_SNIPPETS, 5);
    }

    private int numWords() {
        return this.extractIntegerOption(NUM_WORDS, 300);
    }

    private int extractIntegerOption(String option, int defaultValue) {
        if (this.options == null) {
            return defaultValue;
        }
        MapExpression optionsMap = (MapExpression)this.options;
        Expression expr = optionsMap.keyFoldedMap().get(option);
        if (expr == null) {
            return defaultValue;
        }
        Object value = expr.fold(FoldContext.small());
        return value != null ? ((Number)value).intValue() : defaultValue;
    }

    static void process(BytesRefBlock.Builder builder, BytesRef str, BytesRef query, ChunkingSettings chunkingSettings, MemoryIndexChunkScorer scorer, int numSnippets) {
        String content = str.utf8ToString();
        String queryString = query.utf8ToString();
        List<String> chunks = ChunkUtils.chunkText(content, chunkingSettings);
        List scoredChunks = scorer.scoreChunks(chunks, queryString, numSnippets, false);
        List<String> snippets = scoredChunks.stream().map(MemoryIndexChunkScorer.ScoredChunk::content).limit(numSnippets).toList();
        ChunkUtils.emitChunks(builder, snippets);
    }

    @Override
    public boolean equals(Object o) {
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        TopSnippets chunk = (TopSnippets)o;
        return Objects.equals(this.field, chunk.field) && Objects.equals(this.query, chunk.query) && Objects.equals(this.options, chunk.options);
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.field, this.query, this.options);
    }

    @Override
    public EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
        int numSnippets = this.numSnippets();
        int numWords = this.numWords();
        SentenceBoundaryChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(Integer.valueOf(numWords), Integer.valueOf(0));
        MemoryIndexChunkScorer scorer = new MemoryIndexChunkScorer();
        return new TopSnippetsBytesRefEvaluator.Factory(this.source(), toEvaluator.apply(this.field), toEvaluator.apply(this.query), (ChunkingSettings)chunkingSettings, scorer, numSnippets);
    }
}

