/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.diversification.mmr;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.search.diversification.ResultDiversification;
import org.elasticsearch.search.diversification.ResultDiversificationContext;
import org.elasticsearch.search.diversification.mmr.MMRResultDiversificationContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.vectors.VectorData;

public class MMRResultDiversification
extends ResultDiversification<MMRResultDiversificationContext> {
    private static final VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;

    public MMRResultDiversification(MMRResultDiversificationContext context) {
        super(context);
    }

    @Override
    public RankDoc[] diversify(RankDoc[] docs) throws IOException {
        if (docs == null || docs.length == 0) {
            return docs;
        }
        HashMap<Integer, Integer> docIdIndexMapping = new HashMap<Integer, Integer>();
        for (int i = 0; i < docs.length; ++i) {
            docIdIndexMapping.put(docs[i].rank, i);
        }
        ArrayList<Integer> selectedDocRanks = new ArrayList<Integer>();
        RankDoc highestScoreDoc = Arrays.stream(docs).max(Comparator.comparingDouble(doc -> doc.score)).orElse(docs[0]);
        int highestScoreDocRank = highestScoreDoc.rank;
        selectedDocRanks.add(highestScoreDocRank);
        VectorData firstVec = ((MMRResultDiversificationContext)this.context).getFieldVector(highestScoreDocRank);
        boolean useFloat = firstVec.isFloat();
        Map<Integer, Float> querySimilarity = this.getQuerySimilarityForDocs(docs, useFloat, this.context);
        HashMap<Integer, Map> cachedSimilarities = new HashMap<Integer, Map>();
        int topDocsSize = ((MMRResultDiversificationContext)this.context).getSize();
        for (int x = 0; x < topDocsSize && selectedDocRanks.size() < topDocsSize && selectedDocRanks.size() < docs.length; ++x) {
            int thisMaxMMRDocRank = -1;
            float thisMaxMMRScore = Float.NEGATIVE_INFINITY;
            for (RankDoc doc2 : docs) {
                VectorData thisDocVector;
                int docRank = doc2.rank;
                if (selectedDocRanks.contains(docRank) || (thisDocVector = ((MMRResultDiversificationContext)this.context).getFieldVector(docRank)) == null) continue;
                Map cachedScoresForDoc = cachedSimilarities.getOrDefault(docRank, new HashMap());
                float highestMMRScore = this.getHighestScoreForSelectedVectors(docRank, (MMRResultDiversificationContext)this.context, useFloat, thisDocVector, cachedScoresForDoc);
                float querySimilarityScore = querySimilarity.getOrDefault(doc2.rank, Float.valueOf(0.0f)).floatValue();
                float mmr = ((MMRResultDiversificationContext)this.context).getLambda() * querySimilarityScore - (1.0f - ((MMRResultDiversificationContext)this.context).getLambda()) * highestMMRScore;
                if (mmr > thisMaxMMRScore) {
                    thisMaxMMRScore = mmr;
                    thisMaxMMRDocRank = docRank;
                }
                cachedSimilarities.put(docRank, cachedScoresForDoc);
            }
            if (thisMaxMMRDocRank < 0) continue;
            selectedDocRanks.add(thisMaxMMRDocRank);
        }
        ArrayList<Integer> returnDocIndices = new ArrayList<Integer>();
        for (Integer docRank : selectedDocRanks) {
            returnDocIndices.add((Integer)docIdIndexMapping.get(docRank));
        }
        returnDocIndices.sort(Integer::compareTo);
        RankDoc[] ret = new RankDoc[returnDocIndices.size()];
        for (int i = 0; i < returnDocIndices.size(); ++i) {
            ret[i] = docs[(Integer)returnDocIndices.get(i)];
        }
        return ret;
    }

    private float getHighestScoreForSelectedVectors(int docRank, MMRResultDiversificationContext context, boolean useFloat, VectorData thisDocVector, Map<Integer, Float> cachedScoresForDoc) {
        float highestScore = Float.MIN_VALUE;
        for (Map.Entry<Integer, VectorData> vec : context.getFieldVectorsEntrySet()) {
            if (vec.getKey().equals(docRank)) continue;
            if (cachedScoresForDoc.containsKey(vec.getKey())) {
                float score = cachedScoresForDoc.get(vec.getKey()).floatValue();
                if (!(score > highestScore)) continue;
                highestScore = score;
                continue;
            }
            VectorData comparisonVector = vec.getValue();
            float score = useFloat ? this.getFloatVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector) : this.getByteVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector);
            cachedScoresForDoc.put(vec.getKey(), Float.valueOf(score));
            if (!(score > highestScore)) continue;
            highestScore = score;
        }
        return highestScore;
    }

    protected Map<Integer, Float> getQuerySimilarityForDocs(RankDoc[] docs, boolean useFloat, ResultDiversificationContext context) {
        HashMap<Integer, Float> querySimilarity = new HashMap<Integer, Float>();
        VectorData queryVector = context.getQueryVector();
        if (queryVector == null) {
            return querySimilarity;
        }
        for (RankDoc doc : docs) {
            VectorData vectorData = context.getFieldVector(doc.rank);
            if (vectorData == null) continue;
            float querySimilarityScore = useFloat ? this.getFloatVectorComparisonScore(similarityFunction, vectorData, queryVector) : this.getByteVectorComparisonScore(similarityFunction, vectorData, queryVector);
            querySimilarity.put(doc.rank, Float.valueOf(querySimilarityScore));
        }
        return querySimilarity;
    }
}

