/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.rank.linear;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.rank.RankRRFFeatures;
import org.elasticsearch.xpack.rank.linear.IdentityScoreNormalizer;
import org.elasticsearch.xpack.rank.linear.LinearRankDoc;
import org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent;
import org.elasticsearch.xpack.rank.linear.ScoreNormalizer;
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;

public final class LinearRetrieverBuilder
extends CompoundRetrieverBuilder<LinearRetrieverBuilder> {
    public static final String NAME = "linear";
    public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers", new String[0]);
    public static final float DEFAULT_SCORE = 0.0f;
    private final float[] weights;
    private final ScoreNormalizer[] normalizers;
    static final ConstructingObjectParser<LinearRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser("linear", false, args -> {
        List retrieverComponents = (List)args[0];
        int rankWindowSize = args[1] == null ? 10 : (Integer)args[1];
        ArrayList<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<CompoundRetrieverBuilder.RetrieverSource>();
        float[] weights = new float[retrieverComponents.size()];
        ScoreNormalizer[] normalizers = new ScoreNormalizer[retrieverComponents.size()];
        int index = 0;
        for (LinearRetrieverComponent component : retrieverComponents) {
            innerRetrievers.add(new CompoundRetrieverBuilder.RetrieverSource(component.retriever, null));
            weights[index] = component.weight;
            normalizers[index] = component.normalizer;
            ++index;
        }
        return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers);
    });

    private static float[] getDefaultWeight(int size) {
        float[] weights = new float[size];
        Arrays.fill(weights, 1.0f);
        return weights;
    }

    private static ScoreNormalizer[] getDefaultNormalizers(int size) {
        Object[] normalizers = new ScoreNormalizer[size];
        Arrays.fill(normalizers, IdentityScoreNormalizer.INSTANCE);
        return normalizers;
    }

    public static LinearRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
        if (!context.clusterSupportsFeature(RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED)) {
            throw new ParsingException(parser.getTokenLocation(), "unknown retriever [linear]", new Object[0]);
        }
        if (!RRFRankPlugin.LINEAR_RETRIEVER_FEATURE.check(XPackPlugin.getSharedLicenseState())) {
            throw LicenseUtils.newComplianceException((String)"linear retriever");
        }
        return (LinearRetrieverBuilder)((Object)PARSER.apply(parser, (Object)context));
    }

    LinearRetrieverBuilder(List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers, int rankWindowSize) {
        this(innerRetrievers, rankWindowSize, LinearRetrieverBuilder.getDefaultWeight(innerRetrievers.size()), LinearRetrieverBuilder.getDefaultNormalizers(innerRetrievers.size()));
    }

    public LinearRetrieverBuilder(List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers, int rankWindowSize, float[] weights, ScoreNormalizer[] normalizers) {
        super(innerRetrievers, rankWindowSize);
        if (weights.length != innerRetrievers.size()) {
            throw new IllegalArgumentException("The number of weights must match the number of inner retrievers");
        }
        if (normalizers.length != innerRetrievers.size()) {
            throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers");
        }
        this.weights = weights;
        this.normalizers = normalizers;
    }

    protected LinearRetrieverBuilder clone(List<CompoundRetrieverBuilder.RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
        LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, this.rankWindowSize, this.weights, this.normalizers);
        clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
        clone.retrieverName = this.retrieverName;
        return clone;
    }

    protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
        sourceBuilder.trackScores(true);
        return sourceBuilder;
    }

    protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean isExplain) {
        Map docsToRankResults = Maps.newMapWithExpectedSize((int)this.rankWindowSize);
        String[] normalizerNames = (String[])Arrays.stream(this.normalizers).map(ScoreNormalizer::getName).toArray(String[]::new);
        for (int result = 0; result < rankResults.size(); ++result) {
            IdentityScoreNormalizer normalizer = this.normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : this.normalizers[result];
            ScoreDoc[] originalScoreDocs = rankResults.get(result);
            ScoreDoc[] normalizedScoreDocs = ((ScoreNormalizer)normalizer).normalizeScores(originalScoreDocs);
            for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; ++scoreDocIndex) {
                LinearRankDoc rankDoc = docsToRankResults.computeIfAbsent(new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex), key -> {
                    if (isExplain) {
                        LinearRankDoc doc = new LinearRankDoc(key.doc(), 0.0f, key.shardIndex(), this.weights, normalizerNames);
                        doc.normalizedScores = new float[rankResults.size()];
                        return doc;
                    }
                    return new LinearRankDoc(key.doc(), 0.0f, key.shardIndex());
                });
                if (isExplain) {
                    rankDoc.normalizedScores[result] = normalizedScoreDocs[scoreDocIndex].score;
                }
                float docScore = false == Float.isNaN(normalizedScoreDocs[scoreDocIndex].score) ? normalizedScoreDocs[scoreDocIndex].score : 0.0f;
                float weight = Float.isNaN(this.weights[result]) ? 1.0f : this.weights[result];
                rankDoc.score += weight * docScore;
            }
        }
        LinearRankDoc[] sortedResults = (LinearRankDoc[])docsToRankResults.values().toArray(LinearRankDoc[]::new);
        Arrays.sort((Object[])sortedResults);
        RankDoc[] topResults = new LinearRankDoc[Math.min(this.rankWindowSize, sortedResults.length)];
        for (int rank = 0; rank < topResults.length; ++rank) {
            topResults[rank] = sortedResults[rank];
            topResults[rank].rank = rank + 1;
        }
        return topResults;
    }

    public String getName() {
        return NAME;
    }

    public void doToXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        int index = 0;
        if (!this.innerRetrievers.isEmpty()) {
            builder.startArray(RETRIEVERS_FIELD.getPreferredName());
            for (CompoundRetrieverBuilder.RetrieverSource entry : this.innerRetrievers) {
                builder.startObject();
                builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), (ToXContent)entry.retriever());
                builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), this.weights[index]);
                builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), this.normalizers[index].getName());
                builder.endObject();
                ++index;
            }
            builder.endArray();
        }
        builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), this.rankWindowSize);
    }

    static {
        PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
        RetrieverBuilder.declareBaseParserFields((String)NAME, PARSER);
    }
}

