/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.retriever;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.RankDocsQueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

public final class KnnRetrieverBuilder
extends RetrieverBuilder {
    public static final String NAME = "knn";
    public static final ParseField FIELD_FIELD = new ParseField("field", new String[0]);
    public static final ParseField K_FIELD = new ParseField("k", new String[0]);
    public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates", new String[0]);
    public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector", new String[0]);
    public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder", new String[0]);
    public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity", new String[0]);
    public static final ParseField RESCORE_VECTOR_FIELD = new ParseField("rescore_vector", new String[0]);
    public static final ConstructingObjectParser<KnnRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser("knn", args -> {
        float[] vectorArray;
        List vector = (List)args[1];
        if (vector != null) {
            vectorArray = new float[vector.size()];
            for (int i = 0; i < vector.size(); ++i) {
                vectorArray[i] = ((Float)vector.get(i)).floatValue();
            }
        } else {
            vectorArray = null;
        }
        return new KnnRetrieverBuilder((String)args[0], vectorArray, (QueryVectorBuilder)args[2], (Integer)args[3], (Integer)args[4], (RescoreVectorBuilder)args[6], (Float)args[5]);
    });
    private final String field;
    private final Supplier<float[]> queryVector;
    private final QueryVectorBuilder queryVectorBuilder;
    private final int k;
    private final int numCands;
    private final RescoreVectorBuilder rescoreVectorBuilder;
    private final Float similarity;

    public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
        return PARSER.apply(parser, context);
    }

    public KnnRetrieverBuilder(String field, float[] queryVector, QueryVectorBuilder queryVectorBuilder, int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity) {
        if (queryVector == null && queryVectorBuilder == null) {
            throw new IllegalArgumentException(Strings.format("either [%s] or [%s] must be provided", QUERY_VECTOR_FIELD.getPreferredName(), QUERY_VECTOR_BUILDER_FIELD.getPreferredName()));
        }
        if (queryVector != null && queryVectorBuilder != null) {
            throw new IllegalArgumentException(Strings.format("only one of [%s] and [%s] must be provided", QUERY_VECTOR_FIELD.getPreferredName(), QUERY_VECTOR_BUILDER_FIELD.getPreferredName()));
        }
        this.field = field;
        this.queryVector = queryVector != null ? () -> queryVector : null;
        this.queryVectorBuilder = queryVectorBuilder;
        this.k = k;
        this.numCands = numCands;
        this.similarity = similarity;
        this.rescoreVectorBuilder = rescoreVectorBuilder;
    }

    private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier<float[]> queryVector, QueryVectorBuilder queryVectorBuilder) {
        this.queryVector = queryVector;
        this.queryVectorBuilder = queryVectorBuilder;
        this.field = clone.field;
        this.k = clone.k;
        this.numCands = clone.numCands;
        this.similarity = clone.similarity;
        this.retrieverName = clone.retrieverName;
        this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
        this.rescoreVectorBuilder = clone.rescoreVectorBuilder;
    }

    @Override
    public String getName() {
        return NAME;
    }

    @Override
    public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
        List<QueryBuilder> rewrittenFilters = this.rewritePreFilters(ctx);
        if (rewrittenFilters != this.preFilterQueryBuilders) {
            KnnRetrieverBuilder rewritten = new KnnRetrieverBuilder(this, this.queryVector, this.queryVectorBuilder);
            rewritten.preFilterQueryBuilders = rewrittenFilters;
            return rewritten;
        }
        if (this.queryVectorBuilder != null) {
            SetOnce toSet = new SetOnce();
            ctx.registerAsyncAction((c, l) -> this.queryVectorBuilder.buildVector((Client)c, l.delegateFailureAndWrap((ll, v) -> {
                toSet.set(v);
                if (v == null) {
                    ll.onFailure(new IllegalArgumentException(Strings.format("[%s] with name [%s] returned null query_vector", QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), this.queryVectorBuilder.getWriteableName())));
                    return;
                }
                ll.onResponse(null);
            })));
            return new KnnRetrieverBuilder(this, () -> (float[])toSet.get(), null);
        }
        return super.rewrite(ctx);
    }

    @Override
    public QueryBuilder topDocsQuery() {
        assert (this.queryVector != null) : "query vector must be materialized at this point";
        assert (this.rankDocs != null) : "rankDocs should have been materialized by now";
        RankDocsQueryBuilder rankDocsQuery = new RankDocsQueryBuilder(this.rankDocs, null, true);
        if (this.preFilterQueryBuilders.isEmpty()) {
            return rankDocsQuery.queryName(this.retrieverName);
        }
        BoolQueryBuilder res = new BoolQueryBuilder().must(rankDocsQuery);
        this.preFilterQueryBuilders.forEach(res::filter);
        return res.queryName(this.retrieverName);
    }

    @Override
    public QueryBuilder explainQuery() {
        assert (this.queryVector != null) : "query vector must be materialized at this point";
        assert (this.rankDocs != null) : "rankDocs should have been materialized by now";
        RankDocsQueryBuilder rankDocsQuery = new RankDocsQueryBuilder(this.rankDocs, new QueryBuilder[]{new ExactKnnQueryBuilder(VectorData.fromFloats(this.queryVector.get()), this.field, this.similarity)}, true);
        if (this.preFilterQueryBuilders.isEmpty()) {
            return rankDocsQuery.queryName(this.retrieverName);
        }
        BoolQueryBuilder res = new BoolQueryBuilder().must(rankDocsQuery);
        this.preFilterQueryBuilders.forEach(res::filter);
        return res.queryName(this.retrieverName);
    }

    @Override
    public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
        assert (this.queryVector != null) : "query vector must be materialized at this point.";
        KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(this.field, VectorData.fromFloats(this.queryVector.get()), null, this.k, this.numCands, this.rescoreVectorBuilder, this.similarity);
        if (this.preFilterQueryBuilders != null) {
            knnSearchBuilder.addFilterQueries(this.preFilterQueryBuilders);
        }
        if (this.retrieverName != null) {
            knnSearchBuilder.queryName(this.retrieverName);
        }
        ArrayList<KnnSearchBuilder> knnSearchBuilders = new ArrayList<KnnSearchBuilder>(searchSourceBuilder.knnSearch());
        knnSearchBuilders.add(knnSearchBuilder);
        searchSourceBuilder.knnSearch(knnSearchBuilders);
    }

    RescoreVectorBuilder rescoreVectorBuilder() {
        return this.rescoreVectorBuilder;
    }

    @Override
    public void doToXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.field(FIELD_FIELD.getPreferredName(), this.field);
        builder.field(K_FIELD.getPreferredName(), this.k);
        builder.field(NUM_CANDS_FIELD.getPreferredName(), this.numCands);
        if (this.queryVector != null) {
            builder.field(QUERY_VECTOR_FIELD.getPreferredName(), this.queryVector.get());
        }
        if (this.queryVectorBuilder != null) {
            builder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), this.queryVectorBuilder);
        }
        if (this.similarity != null) {
            builder.field(VECTOR_SIMILARITY.getPreferredName(), this.similarity);
        }
        if (this.rescoreVectorBuilder != null) {
            builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), this.rescoreVectorBuilder);
        }
    }

    @Override
    public boolean doEquals(Object o) {
        KnnRetrieverBuilder that = (KnnRetrieverBuilder)o;
        return this.k == that.k && this.numCands == that.numCands && Objects.equals(this.field, that.field) && (this.queryVector == null && that.queryVector == null || this.queryVector != null && that.queryVector != null && Arrays.equals(this.queryVector.get(), that.queryVector.get())) && Objects.equals(this.queryVectorBuilder, that.queryVectorBuilder) && Objects.equals(this.similarity, that.similarity) && Objects.equals(this.rescoreVectorBuilder, that.rescoreVectorBuilder);
    }

    @Override
    public int doHashCode() {
        int result = Objects.hash(this.field, this.queryVectorBuilder, this.k, this.numCands, this.rescoreVectorBuilder, this.similarity);
        result = 31 * result + Arrays.hashCode(this.queryVector != null ? this.queryVector.get() : null);
        return result;
    }

    static {
        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_FIELD);
        PARSER.declareFloatArray(ConstructingObjectParser.optionalConstructorArg(), QUERY_VECTOR_FIELD);
        PARSER.declareNamedObject(ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), QUERY_VECTOR_BUILDER_FIELD);
        PARSER.declareInt(ConstructingObjectParser.constructorArg(), K_FIELD);
        PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_CANDS_FIELD);
        PARSER.declareFloat(ConstructingObjectParser.optionalConstructorArg(), VECTOR_SIMILARITY);
        PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> RescoreVectorBuilder.fromXContent(p), RESCORE_VECTOR_FIELD, ObjectParser.ValueType.OBJECT);
        RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
    }
}

