/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.rank;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ValidateActions;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.index.search.QueryParserHelper;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;

public class MultiFieldsInnerRetrieverUtils {
    private MultiFieldsInnerRetrieverUtils() {
    }

    public static ActionRequestValidationException validateParams(List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers, @Nullable List<String> fields, @Nullable String query, String retrieverName, String retrieversParamName, String fieldsParamName, String queryParamName, ActionRequestValidationException validationException) {
        if (fields != null || query != null) {
            if (query == null) {
                return ValidateActions.addValidationError((String)String.format(Locale.ROOT, "[%s] [%s] must be provided when [%s] is specified", retrieverName, queryParamName, fieldsParamName), (ActionRequestValidationException)validationException);
            }
            if (query.isEmpty()) {
                validationException = ValidateActions.addValidationError((String)String.format(Locale.ROOT, "[%s] [%s] cannot be empty", retrieverName, queryParamName), (ActionRequestValidationException)validationException);
            }
            if (fields != null && fields.isEmpty()) {
                validationException = ValidateActions.addValidationError((String)String.format(Locale.ROOT, "[%s] [%s] cannot be empty", retrieverName, fieldsParamName), (ActionRequestValidationException)validationException);
            }
            if (!innerRetrievers.isEmpty()) {
                validationException = ValidateActions.addValidationError((String)String.format(Locale.ROOT, "[%s] cannot combine [%s] and [%s]", retrieverName, retrieversParamName, queryParamName), (ActionRequestValidationException)validationException);
            }
        } else if (innerRetrievers.isEmpty()) {
            validationException = ValidateActions.addValidationError((String)String.format(Locale.ROOT, "[%s] must provide [%s] or [%s]", retrieverName, retrieversParamName, queryParamName), (ActionRequestValidationException)validationException);
        }
        return validationException;
    }

    public static List<RetrieverBuilder> generateInnerRetrievers(@Nullable List<String> fieldsAndWeights, String query, Collection<IndexMetadata> indicesMetadata, Function<List<WeightedRetrieverSource>, CompoundRetrieverBuilder<?>> innerNormalizerGenerator, @Nullable Consumer<Float> weightValidator) {
        RetrieverBuilder semanticRetriever;
        Map parsedFieldsAndWeights;
        Map map = parsedFieldsAndWeights = fieldsAndWeights != null ? QueryParserHelper.parseFieldsAndWeights(fieldsAndWeights) : Map.of();
        if (weightValidator != null) {
            parsedFieldsAndWeights.values().forEach(weightValidator);
        }
        ArrayList<RetrieverBuilder> innerRetrievers = new ArrayList<RetrieverBuilder>(2);
        RetrieverBuilder lexicalRetriever = MultiFieldsInnerRetrieverUtils.generateLexicalRetriever(parsedFieldsAndWeights, indicesMetadata, query, weightValidator);
        if (lexicalRetriever != null) {
            innerRetrievers.add(lexicalRetriever);
        }
        if ((semanticRetriever = MultiFieldsInnerRetrieverUtils.generateSemanticRetriever(parsedFieldsAndWeights, indicesMetadata, query, innerNormalizerGenerator, weightValidator)) != null) {
            innerRetrievers.add(semanticRetriever);
        }
        return innerRetrievers;
    }

    private static RetrieverBuilder generateSemanticRetriever(Map<String, Float> parsedFieldsAndWeights, Collection<IndexMetadata> indicesMetadata, String query, Function<List<WeightedRetrieverSource>, CompoundRetrieverBuilder<?>> innerNormalizerGenerator, @Nullable Consumer<Float> weightValidator) {
        HashMap<Tuple, List> groupedIndices = new HashMap<Tuple, List>();
        HashMap groupedWeights = new HashMap();
        for (IndexMetadata indexMetadata : indicesMetadata) {
            MultiFieldsInnerRetrieverUtils.inferenceFieldsAndWeightsForIndex(parsedFieldsAndWeights, indexMetadata, weightValidator).forEach((fieldName, weight) -> {
                String indexName = indexMetadata.getIndex().getName();
                Tuple fieldAndInferenceId = new Tuple(fieldName, (Object)((InferenceFieldMetadata)indexMetadata.getInferenceFields().get(fieldName)).getInferenceId());
                List existingIndexNames = (List)groupedIndices.get(fieldAndInferenceId);
                if (existingIndexNames != null && !((Float)groupedWeights.get(fieldAndInferenceId)).equals(weight)) {
                    String conflictingIndexName = (String)existingIndexNames.getFirst();
                    throw new IllegalArgumentException("field [" + fieldName + "] has different weights in indices [" + conflictingIndexName + "] and [" + indexName + "]");
                }
                groupedWeights.put(fieldAndInferenceId, weight);
                groupedIndices.computeIfAbsent(fieldAndInferenceId, k -> new ArrayList()).add(indexName);
            });
        }
        if (groupedIndices.isEmpty()) {
            return null;
        }
        ArrayList semanticRetrievers = new ArrayList(groupedIndices.size());
        groupedIndices.forEach((fieldAndInferenceId, indexNames) -> {
            String fieldName = (String)fieldAndInferenceId.v1();
            Float weight = (Float)groupedWeights.get(fieldAndInferenceId);
            MatchQueryBuilder queryBuilder = new MatchQueryBuilder(fieldName, (Object)query);
            if (indicesMetadata.size() != indexNames.size()) {
                queryBuilder = new BoolQueryBuilder().must((QueryBuilder)queryBuilder).filter((QueryBuilder)new TermsQueryBuilder("_index", (Collection)indexNames));
            }
            StandardRetrieverBuilder retrieverBuilder = new StandardRetrieverBuilder((QueryBuilder)queryBuilder);
            semanticRetrievers.add(new WeightedRetrieverSource(CompoundRetrieverBuilder.RetrieverSource.from((RetrieverBuilder)retrieverBuilder), weight.floatValue()));
        });
        return (RetrieverBuilder)innerNormalizerGenerator.apply(semanticRetrievers);
    }

