/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.vectors;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

public class KnnSearchBuilder
implements Writeable,
ToXContentFragment,
Rewriteable<KnnSearchBuilder> {
    public static final int NUM_CANDS_LIMIT = 10000;
    public static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f;
    public static final ParseField FIELD_FIELD = new ParseField("field", new String[0]);
    public static final ParseField K_FIELD = new ParseField("k", new String[0]);
    public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates", new String[0]);
    public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector", new String[0]);
    public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder", new String[0]);
    public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity", new String[0]);
    public static final ParseField FILTER_FIELD = new ParseField("filter", new String[0]);
    public static final ParseField NAME_FIELD = AbstractQueryBuilder.NAME_FIELD;
    public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD;
    public static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits", new String[0]);
    public static final ParseField RESCORE_VECTOR_FIELD = new ParseField("rescore_vector", new String[0]);
    private static final ConstructingObjectParser<Builder, Void> PARSER = new ConstructingObjectParser("knn", args -> new Builder().field((String)args[0]).queryVector((VectorData)args[1]).queryVectorBuilder((QueryVectorBuilder)args[4]).k((Integer)args[2]).numCandidates((Integer)args[3]).similarity((Float)args[5]).rescoreVectorBuilder((RescoreVectorBuilder)args[6]));
    final String field;
    final VectorData queryVector;
    final QueryVectorBuilder queryVectorBuilder;
    private final Supplier<float[]> querySupplier;
    final int k;
    final int numCands;
    final Float similarity;
    final List<QueryBuilder> filterQueries;
    String queryName;
    float boost = 1.0f;
    InnerHitBuilder innerHitBuilder;
    private final RescoreVectorBuilder rescoreVectorBuilder;

    public static Builder fromXContent(XContentParser parser) throws IOException {
        return PARSER.parse(parser, null);
    }

    public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity) {
        this(field, Objects.requireNonNull(VectorData.fromFloats(queryVector), Strings.format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands, rescoreVectorBuilder, similarity);
    }

    public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity) {
        this(field, queryVector, null, k, numCands, rescoreVectorBuilder, similarity);
    }

    public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity) {
        this(field, null, Objects.requireNonNull(queryVectorBuilder, Strings.format("[%s] cannot be null", QUERY_VECTOR_BUILDER_FIELD.getPreferredName())), k, numCands, rescoreVectorBuilder, similarity);
    }

    public KnnSearchBuilder(String field, VectorData queryVector, QueryVectorBuilder queryVectorBuilder, int k, int numCands, RescoreVectorBuilder rescoreVectorBuilder, Float similarity) {
        this(field, queryVectorBuilder, queryVector, new ArrayList<QueryBuilder>(), k, numCands, rescoreVectorBuilder, similarity, null, null, 1.0f);
    }

    private KnnSearchBuilder(String field, Supplier<float[]> querySupplier, Integer k, Integer numCands, RescoreVectorBuilder rescoreVectorBuilder, List<QueryBuilder> filterQueries, Float similarity) {
        this.field = field;
        this.queryVector = VectorData.fromFloats(new float[0]);
        this.queryVectorBuilder = null;
        this.k = k;
        this.numCands = numCands;
        this.filterQueries = filterQueries;
        this.querySupplier = querySupplier;
        this.similarity = similarity;
        this.rescoreVectorBuilder = rescoreVectorBuilder;
    }

    private KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, VectorData queryVector, List<QueryBuilder> filterQueries, int k, int numCandidates, RescoreVectorBuilder rescoreVectorBuilder, Float similarity, InnerHitBuilder innerHitBuilder, String queryName, float boost) {
        if (k < 1) {
            throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
        }
        if (numCandidates < k) {
            throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]");
        }
        if (numCandidates > 10000) {
            throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [10000]");
        }
        if (queryVector == null && queryVectorBuilder == null) {
            throw new IllegalArgumentException(Strings.format("either [%s] or [%s] must be provided", QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), QUERY_VECTOR_FIELD.getPreferredName()));
        }
        if (queryVector != null && queryVectorBuilder != null) {
            throw new IllegalArgumentException(Strings.format("cannot provide both [%s] and [%s]", QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), QUERY_VECTOR_FIELD.getPreferredName()));
        }
        this.field = field;
        this.queryVector = queryVector == null ? VectorData.fromFloats(new float[0]) : queryVector;
        this.queryVectorBuilder = queryVectorBuilder;
        this.k = k;
        this.numCands = numCandidates;
        this.rescoreVectorBuilder = rescoreVectorBuilder;
        this.innerHitBuilder = innerHitBuilder;
        this.similarity = similarity;
        this.queryName = queryName;
        this.boost = boost;
        this.filterQueries = filterQueries;
        this.querySupplier = null;
    }

    public KnnSearchBuilder(StreamInput in) throws IOException {
        this.field = in.readString();
        this.k = in.readVInt();
        this.numCands = in.readVInt();
        this.queryVector = in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0) ? in.readOptionalWriteable(VectorData::new) : VectorData.fromFloats(in.readFloatArray());
        this.filterQueries = in.readNamedWriteableCollectionAsList(QueryBuilder.class);
        this.boost = in.readFloat();
        this.queryName = in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0) ? in.readOptionalString() : null;
        this.queryVectorBuilder = in.getTransportVersion().onOrAfter(TransportVersions.V_8_7_0) ? in.readOptionalNamedWriteable(QueryVectorBuilder.class) : null;
        this.querySupplier = null;
        this.similarity = in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0) ? in.readOptionalFloat() : null;
        if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_11_X)) {
            this.innerHitBuilder = in.readOptionalWriteable(InnerHitBuilder::new);
        }
        this.rescoreVectorBuilder = in.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE) ? in.readOptional(RescoreVectorBuilder::new) : null;
    }

    public int k() {
        return this.k;
    }

    public int getNumCands() {
        return this.numCands;
    }

    public RescoreVectorBuilder getRescoreVectorBuilder() {
        return this.rescoreVectorBuilder;
    }

    public QueryVectorBuilder getQueryVectorBuilder() {
        return this.queryVectorBuilder;
    }

    public VectorData getQueryVector() {
        return this.queryVector;
    }

    public String getField() {
        return this.field;
    }

    public List<QueryBuilder> getFilterQueries() {
        return this.filterQueries;
    }

    public KnnSearchBuilder addFilterQuery(QueryBuilder filterQuery) {
        Objects.requireNonNull(filterQuery);
        this.filterQueries.add(filterQuery);
        return this;
    }

    public KnnSearchBuilder addFilterQueries(List<QueryBuilder> filterQueries) {
        Objects.requireNonNull(filterQueries);
        this.filterQueries.addAll(filterQueries);
        return this;
    }

    public KnnSearchBuilder queryName(String queryName) {
        this.queryName = queryName;
        return this;
    }

    public String queryName() {
        return this.queryName;
    }

    public KnnSearchBuilder boost(float boost) {
        this.boost = boost;
        return this;
    }

    public float boost() {
        return this.boost;
    }

    public KnnSearchBuilder innerHit(InnerHitBuilder innerHitBuilder) {
        this.innerHitBuilder = innerHitBuilder;
        return this;
    }

    public InnerHitBuilder innerHit() {
        return this.innerHitBuilder;
    }

    @Override
    public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
        if (this.querySupplier != null) {
            if (this.querySupplier.get() == null) {
                return this;
            }
            return new KnnSearchBuilder(this.field, this.querySupplier.get(), this.k, this.numCands, this.rescoreVectorBuilder, this.similarity).boost(this.boost).queryName(this.queryName).addFilterQueries(this.filterQueries).innerHit(this.innerHitBuilder);
        }
        if (this.queryVectorBuilder != null) {
            SetOnce toSet = new SetOnce();
            ctx.registerAsyncAction((c, l) -> this.queryVectorBuilder.buildVector((Client)c, l.delegateFailureAndWrap((ll, v) -> {
                toSet.set(v);
                if (v == null) {
                    ll.onFailure(new IllegalArgumentException(Strings.format("[%s] with name [%s] returned null query_vector", QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), this.queryVectorBuilder.getWriteableName())));
                    return;
                }
                ll.onResponse(null);
            })));
            return new KnnSearchBuilder(this.field, toSet::get, this.k, (Integer)this.numCands, this.rescoreVectorBuilder, this.filterQueries, this.similarity).boost(this.boost).queryName(this.queryName).innerHit(this.innerHitBuilder);
        }
        boolean changed = false;
        ArrayList<QueryBuilder> rewrittenQueries = new ArrayList<QueryBuilder>(this.filterQueries.size());
        for (QueryBuilder query : this.filterQueries) {
            QueryBuilder rewrittenQuery = query.rewrite(ctx);
            if (rewrittenQuery != query) {
                changed = true;
            }
            rewrittenQueries.add(rewrittenQuery);
        }
        if (changed) {
            return new KnnSearchBuilder(this.field, this.queryVector, this.k, this.numCands, this.rescoreVectorBuilder, this.similarity).boost(this.boost).queryName(this.queryName).addFilterQueries(rewrittenQueries).innerHit(this.innerHitBuilder);
        }
        return this;
    }

    public KnnVectorQueryBuilder toQueryBuilder() {
        if (this.queryVectorBuilder != null) {
            throw new IllegalArgumentException("missing rewrite");
        }
        return ((KnnVectorQueryBuilder)((KnnVectorQueryBuilder)new KnnVectorQueryBuilder(this.field, this.queryVector, (Integer)this.numCands, (Integer)this.numCands, this.rescoreVectorBuilder, this.similarity).boost(this.boost)).queryName(this.queryName)).addFilterQueries(this.filterQueries);
    }

    public Float getSimilarity() {
        return this.similarity;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        KnnSearchBuilder that = (KnnSearchBuilder)o;
        return this.k == that.k && this.numCands == that.numCands && Objects.equals(this.rescoreVectorBuilder, that.rescoreVectorBuilder) && Objects.equals(this.field, that.field) && Objects.equals(this.queryVector, that.queryVector) && Objects.equals(this.queryVectorBuilder, that.queryVectorBuilder) && Objects.equals(this.querySupplier, that.querySupplier) && Objects.equals(this.filterQueries, that.filterQueries) && Objects.equals(this.similarity, that.similarity) && Objects.equals(this.innerHitBuilder, that.innerHitBuilder) && Objects.equals(this.queryName, that.queryName) && this.boost == that.boost;
    }

    public int hashCode() {
        return Objects.hash(this.field, this.k, this.numCands, this.querySupplier, this.queryVectorBuilder, this.rescoreVectorBuilder, this.similarity, Objects.hashCode(this.queryVector), Objects.hashCode(this.filterQueries), this.innerHitBuilder, this.queryName, Float.valueOf(this.boost));
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.field(FIELD_FIELD.getPreferredName(), this.field);
        builder.field(K_FIELD.getPreferredName(), this.k);
        builder.field(NUM_CANDS_FIELD.getPreferredName(), this.numCands);
        if (this.queryVectorBuilder != null) {
            builder.startObject(QUERY_VECTOR_BUILDER_FIELD.getPreferredName());
            builder.field(this.queryVectorBuilder.getWriteableName(), this.queryVectorBuilder);
            builder.endObject();
        } else {
            builder.field(QUERY_VECTOR_FIELD.getPreferredName(), this.queryVector);
        }
        if (this.similarity != null) {
            builder.field(VECTOR_SIMILARITY.getPreferredName(), this.similarity);
        }
        if (!this.filterQueries.isEmpty()) {
            builder.startArray(FILTER_FIELD.getPreferredName());
            for (QueryBuilder filterQuery : this.filterQueries) {
                filterQuery.toXContent(builder, params);
            }
            builder.endArray();
        }
        if (this.innerHitBuilder != null) {
            builder.field(INNER_HITS_FIELD.getPreferredName(), this.innerHitBuilder, params);
        }
        if (this.boost != 1.0f) {
            builder.field(BOOST_FIELD.getPreferredName(), this.boost);
        }
        if (this.queryName != null) {
            builder.field(NAME_FIELD.getPreferredName(), this.queryName);
        }
        if (this.rescoreVectorBuilder != null) {
            builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), this.rescoreVectorBuilder);
        }
        return builder;
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        if (this.querySupplier != null) {
            throw new IllegalStateException("missing a rewriteAndFetch?");
        }
        out.writeString(this.field);
        out.writeVInt(this.k);
        out.writeVInt(this.numCands);
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) {
            out.writeOptionalWriteable(this.queryVector);
        } else {
            out.writeFloatArray(this.queryVector.asFloatVector());
        }
        out.writeNamedWriteableCollection(this.filterQueries);
        out.writeFloat(this.boost);
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {
            out.writeOptionalString(this.queryName);
        }
        if (out.getTransportVersion().before(TransportVersions.V_8_7_0) && this.queryVectorBuilder != null) {
            throw new IllegalArgumentException(Strings.format("cannot serialize [%s] to older node of version [%s]", QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), out.getTransportVersion()));
        }
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_7_0)) {
            out.writeOptionalNamedWriteable(this.queryVectorBuilder);
        }
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
            out.writeOptionalFloat(this.similarity);
        }
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_11_X)) {
            out.writeOptionalWriteable(this.innerHitBuilder);
        }
        if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)) {
            out.writeOptionalWriteable(this.rescoreVectorBuilder);
        }
    }

    static {
        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_FIELD);
        PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> VectorData.parseXContent(p), QUERY_VECTOR_FIELD, ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), K_FIELD);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_CANDS_FIELD);
        PARSER.declareNamedObject(ConstructingObjectParser.optionalConstructorArg(), (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), QUERY_VECTOR_BUILDER_FIELD);
        PARSER.declareFloat(ConstructingObjectParser.optionalConstructorArg(), VECTOR_SIMILARITY);
        PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> RescoreVectorBuilder.fromXContent(p), RESCORE_VECTOR_FIELD, ObjectParser.ValueType.OBJECT);
        PARSER.declareFieldArray(Builder::addFilterQueries, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), FILTER_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);
        PARSER.declareString(Builder::queryName, NAME_FIELD);
        PARSER.declareFloat(Builder::boost, BOOST_FIELD);
        PARSER.declareField(Builder::innerHit, (p, c) -> InnerHitBuilder.fromXContent(p), INNER_HITS_FIELD, ObjectParser.ValueType.OBJECT);
    }

    public static class Builder {
        private String field;
        private VectorData queryVector;
        private QueryVectorBuilder queryVectorBuilder;
        private Integer k;
        private Integer numCandidates;
        private Float similarity;
        private final List<QueryBuilder> filterQueries = new ArrayList<QueryBuilder>();
        private String queryName;
        private float boost = 1.0f;
        private InnerHitBuilder innerHitBuilder;
        private RescoreVectorBuilder rescoreVectorBuilder;

        public Builder addFilterQueries(List<QueryBuilder> filterQueries) {
            Objects.requireNonNull(filterQueries);
            this.filterQueries.addAll(filterQueries);
            return this;
        }

        public Builder field(String field) {
            this.field = field;
            return this;
        }

        public Builder queryName(String queryName) {
            this.queryName = queryName;
            return this;
        }

        public Builder boost(float boost) {
            this.boost = boost;
            return this;
        }

        public Builder innerHit(InnerHitBuilder innerHitBuilder) {
            this.innerHitBuilder = innerHitBuilder;
            return this;
        }

        public Builder queryVector(VectorData queryVector) {
            this.queryVector = queryVector;
            return this;
        }

        public Builder queryVectorBuilder(QueryVectorBuilder queryVectorBuilder) {
            this.queryVectorBuilder = queryVectorBuilder;
            return this;
        }

        public Builder k(Integer k) {
            this.k = k;
            return this;
        }

        public Builder numCandidates(Integer numCands) {
            this.numCandidates = numCands;
            return this;
        }

        public Builder similarity(Float similarity) {
            this.similarity = similarity;
            return this;
        }

        public Builder rescoreVectorBuilder(RescoreVectorBuilder rescoreVectorBuilder) {
            this.rescoreVectorBuilder = rescoreVectorBuilder;
            return this;
        }

        public KnnSearchBuilder build(int size) {
            int requestSize = size < 0 ? 10 : size;
            int adjustedK = this.k == null ? requestSize : this.k;
            int adjustedNumCandidates = this.numCandidates == null ? Math.round(Math.min(10000.0f, 1.5f * (float)adjustedK)) : this.numCandidates;
            return new KnnSearchBuilder(this.field, this.queryVectorBuilder, this.queryVector, this.filterQueries, adjustedK, adjustedNumCandidates, this.rescoreVectorBuilder, this.similarity, this.innerHitBuilder, this.queryName, this.boost);
        }
    }
}

