/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.function.vector;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware;
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.Check;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.MapParam;
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Options;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.fulltext.SingleFieldFullTextFunction;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery;

public class Knn
extends SingleFieldFullTextFunction
implements OptionalArgument,
VectorFunction,
PostOptimizationVerificationAware {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
    private final transient Integer implicitK;
    private final List<Expression> filterExpressions;
    public static final String MIN_CANDIDATES_OPTION = "min_candidates";
    public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(Map.entry(KnnVectorQueryBuilder.K_FIELD.getPreferredName(), DataType.INTEGER), Map.entry("min_candidates", DataType.INTEGER), Map.entry(KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD.getPreferredName(), DataType.FLOAT), Map.entry(KnnVectorQueryBuilder.VISIT_PERCENTAGE_FIELD.getPreferredName(), DataType.FLOAT), Map.entry(AbstractQueryBuilder.BOOST_FIELD.getPreferredName(), DataType.FLOAT), Map.entry("rescore_oversample", DataType.FLOAT));

    @FunctionInfo(returnType={"boolean"}, preview=true, description="Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors or semantic_text fields.", examples={@Example(file="knn-function", tag="knn-function")}, appliesTo={@FunctionAppliesTo(lifeCycle=FunctionAppliesToLifecycle.PREVIEW, version="9.2.0")})
    public Knn(Source source, @Param(name="field", type={"dense_vector", "text"}, description="Field that the query will target. knn function can be used with dense_vector or semantic_text fields. Other text fields are not allowed") Expression field, @Param(name="query", type={"dense_vector"}, description="Vector value to find top nearest neighbours for.") Expression query, @MapParam(name="options", params={@MapParam.MapParamEntry(name="k", type={"integer"}, valueHint={"10"}, description="The number of nearest neighbors to return from each shard. Elasticsearch collects k results from each shard, then merges them to find the global top results. This value must be less than or equal to num_candidates. This value is automatically set with any LIMIT applied to the function."), @MapParam.MapParamEntry(name="boost", type={"float"}, valueHint={"2.5"}, description="Floating point number used to decrease or increase the relevance scores of the query.Defaults to 1.0."), @MapParam.MapParamEntry(name="min_candidates", type={"integer"}, valueHint={"10"}, description="The minimum number of nearest neighbor candidates to consider per shard while doing knn search.  KNN may use a higher number of candidates in case the query can't use a approximate results. Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. Defaults to 1.5 * k (or LIMIT) used for the query."), @MapParam.MapParamEntry(name="visit_percentage", type={"float"}, valueHint={"10"}, description="The percentage of vectors to explore per shard while doing knn search with bbq_disk. Must be between 0 and 100. 0 will default to using num_candidates for calculating the percent visited. Increasing visit_percentage tends to improve the accuracy of the final results. If visit_percentage is set for bbq_disk, num_candidates is ignored. Defaults to ~1% per shard for every 1 million vectors"), @MapParam.MapParamEntry(name="similarity", type={"double"}, valueHint={"0.01"}, description="The minimum similarity required for a document to be considered a match. The similarity value calculated relates to the raw similarity used, not the document score."), @MapParam.MapParamEntry(name="rescore_oversample", type={"double"}, valueHint={"3.5"}, description="Applies the specified oversampling for rescoring quantized vectors. See [oversampling and rescoring quantized vectors](docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details.")}, description="(Optional) kNN additional options as <<esql-function-named-params,function named parameters>>. See [knn query](/reference/query-languages/query-dsl/query-dsl-knn-query.md) for more information.", optional=true) Expression options) {
        this(source, field, query, options, null, null, List.of());
    }

    public Knn(Source source, Expression field, Expression query, Expression options, Integer implicitK, QueryBuilder queryBuilder, List<Expression> filterExpressions) {
        super(source, field, query, options, Knn.expressionList(field, query, options), queryBuilder);
        this.implicitK = implicitK;
        this.filterExpressions = filterExpressions;
    }

    private static List<Expression> expressionList(Expression field, Expression query, Expression options) {
        ArrayList<Expression> result = new ArrayList<Expression>();
        result.add(field);
        result.add(query);
        if (options != null) {
            result.add(options);
        }
        return result;
    }

    public Integer implicitK() {
        return this.implicitK;
    }

    public List<Expression> filterExpressions() {
        return this.filterExpressions;
    }

    public Knn withImplicitK(Integer k) {
        Check.notNull((Object)k, (String)"k must not be null");
        return new Knn(this.source(), this.field(), this.query(), this.options(), k, this.queryBuilder(), this.filterExpressions());
    }

    @Override
    public List<Number> queryAsObject() {
        Expression query = this.query();
        if (query instanceof Literal) {
            Literal literal = (Literal)query;
            List result = (List)literal.value();
            return result;
        }
        throw new EsqlIllegalArgumentException(LoggerMessageFormat.format(null, (String)"Query value must be a list of numbers in [{}], found [{}]", (Object[])new Object[]{this.source(), query}));
    }

    @Override
    public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
        return new Knn(this.source(), this.field(), this.query(), this.options(), this.implicitK(), queryBuilder, this.filterExpressions());
    }

    @Override
    public TranslationAware.Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
        TranslationAware.Translatable translatable = super.translatable(pushdownPredicates);
        for (Expression filterExpression : this.filterExpressions()) {
            translatable = translatable.merge(TranslationAware.translatable(filterExpression, pushdownPredicates));
        }
        return translatable;
    }

    @Override
    protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
        assert (this.implicitK() != null) : "Knn function must have a k value set before translation";
        FieldAttribute fieldAttribute = this.fieldAsFieldAttribute(this.field());
        Check.notNull((Object)fieldAttribute, (String)"Knn must have a field attribute as the first argument");
        String fieldName = this.getNameFromFieldAttribute(fieldAttribute);
        float[] queryAsFloats = this.queryAsFloats();
        ArrayList<QueryBuilder> filterQueries = new ArrayList<QueryBuilder>();
        for (Expression filterExpression : this.filterExpressions()) {
            TranslationAware translationAware;
            if (!(filterExpression instanceof TranslationAware) || (translationAware = (TranslationAware)filterExpression).translatable(pushdownPredicates) != TranslationAware.Translatable.YES) continue;
            filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder());
        }
        Map<String, Object> options = this.queryOptions();
        Integer explicitK = (Integer)options.get(KnnVectorQueryBuilder.K_FIELD.getPreferredName());
        return new KnnQuery(this.source(), fieldName, queryAsFloats, explicitK != null ? explicitK : this.implicitK(), options, filterQueries);
    }

    private float[] queryAsFloats() {
        Object queryFolded = this.queryAsObject();
        float[] queryAsFloats = new float[queryFolded.size()];
        for (int i = 0; i < queryFolded.size(); ++i) {
            queryAsFloats[i] = ((Number)queryFolded.get(i)).floatValue();
        }
        return queryAsFloats;
    }

    public Expression withFilters(List<Expression> filterExpressions) {
        return new Knn(this.source(), this.field(), this.query(), this.options(), this.implicitK(), this.queryBuilder(), filterExpressions);
    }

    private Map<String, Object> queryOptions() throws InvalidArgumentException {
        HashMap<String, Object> options = new HashMap<String, Object>();
        if (this.options() != null) {
            Options.populateMap((MapExpression)this.options(), options, this.source(), TypeResolutions.ParamOrdinal.FOURTH, ALLOWED_OPTIONS);
        }
        return options;
    }

    @Override
    protected QueryBuilder evaluatorQueryBuilder() {
        FieldAttribute fieldAttribute = this.fieldAsFieldAttribute(this.field());
        Check.notNull((Object)fieldAttribute, (String)"Knn must have a field attribute as the first argument");
        String fieldName = this.getNameFromFieldAttribute(fieldAttribute);
        Map<String, Object> opts = this.queryOptions();
        return new ExactKnnQueryBuilder(VectorData.fromFloats((float[])this.queryAsFloats()), fieldName, (Float)opts.get(KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD));
    }

    @Override
    public void postOptimizationVerification(Failures failures) {
        if (this.implicitK() == null) {
            failures.add(Failure.fail(this, "Knn function must be used with a LIMIT clause after it to set the number of nearest neighbors to find", new Object[0]));
        }
    }

    public Expression replaceChildren(List<Expression> newChildren) {
        return new Knn(this.source(), newChildren.get(0), newChildren.get(1), newChildren.size() > 2 ? newChildren.get(2) : null, this.implicitK(), this.queryBuilder(), this.filterExpressions());
    }

    protected NodeInfo<? extends Expression> info() {
        return NodeInfo.create((Node)this, Knn::new, (Object)this.field(), (Object)this.query(), (Object)this.options(), (Object)this.implicitK(), (Object)this.queryBuilder(), this.filterExpressions());
    }

    public String getWriteableName() {
        return Knn.ENTRY.name;
    }

    private static Knn readFrom(StreamInput in) throws IOException {
        Source source = Source.readFrom((StreamInput)((PlanStreamInput)in));
        Expression field = (Expression)in.readNamedWriteable(Expression.class);
        Expression query = (Expression)in.readNamedWriteable(Expression.class);
        QueryBuilder queryBuilder = (QueryBuilder)in.readOptionalNamedWriteable(QueryBuilder.class);
        List filterExpressions = in.readNamedWriteableCollectionAsList(Expression.class);
        return new Knn(source, field, query, null, null, queryBuilder, filterExpressions);
    }

    public void writeTo(StreamOutput out) throws IOException {
        this.source().writeTo(out);
        out.writeNamedWriteable((NamedWriteable)this.field());
        out.writeNamedWriteable((NamedWriteable)this.query());
        out.writeOptionalNamedWriteable((NamedWriteable)this.queryBuilder());
        out.writeNamedWriteableCollection(this.filterExpressions());
    }

    @Override
    protected Set<DataType> getFieldDataTypes() {
        return Set.of(DataType.DENSE_VECTOR, DataType.TEXT, DataType.NULL);
    }

    @Override
    protected Set<DataType> getQueryDataTypes() {
        return Set.of(DataType.DENSE_VECTOR);
    }

    @Override
    protected Map<String, DataType> getAllowedOptions() {
        return ALLOWED_OPTIONS;
    }

    @Override
    public boolean equals(Object o) {
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Knn knn = (Knn)o;
        return super.equals(knn) && Objects.equals(this.implicitK(), knn.implicitK()) && Objects.equals(this.filterExpressions(), knn.filterExpressions());
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.field(), this.query(), this.queryBuilder(), this.implicitK(), this.filterExpressions());
    }
}

