/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.inference;

import java.io.IOException;
import java.util.List;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.vectors.TokenPruningConfig;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.WeightedToken;
import org.elasticsearch.search.vectors.SparseVectorQueryWrapper;

public final class WeightedTokensUtils {
    private WeightedTokensUtils() {
    }

    public static Query queryBuilderWithAllTokens(String fieldName, List<WeightedToken> tokens, MappedFieldType ft, SearchExecutionContext context) {
        BooleanQuery.Builder qb = new BooleanQuery.Builder();
        for (WeightedToken token : tokens) {
            qb.add((Query)new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD);
        }
        return new SparseVectorQueryWrapper(fieldName, (Query)qb.setMinimumNumberShouldMatch(1).build());
    }

    public static Query queryBuilderWithPrunedTokens(String fieldName, TokenPruningConfig tokenPruningConfig, List<WeightedToken> tokens, MappedFieldType ft, SearchExecutionContext context) throws IOException {
        BooleanQuery.Builder qb = new BooleanQuery.Builder();
        int fieldDocCount = context.getIndexReader().getDocCount(fieldName);
        float bestWeight = tokens.stream().map(WeightedToken::weight).reduce(Float.valueOf(0.0f), Math::max).floatValue();
        float averageTokenFreqRatio = WeightedTokensUtils.getAverageTokenFreqRatio(fieldName, context.getIndexReader(), fieldDocCount);
        if (averageTokenFreqRatio == 0.0f) {
            return new MatchNoDocsQuery("query is against an empty field");
        }
        for (WeightedToken token : tokens) {
            boolean keep = WeightedTokensUtils.shouldKeepToken(fieldName, tokenPruningConfig, context.getIndexReader(), token, fieldDocCount, averageTokenFreqRatio, bestWeight);
            if (!(keep ^= tokenPruningConfig != null && tokenPruningConfig.isOnlyScorePrunedTokens())) continue;
            qb.add((Query)new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD);
        }
        return new SparseVectorQueryWrapper(fieldName, (Query)qb.setMinimumNumberShouldMatch(1).build());
    }

    private static float getAverageTokenFreqRatio(String fieldName, IndexReader reader, int fieldDocCount) throws IOException {
        int numUniqueTokens = 0;
        for (LeafReaderContext leaf : reader.getContext().leaves()) {
            Terms terms = leaf.reader().terms(fieldName);
            if (terms == null) continue;
            numUniqueTokens = (int)Math.max(terms.size(), (long)numUniqueTokens);
        }
        if (numUniqueTokens == 0) {
            return 0.0f;
        }
        return (float)reader.getSumDocFreq(fieldName) / (float)fieldDocCount / (float)numUniqueTokens;
    }

    private static boolean shouldKeepToken(String fieldName, TokenPruningConfig tokenPruningConfig, IndexReader reader, WeightedToken token, int fieldDocCount, float averageTokenFreqRatio, float bestWeight) throws IOException {
        if (tokenPruningConfig == null) {
            return true;
        }
        int docFreq = reader.docFreq(new Term(fieldName, token.token()));
        if (docFreq == 0) {
            return false;
        }
        float tokenFreqRatio = (float)docFreq / (float)fieldDocCount;
        return tokenFreqRatio < tokenPruningConfig.getTokensFreqRatioThreshold() * averageTokenFreqRatio || token.weight() > tokenPruningConfig.getTokensWeightThreshold() * bestWeight;
    }
}

