/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.ltr;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.rescore.Rescorer;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.ltr.FeatureExtractor;
import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankRescorerContext;

public class LearningToRankRescorer
implements Rescorer {
    private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
    public static final LearningToRankRescorer INSTANCE = new LearningToRankRescorer();
    private static final Logger logger = LogManager.getLogger(LearningToRankRescorer.class);

    private LearningToRankRescorer() {
    }

    public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) throws IOException {
        if (topDocs.scoreDocs.length == 0) {
            return topDocs;
        }
        LearningToRankRescorerContext ltrRescoreContext = (LearningToRankRescorerContext)rescoreContext;
        if (ltrRescoreContext.regressionModelDefinition == null) {
            throw new IllegalStateException("local model reference is null, missing rewriteAndFetch before rescore phase?");
        }
        LocalModel definition = ltrRescoreContext.regressionModelDefinition;
        topDocs = Rescorer.topN((TopDocs)topDocs, (int)rescoreContext.getWindowSize());
        Set topDocIDs = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toUnmodifiableSet());
        rescoreContext.setRescoredDocs(topDocIDs);
        ScoreDoc[] hitsToRescore = topDocs.scoreDocs;
        Arrays.sort(hitsToRescore, Comparator.comparingInt(a -> a.doc));
        int readerUpto = -1;
        int endDoc = 0;
        int docBase = 0;
        List leaves = ltrRescoreContext.executionContext.searcher().getIndexReader().leaves();
        LeafReaderContext currentSegment = null;
        boolean changedSegment = true;
        List<FeatureExtractor> featureExtractors = ltrRescoreContext.buildFeatureExtractors(searcher);
        ArrayList<Map> docFeatures = new ArrayList<Map>(topDocIDs.size());
        int featureSize = featureExtractors.stream().mapToInt(fe -> fe.featureNames().size()).sum();
        int count = 0;
        for (int hitUpto = 0; hitUpto < hitsToRescore.length; ++hitUpto) {
            if (count % 10 == 0) {
                rescoreContext.checkCancellation();
            }
            ++count;
            ScoreDoc hit = hitsToRescore[hitUpto];
            int docID = hit.doc;
            while (docID >= endDoc) {
                currentSegment = (LeafReaderContext)leaves.get(++readerUpto);
                endDoc = currentSegment.docBase + currentSegment.reader().maxDoc();
                changedSegment = true;
            }
            assert (currentSegment != null) : "Unexpected null segment";
            if (changedSegment) {
                docBase = currentSegment.docBase;
                for (FeatureExtractor featureExtractor : featureExtractors) {
                    featureExtractor.setNextReader(currentSegment);
                }
                changedSegment = false;
            }
            int targetDoc = docID - docBase;
            Map features = Maps.newMapWithExpectedSize((int)featureSize);
            for (FeatureExtractor featureExtractor : featureExtractors) {
                featureExtractor.addFeatures(features, targetDoc);
            }
            logger.debug(() -> Strings.format((String)"doc [%d] has features [%s]", (Object[])new Object[]{targetDoc, features}));
            docFeatures.add(features);
        }
        for (int i = 0; i < hitsToRescore.length; ++i) {
            if (i % 10 == 0) {
                rescoreContext.checkCancellation();
            }
            Map features = (Map)docFeatures.get(i);
            try {
                InferenceResults results = definition.inferLtr(features, (InferenceConfig)ltrRescoreContext.learningToRankConfig);
                if (results instanceof WarningInferenceResults) {
                    WarningInferenceResults warningInferenceResults = (WarningInferenceResults)results;
                    logger.warn("Failure rescoring doc, warning returned [" + warningInferenceResults.getWarning() + "]");
                    continue;
                }
                Object object = results.predictedValue();
                if (object instanceof Number) {
                    Number prediction = (Number)object;
                    hitsToRescore[i].score = prediction.floatValue();
                    continue;
                }
                logger.warn("Failure rescoring doc, unexpected inference result of kind [" + results.getWriteableName() + "]");
                continue;
            }
            catch (Exception ex) {
                logger.warn("Failure rescoring doc...", (Throwable)ex);
            }
        }
        assert (rescoreContext.getWindowSize() >= hitsToRescore.length) : "unexpected, windows size [" + rescoreContext.getWindowSize() + "] should be gte [" + hitsToRescore.length + "]";
        Arrays.sort(topDocs.scoreDocs, SCORE_DOC_COMPARATOR);
        return topDocs;
    }

    public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext, Explanation sourceExplanation) throws IOException {
        if (sourceExplanation == null) {
            return Explanation.noMatch((String)"no match found", (Explanation[])new Explanation[0]);
        }
        LearningToRankRescorerContext ltrContext = (LearningToRankRescorerContext)rescoreContext;
        LocalModel localModelDefinition = ltrContext.regressionModelDefinition;
        if (localModelDefinition == null) {
            throw new IllegalStateException("local model reference is null, missing rewriteAndFetch before rescore phase?");
        }
        List leaves = ltrContext.executionContext.searcher().getIndexReader().leaves();
        int endDoc = 0;
        int readerUpto = -1;
        LeafReaderContext currentSegment = null;
        while (topLevelDocId >= endDoc) {
            currentSegment = (LeafReaderContext)leaves.get(++readerUpto);
            endDoc = currentSegment.docBase + currentSegment.reader().maxDoc();
        }
        assert (currentSegment != null) : "Unexpected null segment";
        int targetDoc = topLevelDocId - currentSegment.docBase;
        List<FeatureExtractor> featureExtractors = ltrContext.buildFeatureExtractors(searcher);
        int featureSize = featureExtractors.stream().mapToInt(fe -> fe.featureNames().size()).sum();
        Map features = Maps.newMapWithExpectedSize((int)featureSize);
        for (FeatureExtractor featureExtractor : featureExtractors) {
            featureExtractor.setNextReader(currentSegment);
            featureExtractor.addFeatures(features, targetDoc);
        }
        float ltrScore = ((Number)localModelDefinition.inferLtr(features, (InferenceConfig)ltrContext.learningToRankConfig).predictedValue()).floatValue();
        ArrayList<Explanation> featureExplanations = new ArrayList<Explanation>();
        for (String featureName : features.keySet()) {
            Number featureValue = Objects.requireNonNullElse((Number)features.get(featureName), 0);
            featureExplanations.add(Explanation.match((Number)featureValue, (String)("feature value for [" + featureName + "]"), (Explanation[])new Explanation[0]));
        }
        return Explanation.match((Number)Float.valueOf(ltrScore), (String)("rescored using LTR model " + ltrContext.regressionModelDefinition.getModelId()), (Explanation[])new Explanation[]{Explanation.match((Number)sourceExplanation.getValue(), (String)"first pass query score", (Explanation[])new Explanation[]{sourceExplanation}), Explanation.match((Number)Float.valueOf(0.0f), (String)"extracted features", featureExplanations)});
    }
}