    private static Map<String, Float> defaultFieldsAndWeightsForIndex(IndexMetadata indexMetadata, @Nullable Consumer<Float> weightValidator) {
        Settings settings = indexMetadata.getSettings();
        List defaultFields = settings.getAsList(IndexSettings.DEFAULT_FIELD_SETTING.getKey(), (List)IndexSettings.DEFAULT_FIELD_SETTING.getDefault(settings));
        Map fieldsAndWeights = QueryParserHelper.parseFieldsAndWeights((List)defaultFields);
        if (weightValidator != null) {
            fieldsAndWeights.values().forEach(weightValidator);
        }
        return fieldsAndWeights;
    }

    private static Map<String, Float> inferenceFieldsAndWeightsForIndex(Map<String, Float> parsedFieldsAndWeights, IndexMetadata indexMetadata, @Nullable Consumer<Float> weightValidator) {
        Map<String, Float> fieldsAndWeightsToQuery = parsedFieldsAndWeights;
        if (fieldsAndWeightsToQuery.isEmpty()) {
            fieldsAndWeightsToQuery = MultiFieldsInnerRetrieverUtils.defaultFieldsAndWeightsForIndex(indexMetadata, weightValidator);
        }
        Map indexInferenceFields = indexMetadata.getInferenceFields();
        return IndexMetadata.getMatchingInferenceFields((Map)indexInferenceFields, fieldsAndWeightsToQuery, (boolean)true).entrySet().stream().collect(Collectors.toMap(e -> ((InferenceFieldMetadata)e.getKey()).getName(), Map.Entry::getValue));
    }

    private static Map<String, Float> nonInferenceFieldsAndWeightsForIndex(Map<String, Float> fieldsAndWeightsToQuery, IndexMetadata indexMetadata, @Nullable Consumer<Float> weightValidator) {
        Map<String, Float> nonInferenceFields = new HashMap<String, Float>(fieldsAndWeightsToQuery);
        if (nonInferenceFields.isEmpty()) {
            nonInferenceFields = MultiFieldsInnerRetrieverUtils.defaultFieldsAndWeightsForIndex(indexMetadata, weightValidator);
        }
        nonInferenceFields.keySet().removeAll(indexMetadata.getInferenceFields().keySet());
        return nonInferenceFields;
    }

    private static RetrieverBuilder generateLexicalRetriever(Map<String, Float> fieldsAndWeightsToQuery, Collection<IndexMetadata> indicesMetadata, String query, @Nullable Consumer<Float> weightValidator) {
        HashMap<Map, List> groupedIndices = new HashMap<Map, List>();
        for (IndexMetadata indexMetadata : indicesMetadata) {
            Map<String, Float> map = MultiFieldsInnerRetrieverUtils.nonInferenceFieldsAndWeightsForIndex(fieldsAndWeightsToQuery, indexMetadata, weightValidator);
            if (map.isEmpty()) continue;
            groupedIndices.computeIfAbsent(map, k -> new ArrayList()).add(indexMetadata.getIndex().getName());
        }
        if (groupedIndices.isEmpty()) {
            return null;
        }
        ArrayList<MultiMatchQueryBuilder> lexicalQueryBuilders = new ArrayList<MultiMatchQueryBuilder>();
        for (Map.Entry entry : groupedIndices.entrySet()) {
            Map fieldsAndWeights = (Map)entry.getKey();
            List indices = (List)entry.getValue();
            MultiMatchQueryBuilder queryBuilder = new MultiMatchQueryBuilder((Object)query, new String[0]).type(MultiMatchQueryBuilder.Type.MOST_FIELDS).fields(fieldsAndWeights);
            if (indices.size() != indicesMetadata.size()) {
                queryBuilder = new BoolQueryBuilder().must((QueryBuilder)queryBuilder).filter((QueryBuilder)new TermsQueryBuilder("_index", (Collection)indices));
            }
            lexicalQueryBuilders.add(queryBuilder);
        }
        if (lexicalQueryBuilders.size() == 1) {
            return new StandardRetrieverBuilder((QueryBuilder)lexicalQueryBuilders.getFirst());
        }
        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
        lexicalQueryBuilders.forEach(arg_0 -> ((BoolQueryBuilder)boolQueryBuilder).should(arg_0));
        return new StandardRetrieverBuilder((QueryBuilder)boolQueryBuilder);
    }

    public record WeightedRetrieverSource(CompoundRetrieverBuilder.RetrieverSource retrieverSource, float weight) {
    }
}

