/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.queries;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.WeightedToken;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.FullyQualifiedInferenceId;
import org.elasticsearch.xpack.inference.queries.InterceptedInferenceQueryBuilder;
import org.elasticsearch.xpack.inference.queries.LegacySemanticSparseVectorQueryRewriteInterceptor;

public class InterceptedInferenceSparseVectorQueryBuilder
extends InterceptedInferenceQueryBuilder<SparseVectorQueryBuilder> {
    public static final String NAME = "intercepted_inference_sparse_vector";
    private static final QueryRewriteInterceptor BWC_INTERCEPTOR = new LegacySemanticSparseVectorQueryRewriteInterceptor();
    private static final TransportVersion NEW_SEMANTIC_QUERY_INTERCEPTORS = TransportVersion.fromName((String)"new_semantic_query_interceptors");

    public InterceptedInferenceSparseVectorQueryBuilder(SparseVectorQueryBuilder originalQuery) {
        super(originalQuery);
    }

    public InterceptedInferenceSparseVectorQueryBuilder(SparseVectorQueryBuilder originalQuery, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
        super(originalQuery, inferenceResultsMap);
    }

    public InterceptedInferenceSparseVectorQueryBuilder(StreamInput in) throws IOException {
        super(in);
    }

    private InterceptedInferenceSparseVectorQueryBuilder(InterceptedInferenceQueryBuilder<SparseVectorQueryBuilder> other, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap, SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier, boolean ccsRequest) {
        super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest);
    }

    @Override
    protected Map<String, Float> getFields() {
        return Map.of(this.getField(), Float.valueOf(1.0f));
    }

    @Override
    protected String getQuery() {
        return ((SparseVectorQueryBuilder)this.originalQuery).getQuery();
    }

    @Override
    protected FullyQualifiedInferenceId getInferenceIdOverride() {
        FullyQualifiedInferenceId override = null;
        String originalInferenceId = ((SparseVectorQueryBuilder)this.originalQuery).getInferenceId();
        if (originalInferenceId != null) {
            override = new FullyQualifiedInferenceId("", originalInferenceId);
        }
        return override;
    }

    @Override
    protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
        Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
        for (IndexMetadata indexMetadata : indexMetadataCollection) {
            InferenceFieldMetadata inferenceFieldMetadata = (InferenceFieldMetadata)indexMetadata.getInferenceFields().get(this.getField());
            if (inferenceFieldMetadata != null || ((SparseVectorQueryBuilder)this.originalQuery).getQuery() == null || ((SparseVectorQueryBuilder)this.originalQuery).getInferenceId() != null) continue;
            throw new IllegalArgumentException(SparseVectorQueryBuilder.INFERENCE_ID_FIELD.getPreferredName() + " required to perform vector search on query string");
        }
    }

    @Override
    protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException {
        InterceptedInferenceSparseVectorQueryBuilder rewritten = this;
        if (!queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS)) {
            rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, (QueryBuilder)this.originalQuery);
        }
        return rewritten;
    }

    protected InterceptedInferenceQueryBuilder<SparseVectorQueryBuilder> copy(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap, SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier, boolean ccsRequest) {
        return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest);
    }

    @Override
    protected QueryBuilder queryFields(Map<String, Float> inferenceFields, Map<String, Float> nonInferenceFields, QueryRewriteContext indexMetadataContext) {
        MatchNoneQueryBuilder rewritten;
        MappedFieldType fieldType = indexMetadataContext.getFieldType(this.getField());
        if (fieldType == null) {
            rewritten = new MatchNoneQueryBuilder();
        } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) {
            SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType)fieldType;
            rewritten = this.querySemanticTextField(indexMetadataContext.getLocalClusterAlias(), semanticTextFieldType);
        } else {
            rewritten = this.queryNonSemanticTextField();
        }
        return rewritten;
    }

    @Override
    protected boolean resolveWildcards() {
        return false;
    }

    @Override
    protected boolean useDefaultFields() {
        return false;
    }

    public String getWriteableName() {
        return NAME;
    }

    private String getField() {
        return ((SparseVectorQueryBuilder)this.originalQuery).getFieldName();
    }

    private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
        MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings();
        if (modelSettings == null) {
            return new MatchNoneQueryBuilder();
        }
        if (modelSettings.taskType() != TaskType.SPARSE_EMBEDDING) {
            throw new IllegalArgumentException("Field [" + this.getField() + "] does not use a [" + String.valueOf(TaskType.SPARSE_EMBEDDING) + "] model");
        }
        List<WeightedToken> queryVector = ((SparseVectorQueryBuilder)this.originalQuery).getQueryVectors();
        if (queryVector == null) {
            FullyQualifiedInferenceId fullyQualifiedInferenceId = this.getInferenceIdOverride();
            if (fullyQualifiedInferenceId == null) {
                fullyQualifiedInferenceId = new FullyQualifiedInferenceId(clusterAlias, semanticTextFieldType.getSearchInferenceId());
            }
            queryVector = this.getQueryVector(fullyQualifiedInferenceId);
        }
        SparseVectorQueryBuilder innerSparseVectorQuery = new SparseVectorQueryBuilder(SemanticTextField.getEmbeddingsFieldName(this.getField()), queryVector, null, null, ((SparseVectorQueryBuilder)this.originalQuery).shouldPruneTokens(), ((SparseVectorQueryBuilder)this.originalQuery).getTokenPruningConfig());
        return ((NestedQueryBuilder)QueryBuilders.nestedQuery((String)SemanticTextField.getChunksFieldName(this.getField()), (QueryBuilder)innerSparseVectorQuery, (ScoreMode)ScoreMode.Max).boost(((SparseVectorQueryBuilder)this.originalQuery).boost())).queryName(((SparseVectorQueryBuilder)this.originalQuery).queryName());
    }

    private QueryBuilder queryNonSemanticTextField() {
        List<WeightedToken> queryVector = ((SparseVectorQueryBuilder)this.originalQuery).getQueryVectors();
        if (queryVector == null) {
            FullyQualifiedInferenceId fullyQualifiedInferenceId = this.getInferenceIdOverride();
            if (fullyQualifiedInferenceId == null) {
                throw new IllegalArgumentException("Either query vector or inference ID must be specified");
            }
            queryVector = this.getQueryVector(fullyQualifiedInferenceId);
        }
        return ((SparseVectorQueryBuilder)new SparseVectorQueryBuilder(this.getField(), queryVector, null, null, ((SparseVectorQueryBuilder)this.originalQuery).shouldPruneTokens(), ((SparseVectorQueryBuilder)this.originalQuery).getTokenPruningConfig()).boost(((SparseVectorQueryBuilder)this.originalQuery).boost())).queryName(((SparseVectorQueryBuilder)this.originalQuery).queryName());
    }

    private List<WeightedToken> getQueryVector(FullyQualifiedInferenceId fullyQualifiedInferenceId) {
        InferenceResults inferenceResults = (InferenceResults)this.inferenceResultsMap.get(fullyQualifiedInferenceId);
        if (inferenceResults == null) {
            throw new IllegalStateException("Could not find inference results from inference endpoint [" + String.valueOf(fullyQualifiedInferenceId) + "]");
        }
        if (!(inferenceResults instanceof TextExpansionResults)) {
            throw new IllegalArgumentException("Expected query inference results to be of type [text_expansion_result], got [" + inferenceResults.getWriteableName() + "]. Are you specifying a compatible inference endpoint? Has the inference endpoint configuration changed?");
        }
        TextExpansionResults textExpansionResults = (TextExpansionResults)inferenceResults;
        return textExpansionResults.getWeightedTokens();
    }
}

