/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.queries;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.search.QueryParserHelper;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.inference.InferenceException;
import org.elasticsearch.xpack.inference.queries.FullyQualifiedInferenceId;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;

public abstract class InterceptedInferenceQueryBuilder<T extends AbstractQueryBuilder<T>>
extends AbstractQueryBuilder<InterceptedInferenceQueryBuilder<T>> {
    public static final NodeFeature NEW_SEMANTIC_QUERY_INTERCEPTORS = new NodeFeature("search.new_semantic_query_interceptors");
    static final TransportVersion INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS = TransportVersion.fromName((String)"inference_results_map_with_cluster_alias");
    protected final T originalQuery;
    protected final Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap;
    protected final SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier;
    protected final boolean ccsRequest;

    protected InterceptedInferenceQueryBuilder(T originalQuery) {
        this(originalQuery, null);
    }

    protected InterceptedInferenceQueryBuilder(T originalQuery, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
        Objects.requireNonNull(originalQuery, "original query must not be null");
        this.originalQuery = originalQuery;
        this.inferenceResultsMap = inferenceResultsMap != null ? Map.copyOf(inferenceResultsMap) : null;
        this.inferenceResultsMapSupplier = null;
        this.ccsRequest = false;
    }

    protected InterceptedInferenceQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.originalQuery = (AbstractQueryBuilder)in.readNamedWriteable(QueryBuilder.class);
        this.inferenceResultsMap = in.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS) ? (Map<Object, Object>)in.readOptional(i1 -> i1.readImmutableMap(FullyQualifiedInferenceId::new, i2 -> (InferenceResults)i2.readNamedWriteable(InferenceResults.class))) : SemanticQueryBuilder.convertFromBwcInferenceResultsMap((Map)in.readOptional(i1 -> i1.readImmutableMap(i2 -> (InferenceResults)i2.readNamedWriteable(InferenceResults.class))));
        this.ccsRequest = in.getTransportVersion().supports(SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT) ? in.readBoolean() : false;
        this.inferenceResultsMapSupplier = null;
    }

    protected InterceptedInferenceQueryBuilder(InterceptedInferenceQueryBuilder<T> other, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap, SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier, boolean ccsRequest) {
        this.originalQuery = other.originalQuery;
        this.inferenceResultsMap = inferenceResultsMap;
        this.inferenceResultsMapSupplier = inferenceResultsMapSupplier;
        this.ccsRequest = ccsRequest;
    }

    protected abstract Map<String, Float> getFields();

    protected abstract String getQuery();

    protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext var1);

    protected abstract QueryBuilder copy(Map<FullyQualifiedInferenceId, InferenceResults> var1, SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> var2, boolean var3);

    protected abstract QueryBuilder queryFields(Map<String, Float> var1, Map<String, Float> var2, QueryRewriteContext var3);

    protected abstract boolean resolveWildcards();

    protected abstract boolean useDefaultFields();

    protected FullyQualifiedInferenceId getInferenceIdOverride() {
        return null;
    }

    protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        if (this.inferenceResultsMapSupplier != null) {
            throw new IllegalStateException("inferenceResultsMapSupplier must be null, can't serialize suppliers, missing a rewriteAndFetch?");
        }
        out.writeNamedWriteable(this.originalQuery);
        if (out.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
            out.writeOptional((o, v) -> o.writeMap(v, StreamOutput::writeWriteable, StreamOutput::writeNamedWriteable), this.inferenceResultsMap);
        } else {
            out.writeOptional((o1, v) -> o1.writeMap(v, (o2, id) -> {
                if (!id.clusterAlias().equals("")) {
                    throw new IllegalArgumentException("Cannot serialize remote cluster inference results in a mixed-version cluster");
                }
                o2.writeString(id.inferenceId());
            }, StreamOutput::writeNamedWriteable), this.inferenceResultsMap);
        }
        if (out.getTransportVersion().supports(SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT)) {
            out.writeBoolean(this.ccsRequest);
        } else if (this.ccsRequest) {
            throw new IllegalArgumentException("One or more nodes does not support " + this.originalQuery.getName() + " query cross-cluster search when querying a [semantic_text] field. Please update all nodes to at least Elasticsearch " + SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion() + ".");
        }
    }

    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.field(this.getName(), this.originalQuery);
    }

    protected Query doToQuery(SearchExecutionContext context) {
        throw new UnsupportedOperationException("Query should be rewritten to a different type");
    }

    protected boolean doEquals(InterceptedInferenceQueryBuilder<T> other) {
        return Objects.equals(this.originalQuery, other.originalQuery) && Objects.equals(this.inferenceResultsMap, other.inferenceResultsMap) && Objects.equals(this.inferenceResultsMapSupplier, other.inferenceResultsMapSupplier) && Objects.equals(this.ccsRequest, other.ccsRequest);
    }

    protected int doHashCode() {
        return Objects.hash(this.originalQuery, this.inferenceResultsMap, this.inferenceResultsMapSupplier, this.ccsRequest);
    }

    public TransportVersion getMinimalSupportedVersion() {
        return this.originalQuery.getMinimalSupportedVersion();
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
        QueryRewriteContext indexMetadataContext = queryRewriteContext.convertToIndexMetadataContext();
        if (indexMetadataContext != null) {
            return this.doRewriteBuildQuery(indexMetadataContext);
        }
        ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
        if (resolvedIndices != null) {
            return this.doRewriteGetInferenceResults(queryRewriteContext);
        }
        return this;
    }

    private QueryBuilder doRewriteBuildQuery(QueryRewriteContext indexMetadataContext) {
        Map<String, Float> queryFields = this.getFields();
        if (this.useDefaultFields() && queryFields.isEmpty()) {
            queryFields = InterceptedInferenceQueryBuilder.getDefaultFields(indexMetadataContext.getIndexSettings().getSettings());
        }
        Map<String, Float> inferenceFieldsToQuery = InterceptedInferenceQueryBuilder.getInferenceFieldsMap(indexMetadataContext, queryFields, this.resolveWildcards());
        HashMap<String, Float> nonInferenceFieldsToQuery = new HashMap<String, Float>(queryFields);
        nonInferenceFieldsToQuery.keySet().removeAll(inferenceFieldsToQuery.keySet());
        return this.queryFields(inferenceFieldsToQuery, nonInferenceFieldsToQuery, indexMetadataContext);
    }

    private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
        boolean ccsRequest;
        QueryBuilder rewrittenBwC = this.doRewriteBwC(queryRewriteContext);
        if (rewrittenBwC != this) {
            return rewrittenBwC;
        }
        ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
        Set<FullyQualifiedInferenceId> inferenceIds = InterceptedInferenceQueryBuilder.getInferenceIdsForFields(resolvedIndices.getConcreteLocalIndicesMetadata().values(), queryRewriteContext.getLocalClusterAlias(), this.getFields(), this.resolveWildcards(), this.useDefaultFields());
        if (inferenceIds.isEmpty() && !this.ccsRequest) {
            return this.originalQuery;
        }
        this.coordinatorNodeValidate(resolvedIndices);
        boolean bl = ccsRequest = this.ccsRequest || !resolvedIndices.getRemoteClusterIndices().isEmpty();
        if (ccsRequest && !queryRewriteContext.isCcsMinimizeRoundTrips().booleanValue()) {
            throw new IllegalArgumentException(this.originalQuery.getName() + " query does not support cross-cluster search when querying a [semantic_text] field when [ccs_minimize_roundtrips] is false");
        }
        if (this.inferenceResultsMapSupplier != null) {
            return SemanticQueryBuilder.getNewInferenceResultsFromSupplier(this.inferenceResultsMapSupplier, this, m -> this.copy((Map<FullyQualifiedInferenceId, InferenceResults>)m, null, ccsRequest));
        }
        FullyQualifiedInferenceId inferenceIdOverride = this.getInferenceIdOverride();
        if (inferenceIdOverride != null) {
            inferenceIds = Set.of(inferenceIdOverride);
        }
        SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> newInferenceResultsMapSupplier = SemanticQueryBuilder.getInferenceResults(queryRewriteContext, inferenceIds, this.inferenceResultsMap, this.getQuery());
        InterceptedInferenceQueryBuilder rewritten = this;
        if (newInferenceResultsMapSupplier == null) {
            if (this.inferenceResultsMap != null) {
                InterceptedInferenceQueryBuilder.inferenceResultsErrorCheck(this.inferenceResultsMap);
            } else {
                rewritten = this.copy(Map.of(), null, ccsRequest);
            }
        } else {
            rewritten = this.copy(this.inferenceResultsMap, newInferenceResultsMapSupplier, ccsRequest);
        }
        return rewritten;
    }

    private static Set<FullyQualifiedInferenceId> getInferenceIdsForFields(Collection<IndexMetadata> indexMetadataCollection, String clusterAlias, Map<String, Float> fields, boolean resolveWildcards, boolean useDefaultFields) {
        HashSet<FullyQualifiedInferenceId> fullyQualifiedInferenceIds = new HashSet<FullyQualifiedInferenceId>();
        for (IndexMetadata indexMetadata : indexMetadataCollection) {
            Map<String, Float> indexQueryFields = useDefaultFields && fields.isEmpty() ? InterceptedInferenceQueryBuilder.getDefaultFields(indexMetadata.getSettings()) : fields;
            Map indexInferenceFields = indexMetadata.getInferenceFields();
            for (String indexQueryField : indexQueryFields.keySet()) {
                if (indexInferenceFields.containsKey(indexQueryField)) {
                    InferenceFieldMetadata inferenceFieldMetadata = (InferenceFieldMetadata)indexInferenceFields.get(indexQueryField);
                    fullyQualifiedInferenceIds.add(new FullyQualifiedInferenceId(clusterAlias, inferenceFieldMetadata.getSearchInferenceId()));
                    continue;
                }
                if (!resolveWildcards) continue;
                if (Regex.isMatchAllPattern((String)indexQueryField)) {
                    indexInferenceFields.values().forEach(ifm -> fullyQualifiedInferenceIds.add(new FullyQualifiedInferenceId(clusterAlias, ifm.getSearchInferenceId())));
                    continue;
                }
                if (!Regex.isSimpleMatchPattern((String)indexQueryField)) continue;
                indexInferenceFields.values().stream().filter(ifm -> Regex.simpleMatch((String)indexQueryField, (String)ifm.getName())).forEach(ifm -> fullyQualifiedInferenceIds.add(new FullyQualifiedInferenceId(clusterAlias, ifm.getSearchInferenceId())));
            }
        }
        return fullyQualifiedInferenceIds;
    }

    private static Map<String, Float> getInferenceFieldsMap(QueryRewriteContext indexMetadataContext, Map<String, Float> queryFields, boolean resolveWildcards) {
        Map indexInferenceFields = indexMetadataContext.getMappingLookup().inferenceFields();
        Map matchingInferenceFields = IndexMetadata.getMatchingInferenceFields((Map)indexInferenceFields, queryFields, (boolean)resolveWildcards);
        return matchingInferenceFields.entrySet().stream().collect(Collectors.toMap(e -> ((InferenceFieldMetadata)e.getKey()).getName(), Map.Entry::getValue));
    }

    private static Map<String, Float> getDefaultFields(Settings settings) {
        List defaultFieldsList = settings.getAsList(IndexSettings.DEFAULT_FIELD_SETTING.getKey(), (List)IndexSettings.DEFAULT_FIELD_SETTING.getDefault(settings));
        return QueryParserHelper.parseFieldsAndWeights((List)defaultFieldsList);
    }

    private static void inferenceResultsErrorCheck(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
        for (Map.Entry<FullyQualifiedInferenceId, InferenceResults> entry : inferenceResultsMap.entrySet()) {
            String inferenceId = entry.getKey().inferenceId();
            InferenceResults inferenceResults = entry.getValue();
            if (inferenceResults instanceof ErrorInferenceResults) {
                ErrorInferenceResults errorInferenceResults = (ErrorInferenceResults)inferenceResults;
                throw new InferenceException("Inference ID [" + inferenceId + "] query inference error", errorInferenceResults.getException(), new Object[0]);
            }
            if (!(inferenceResults instanceof WarningInferenceResults)) continue;
            WarningInferenceResults warningInferenceResults = (WarningInferenceResults)inferenceResults;
            throw new IllegalStateException("Inference ID [" + inferenceId + "] query inference warning: " + warningInferenceResults.getWarning());
        }
    }
}

