/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.diversification.mmr;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.search.diversification.ResultDiversification;
import org.elasticsearch.search.diversification.ResultDiversificationContext;
import org.elasticsearch.search.diversification.mmr.MMRResultDiversificationContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.vectors.VectorData;

public class MMRResultDiversification
extends ResultDiversification<MMRResultDiversificationContext> {
    private static final VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;

    public MMRResultDiversification(MMRResultDiversificationContext context) {
        super(context);
    }

    @Override
    public RankDoc[] diversify(RankDoc[] docs) throws IOException {
        if (docs == null || docs.length == 0) {
            return docs;
        }
        LinkedHashMap<Integer, Integer> docIdIndexMapping = new LinkedHashMap<Integer, Integer>();
        for (int i = 0; i < docs.length; ++i) {
            docIdIndexMapping.put(docs[i].rank, i);
        }
        ArrayList<Integer> selectedDocRanks = new ArrayList<Integer>();
        Map<Integer, Float> querySimilarity = this.getQuerySimilarityForDocs(docs, this.context);
        selectedDocRanks.add(this.getHighestRelevantDocRank(docs, querySimilarity));
        HashMap<Integer, Map> cachedSimilarities = new HashMap<Integer, Map>();
        int topDocsSize = ((MMRResultDiversificationContext)this.context).getSize();
        for (int x = 0; x < topDocsSize && selectedDocRanks.size() < topDocsSize && selectedDocRanks.size() < docs.length; ++x) {
            int thisMaxMMRDocRank = -1;
            float thisMaxMMRScore = Float.NEGATIVE_INFINITY;
            for (RankDoc doc : docs) {
                VectorData thisDocVector;
                int docRank = doc.rank;
                if (selectedDocRanks.contains(docRank) || (thisDocVector = ((MMRResultDiversificationContext)this.context).getFieldVector(docRank)) == null) continue;
                Map cachedScoresForDoc = cachedSimilarities.getOrDefault(docRank, new LinkedHashMap());
                float highestSimilarityScoreToSelected = this.getHighestSimilarityScoreToSelectedVectors(selectedDocRanks, thisDocVector, cachedScoresForDoc);
                float querySimilarityScore = querySimilarity.getOrDefault(doc.rank, Float.valueOf(0.0f)).floatValue();
                float mmr = ((MMRResultDiversificationContext)this.context).getLambda() * querySimilarityScore - (1.0f - ((MMRResultDiversificationContext)this.context).getLambda()) * highestSimilarityScoreToSelected;
                if (mmr > thisMaxMMRScore) {
                    thisMaxMMRScore = mmr;
                    thisMaxMMRDocRank = docRank;
                }
                cachedSimilarities.put(docRank, cachedScoresForDoc);
            }
            if (thisMaxMMRDocRank < 0) continue;
            selectedDocRanks.add(thisMaxMMRDocRank);
        }
        ArrayList<Integer> returnDocIndices = new ArrayList<Integer>();
        for (Integer docRank : selectedDocRanks) {
            returnDocIndices.add((Integer)docIdIndexMapping.get(docRank));
        }
        returnDocIndices.sort(Integer::compareTo);
        RankDoc[] ret = new RankDoc[returnDocIndices.size()];
        for (int i = 0; i < returnDocIndices.size(); ++i) {
            ret[i] = docs[(Integer)returnDocIndices.get(i)];
        }
        return ret;
    }

    private Integer getHighestRelevantDocRank(RankDoc[] docs, Map<Integer, Float> querySimilarity) {
        Map.Entry highestRelevantDoc = querySimilarity.entrySet().stream().max(Comparator.comparingDouble(Map.Entry::getValue)).orElse(null);
        if (highestRelevantDoc != null) {
            return (Integer)highestRelevantDoc.getKey();
        }
        RankDoc highestScoreDoc = Arrays.stream(docs).max(Comparator.comparingDouble(doc -> doc.score)).orElse(docs[0]);
        return highestScoreDoc.rank;
    }

    private float getHighestSimilarityScoreToSelectedVectors(List<Integer> selectedDocRanks, VectorData thisDocVector, Map<Integer, Float> cachedScoresForDoc) {
        float highestScore = Float.NEGATIVE_INFINITY;
        for (Integer compareToDocRank : selectedDocRanks) {
            VectorData comparisonVector;
            Float similarityScore = cachedScoresForDoc.getOrDefault(compareToDocRank, null);
            if (similarityScore == null && (comparisonVector = ((MMRResultDiversificationContext)this.context).getFieldVector(compareToDocRank)) != null) {
                similarityScore = Float.valueOf(this.getVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector));
                cachedScoresForDoc.put(compareToDocRank, similarityScore);
            }
            if (similarityScore == null || !(similarityScore.floatValue() > highestScore)) continue;
            highestScore = similarityScore.floatValue();
        }
        return highestScore == Float.NEGATIVE_INFINITY ? 0.0f : highestScore;
    }

    protected Map<Integer, Float> getQuerySimilarityForDocs(RankDoc[] docs, ResultDiversificationContext context) {
        HashMap<Integer, Float> querySimilarity = new HashMap<Integer, Float>();
        VectorData queryVector = context.getQueryVector();
        if (queryVector == null) {
            return querySimilarity;
        }
        for (RankDoc doc : docs) {
            VectorData vectorData = context.getFieldVector(doc.rank);
            if (vectorData == null) continue;
            float querySimilarityScore = this.getVectorComparisonScore(similarityFunction, vectorData, queryVector);
            querySimilarity.put(doc.rank, Float.valueOf(querySimilarityScore));
        }
        return querySimilarity;
    }
}

