/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.rank.linear;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.action.ValidateActions;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils;
import org.elasticsearch.xpack.rank.RankRRFFeatures;
import org.elasticsearch.xpack.rank.linear.IdentityScoreNormalizer;
import org.elasticsearch.xpack.rank.linear.LinearRankDoc;
import org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent;
import org.elasticsearch.xpack.rank.linear.ScoreNormalizer;
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;

public final class LinearRetrieverBuilder
extends CompoundRetrieverBuilder<LinearRetrieverBuilder> {
    public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("linear_retriever.multi_fields_query_format_support");
    public static final NodeFeature LINEAR_RETRIEVER_MINSCORE_FIX = new NodeFeature("linear_retriever_minscore_fix");
    public static final String NAME = "linear";
    public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers", new String[0]);
    public static final ParseField FIELDS_FIELD = new ParseField("fields", new String[0]);
    public static final ParseField QUERY_FIELD = new ParseField("query", new String[0]);
    public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer", new String[0]);
    public static final float DEFAULT_SCORE = 0.0f;
    private final float[] weights;
    private final ScoreNormalizer[] normalizers;
    private final List<String> fields;
    private final String query;
    private final ScoreNormalizer normalizer;
    static final ConstructingObjectParser<LinearRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser("linear", false, args -> {
        List retrieverComponents = args[0] == null ? List.of() : (List)args[0];
        List fields = (List)args[1];
        String query = (String)args[2];
        ScoreNormalizer normalizer = args[3] == null ? null : ScoreNormalizer.valueOf((String)args[3]);
        int rankWindowSize = args[4] == null ? 10 : (Integer)args[4];
        int index = 0;
        float[] weights = new float[retrieverComponents.size()];
        ScoreNormalizer[] normalizers = new ScoreNormalizer[retrieverComponents.size()];
        ArrayList<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<CompoundRetrieverBuilder.RetrieverSource>();
        for (LinearRetrieverComponent component : retrieverComponents) {
            innerRetrievers.add(CompoundRetrieverBuilder.RetrieverSource.from((RetrieverBuilder)component.retriever));
            weights[index] = component.weight;
            normalizers[index] = component.normalizer;
            ++index;
        }
        return new LinearRetrieverBuilder(innerRetrievers, fields, query, normalizer, rankWindowSize, weights, normalizers);
    });

    private static float[] getDefaultWeight(List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers) {
        int size = innerRetrievers != null ? innerRetrievers.size() : 0;
        float[] weights = new float[size];
        Arrays.fill(weights, 1.0f);
        return weights;
    }

    private static ScoreNormalizer[] getDefaultNormalizers(List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers) {
        int size = innerRetrievers != null ? innerRetrievers.size() : 0;
        Object[] normalizers = new ScoreNormalizer[size];
        Arrays.fill(normalizers, IdentityScoreNormalizer.INSTANCE);
        return normalizers;
    }

    public static LinearRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
        if (!context.clusterSupportsFeature(RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED)) {
            throw new ParsingException(parser.getTokenLocation(), "unknown retriever [linear]", new Object[0]);
        }
        if (!RRFRankPlugin.LINEAR_RETRIEVER_FEATURE.check(XPackPlugin.getSharedLicenseState())) {
            throw LicenseUtils.newComplianceException((String)"linear retriever");
        }
        return (LinearRetrieverBuilder)((Object)PARSER.apply(parser, (Object)context));
    }

    LinearRetrieverBuilder(List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers, int rankWindowSize) {
        this(innerRetrievers, null, null, null, rankWindowSize, LinearRetrieverBuilder.getDefaultWeight(innerRetrievers), LinearRetrieverBuilder.getDefaultNormalizers(innerRetrievers));
    }

    public LinearRetrieverBuilder(List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers, int rankWindowSize, float[] weights, ScoreNormalizer[] normalizers) {
        this(innerRetrievers, null, null, null, rankWindowSize, weights, normalizers);
    }

    public LinearRetrieverBuilder(List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers, List<String> fields, String query, ScoreNormalizer normalizer, int rankWindowSize, float[] weights, ScoreNormalizer[] normalizers) {
        super(innerRetrievers == null ? new ArrayList() : new ArrayList<CompoundRetrieverBuilder.RetrieverSource>(innerRetrievers), rankWindowSize);
        if (weights.length != this.innerRetrievers.size()) {
            throw new IllegalArgumentException("The number of weights must match the number of inner retrievers");
        }
        if (normalizers.length != this.innerRetrievers.size()) {
            throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers");
        }
        this.fields = fields == null ? null : List.copyOf(fields);
        this.query = query;
        this.normalizer = normalizer;
        this.weights = weights;
        this.normalizers = normalizers;
    }

    public LinearRetrieverBuilder(List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers, List<String> fields, String query, ScoreNormalizer normalizer, int rankWindowSize, float[] weights, ScoreNormalizer[] normalizers, Float minScore, String retrieverName, List<QueryBuilder> preFilterQueryBuilders) {
        this(innerRetrievers, fields, query, normalizer, rankWindowSize, weights, normalizers);
        this.minScore = minScore;
        if (minScore != null && minScore.floatValue() < 0.0f) {
            throw new IllegalArgumentException("[min_score] must be greater than or equal to 0, was: [" + minScore + "]");
        }
        this.retrieverName = retrieverName;
        this.preFilterQueryBuilders = preFilterQueryBuilders;
    }

    public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException, boolean isScroll, boolean allowPartialSearchResults) {
        validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
        validationException = MultiFieldsInnerRetrieverUtils.validateParams(this.innerRetrievers, this.fields, this.query, this.getName(), RETRIEVERS_FIELD.getPreferredName(), FIELDS_FIELD.getPreferredName(), QUERY_FIELD.getPreferredName(), validationException);
        if (this.query != null && this.normalizer == null) {
            validationException = ValidateActions.addValidationError((String)String.format(Locale.ROOT, "[%s] [%s] must be provided when [%s] is specified", this.getName(), NORMALIZER_FIELD.getPreferredName(), QUERY_FIELD.getPreferredName()), (ActionRequestValidationException)validationException);
        } else if (!this.innerRetrievers.isEmpty() && this.normalizer != null) {
            validationException = ValidateActions.addValidationError((String)String.format(Locale.ROOT, "[%s] [%s] cannot be provided when [%s] is specified", this.getName(), NORMALIZER_FIELD.getPreferredName(), RETRIEVERS_FIELD.getPreferredName()), (ActionRequestValidationException)validationException);
        }
        return validationException;
    }

    protected LinearRetrieverBuilder clone(List<CompoundRetrieverBuilder.RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
        return new LinearRetrieverBuilder(newChildRetrievers, this.fields, this.query, this.normalizer, this.rankWindowSize, this.weights, this.normalizers, this.minScore, this.retrieverName, newPreFilterQueryBuilders);
    }

    protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
        sourceBuilder.trackScores(true);
        return sourceBuilder;
    }

    protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean isExplain) {
        Map docsToRankResults = Maps.newMapWithExpectedSize((int)this.rankWindowSize);
        String[] normalizerNames = (String[])Arrays.stream(this.normalizers).map(ScoreNormalizer::getName).toArray(String[]::new);
        for (int result = 0; result < rankResults.size(); ++result) {
            IdentityScoreNormalizer normalizer = this.normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : this.normalizers[result];
            ScoreDoc[] originalScoreDocs = rankResults.get(result);
            ScoreDoc[] normalizedScoreDocs = ((ScoreNormalizer)normalizer).normalizeScores(originalScoreDocs);
            for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; ++scoreDocIndex) {
                LinearRankDoc rankDoc = docsToRankResults.computeIfAbsent(new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex), key -> {
                    if (isExplain) {
                        LinearRankDoc doc = new LinearRankDoc(key.doc(), 0.0f, key.shardIndex(), this.weights, normalizerNames);
                        doc.normalizedScores = new float[rankResults.size()];
                        return doc;
                    }
                    return new LinearRankDoc(key.doc(), 0.0f, key.shardIndex());
                });
                if (isExplain) {
                    rankDoc.normalizedScores[result] = normalizedScoreDocs[scoreDocIndex].score;
                }
                float docScore = false == Float.isNaN(normalizedScoreDocs[scoreDocIndex].score) ? normalizedScoreDocs[scoreDocIndex].score : 0.0f;
                float weight = Float.isNaN(this.weights[result]) ? 1.0f : this.weights[result];
                rankDoc.score += weight * docScore;
            }
        }
        LinearRankDoc[] sortedResults = (LinearRankDoc[])docsToRankResults.values().toArray(LinearRankDoc[]::new);
        Arrays.sort((Object[])sortedResults);
        RankDoc[] topResults = new LinearRankDoc[Math.min(this.rankWindowSize, sortedResults.length)];
        for (int rank = 0; rank < topResults.length; ++rank) {
            topResults[rank] = sortedResults[rank];
            topResults[rank].rank = rank + 1;
        }
        if (this.minScore != null) {
            topResults = (LinearRankDoc[])Arrays.stream(topResults).filter(doc -> doc.score >= this.minScore.floatValue()).toArray(LinearRankDoc[]::new);
        }
        return topResults;
    }

    protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
        LinearRetrieverBuilder rewritten = this;
        ResolvedIndices resolvedIndices = ctx.getResolvedIndices();
        if (resolvedIndices != null && this.query != null) {
            Map localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
            if (localIndicesMetadata.size() > 1) {
                throw new IllegalArgumentException("[linear] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying multiple indices");
            }
            if (!resolvedIndices.getRemoteClusterIndices().isEmpty()) {
                throw new IllegalArgumentException("[linear] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying remote indices");
            }
            List<CompoundRetrieverBuilder.RetrieverSource> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(this.fields, this.query, localIndicesMetadata.values(), r -> {
                ArrayList<CompoundRetrieverBuilder.RetrieverSource> retrievers = new ArrayList<CompoundRetrieverBuilder.RetrieverSource>(r.size());
                float[] weights = new float[r.size()];
                ScoreNormalizer[] normalizers = new ScoreNormalizer[r.size()];
                int index = 0;
                for (MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource weightedRetriever : r) {
                    retrievers.add(weightedRetriever.retrieverSource());
                    weights[index] = weightedRetriever.weight();
                    normalizers[index] = this.normalizer;
                    ++index;
                }
                return new LinearRetrieverBuilder(retrievers, this.rankWindowSize, weights, normalizers);
            }, w -> {
                if (w.floatValue() < 0.0f) {
                    throw new IllegalArgumentException("[linear] per-field weights must be non-negative");
                }
            }).stream().map(CompoundRetrieverBuilder.RetrieverSource::from).toList();
            if (!fieldsInnerRetrievers.isEmpty()) {
                float[] weights = new float[fieldsInnerRetrievers.size()];
                Arrays.fill(weights, 1.0f);
                Object[] normalizers = new ScoreNormalizer[fieldsInnerRetrievers.size()];
                Arrays.fill(normalizers, this.normalizer);
                rewritten = new LinearRetrieverBuilder(fieldsInnerRetrievers, null, null, this.normalizer, this.rankWindowSize, weights, (ScoreNormalizer[])normalizers);
                rewritten.getPreFilterQueryBuilders().addAll(this.preFilterQueryBuilders);
            } else {
                rewritten = new StandardRetrieverBuilder((QueryBuilder)new MatchNoneQueryBuilder());
            }
        }
        return rewritten;
    }

    public String getName() {
        return NAME;
    }

    float[] getWeights() {
        return this.weights;
    }

    ScoreNormalizer[] getNormalizers() {
        return this.normalizers;
    }

    public void doToXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        int index = 0;
        if (!this.innerRetrievers.isEmpty()) {
            builder.startArray(RETRIEVERS_FIELD.getPreferredName());
            for (CompoundRetrieverBuilder.RetrieverSource entry : this.innerRetrievers) {
                builder.startObject();
                builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), (ToXContent)entry.retriever());
                builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), this.weights[index]);
                builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), this.normalizers[index].getName());
                builder.endObject();
                ++index;
            }
            builder.endArray();
        }
        if (this.fields != null) {
            builder.startArray(FIELDS_FIELD.getPreferredName());
            for (String field : this.fields) {
                builder.value(field);
            }
            builder.endArray();
        }
        if (this.query != null) {
            builder.field(QUERY_FIELD.getPreferredName(), this.query);
        }
        if (this.normalizer != null) {
            builder.field(NORMALIZER_FIELD.getPreferredName(), this.normalizer.getName());
        }
        builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), this.rankWindowSize);
    }

    public boolean doEquals(Object o) {
        LinearRetrieverBuilder that = (LinearRetrieverBuilder)((Object)o);
        return super.doEquals(o) && Arrays.equals(this.weights, that.weights) && Arrays.equals(this.normalizers, that.normalizers) && Objects.equals(this.fields, that.fields) && Objects.equals(this.query, that.query) && Objects.equals(this.normalizer, that.normalizer);
    }

    public int doHashCode() {
        return Objects.hash(super.doHashCode(), Arrays.hashCode(this.weights), Arrays.hashCode(this.normalizers), this.fields, this.query, this.normalizer);
    }

    static {
        PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
        PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), NORMALIZER_FIELD);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
        RetrieverBuilder.declareBaseParserFields((String)NAME, PARSER);
    }
}

