/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the "Elastic License
 * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
 * Public License v 1"; you may not use this file except in compliance with, at
 * your election, the "Elastic License 2.0", the "GNU Affero General Public
 * License v3.0 only", or the "Server Side Public License, v 1".
 */

package org.elasticsearch.search.internal;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.IndicesRequest;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.metadata.AliasMetadata;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.routing.SplitShardCountSummary;
import org.elasticsearch.common.CheckedBiConsumer;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.hash.MessageDigests;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.mapper.SourceLoader;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.seqno.SequenceNumbers;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.indices.AliasFilterParsingException;
import org.elasticsearch.indices.InvalidAliasNameException;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.SearchSortValuesAndFormats;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.AbstractTransportRequest;

import java.io.IOException;
import java.util.Map;

import static java.util.Collections.emptyMap;
import static org.elasticsearch.search.internal.SearchContext.TRACK_TOTAL_HITS_DISABLED;

/**
 * Shard level request that represents a search.
 * It provides all the methods that the {@link SearchContext} needs.
 * Provides a cache key based on its content that can be used to cache shard level response.
 */
public class ShardSearchRequest extends AbstractTransportRequest implements IndicesRequest {
    private final String clusterAlias;
    private final ShardId shardId;
    private final int shardRequestIndex;
    private final int numberOfShards;
    private final long waitForCheckpoint;
    private final TimeValue waitForCheckpointsTimeout;
    private final SearchType searchType;
    private final TimeValue scroll;
    private final float indexBoost;
    private Boolean requestCache;
    private final long nowInMillis;
    private final boolean allowPartialSearchResults;
    private final OriginalIndices originalIndices;

    private boolean canReturnNullResponseIfMatchNoDocs;
    private SearchSortValuesAndFormats bottomSortValues;

    // these are the only mutable fields, as they are subject to rewriting
    private AliasFilter aliasFilter;
    private SearchSourceBuilder source;
    private final ShardSearchContextId readerId;
    private final TimeValue keepAlive;

    private final TransportVersion channelVersion;

    /**
     * Should this request force {@link SourceLoader.Synthetic synthetic source}?
     * Use this to test if the mapping supports synthetic _source and to get a sense
     * of the worst case performance. Fetches with this enabled will be slower the
     * enabling synthetic source natively in the index.
     */
    private final boolean forceSyntheticSource;

    /**
     * Additional metadata specific to the resharding feature. See {@link org.elasticsearch.cluster.routing.SplitShardCountSummary}.
     */
    private final SplitShardCountSummary splitShardCountSummary;

    public static final TransportVersion SHARD_SEARCH_REQUEST_RESHARD_SHARD_COUNT_SUMMARY = TransportVersion.fromName(
        "shard_search_request_reshard_shard_count_summary"
    );

    // Test only constructor.
    public ShardSearchRequest(
        OriginalIndices originalIndices,
        SearchRequest searchRequest,
        ShardId shardId,
        int shardRequestIndex,
        int numberOfShards,
        AliasFilter aliasFilter,
        float indexBoost,
        long nowInMillis,
        @Nullable String clusterAlias
    ) {
        this(
            originalIndices,
            searchRequest,
            shardId,
            shardRequestIndex,
            numberOfShards,
            aliasFilter,
            indexBoost,
            nowInMillis,
            clusterAlias,
            null,
            null,
            SplitShardCountSummary.UNSET
        );
    }

    public ShardSearchRequest(
        OriginalIndices originalIndices,
        SearchRequest searchRequest,
        ShardId shardId,
        int shardRequestIndex,
        int numberOfShards,
        AliasFilter aliasFilter,
        float indexBoost,
        long nowInMillis,
        @Nullable String clusterAlias,
        ShardSearchContextId readerId,
        TimeValue keepAlive,
        SplitShardCountSummary splitShardCountSummary
    ) {
        this(
            originalIndices,
            shardId,
            shardRequestIndex,
            numberOfShards,
            searchRequest.searchType(),
            searchRequest.source(),
            searchRequest.requestCache(),
            aliasFilter,
            indexBoost,
            searchRequest.allowPartialSearchResults(),
            searchRequest.scroll(),
            nowInMillis,
            clusterAlias,
            readerId,
            keepAlive,
            computeWaitForCheckpoint(searchRequest.getWaitForCheckpoints(), shardId, shardRequestIndex),
            searchRequest.getWaitForCheckpointsTimeout(),
            searchRequest.isForceSyntheticSource(),
            splitShardCountSummary
        );
        // If allowPartialSearchResults is unset (ie null), the cluster-level default should have been substituted
        // at this stage. Any NPEs in the above are therefore an error in request preparation logic.
        assert searchRequest.allowPartialSearchResults() != null;
    }

    private static final long[] EMPTY_LONG_ARRAY = new long[0];

    public static long computeWaitForCheckpoint(Map<String, long[]> indexToWaitForCheckpoints, ShardId shardId, int shardRequestIndex) {
        final long[] waitForCheckpoints = indexToWaitForCheckpoints.getOrDefault(shardId.getIndex().getName(), EMPTY_LONG_ARRAY);

        long waitForCheckpoint;
        if (waitForCheckpoints.length == 0) {
            waitForCheckpoint = SequenceNumbers.UNASSIGNED_SEQ_NO;
        } else {
            assert waitForCheckpoints.length > shardRequestIndex;
            waitForCheckpoint = waitForCheckpoints[shardRequestIndex];
        }
        return waitForCheckpoint;
    }

    // Used by ValidateQueryAction, ExplainAction, FieldCaps, TermsEnumAction, lookup join in ESQL
    public ShardSearchRequest(ShardId shardId, long nowInMillis, AliasFilter aliasFilter) {
        // TODO fix SplitShardCountSummary
        this(shardId, nowInMillis, aliasFilter, null, SplitShardCountSummary.UNSET);
    }

    // Used by ESQL and field_caps API
    public ShardSearchRequest(
        ShardId shardId,
        long nowInMillis,
        AliasFilter aliasFilter,
        String clusterAlias,
        SplitShardCountSummary splitShardCountSummary
    ) {
        this(
            OriginalIndices.NONE,
            shardId,
            -1,
            -1,
            SearchType.QUERY_THEN_FETCH,
            null,
            null,
            aliasFilter,
            1.0f,
            true,
            null,
            nowInMillis,
            clusterAlias,
            null,
            null,
            SequenceNumbers.UNASSIGNED_SEQ_NO,
            SearchService.NO_TIMEOUT,
            false,
            splitShardCountSummary
        );
    }

    @SuppressWarnings("this-escape")
    public ShardSearchRequest(
        OriginalIndices originalIndices,
        ShardId shardId,
        int shardRequestIndex,
        int numberOfShards,
        SearchType searchType,
        SearchSourceBuilder source,
        Boolean requestCache,
        AliasFilter aliasFilter,
        float indexBoost,
        boolean allowPartialSearchResults,
        TimeValue scroll,
        long nowInMillis,
        @Nullable String clusterAlias,
        ShardSearchContextId readerId,
        TimeValue keepAlive,
        long waitForCheckpoint,
        TimeValue waitForCheckpointsTimeout,
        boolean forceSyntheticSource,
        SplitShardCountSummary splitShardCountSummary
    ) {
        this.shardId = shardId;
        this.shardRequestIndex = shardRequestIndex;
        this.numberOfShards = numberOfShards;
        this.searchType = searchType;
        this.source(source);
        this.requestCache = requestCache;
        this.aliasFilter = aliasFilter;
        this.indexBoost = indexBoost;
        this.allowPartialSearchResults = allowPartialSearchResults;
        this.scroll = scroll;
        this.nowInMillis = nowInMillis;
        this.clusterAlias = clusterAlias;
        this.originalIndices = originalIndices;
        this.readerId = readerId;
        this.keepAlive = keepAlive;
        assert keepAlive == null || readerId != null : "readerId: null keepAlive: " + keepAlive;
        this.channelVersion = TransportVersion.current();
        this.waitForCheckpoint = waitForCheckpoint;
        this.waitForCheckpointsTimeout = waitForCheckpointsTimeout;
        this.forceSyntheticSource = forceSyntheticSource;
        this.splitShardCountSummary = splitShardCountSummary;
    }

    @SuppressWarnings("this-escape")
    public ShardSearchRequest(ShardSearchRequest clone) {
        this.shardId = clone.shardId;
        this.shardRequestIndex = clone.shardRequestIndex;
        this.searchType = clone.searchType;
        this.numberOfShards = clone.numberOfShards;
        this.scroll = clone.scroll;
        this.source(clone.source);
        this.aliasFilter = clone.aliasFilter;
        this.indexBoost = clone.indexBoost;
        this.nowInMillis = clone.nowInMillis;
        this.requestCache = clone.requestCache;
        this.clusterAlias = clone.clusterAlias;
        this.allowPartialSearchResults = clone.allowPartialSearchResults;
        this.canReturnNullResponseIfMatchNoDocs = clone.canReturnNullResponseIfMatchNoDocs;
        this.bottomSortValues = clone.bottomSortValues;
        this.originalIndices = clone.originalIndices;
        this.readerId = clone.readerId;
        this.keepAlive = clone.keepAlive;
        this.channelVersion = clone.channelVersion;
        this.waitForCheckpoint = clone.waitForCheckpoint;
        this.waitForCheckpointsTimeout = clone.waitForCheckpointsTimeout;
        this.forceSyntheticSource = clone.forceSyntheticSource;
        this.splitShardCountSummary = clone.splitShardCountSummary;
    }

    public ShardSearchRequest(StreamInput in) throws IOException {
        super(in);
        shardId = new ShardId(in);
        searchType = SearchType.fromId(in.readByte());
        shardRequestIndex = in.readVInt();
        numberOfShards = in.readVInt();
        scroll = in.readOptionalTimeValue();
        source = in.readOptionalWriteable(SearchSourceBuilder::new);
        aliasFilter = AliasFilter.readFrom(in);
        indexBoost = in.readFloat();
        nowInMillis = in.readVLong();
        requestCache = in.readOptionalBoolean();
        clusterAlias = in.readOptionalString();
        allowPartialSearchResults = in.readBoolean();
        canReturnNullResponseIfMatchNoDocs = in.readBoolean();
        bottomSortValues = in.readOptionalWriteable(SearchSortValuesAndFormats::new);
        readerId = in.readOptionalWriteable(ShardSearchContextId::new);
        keepAlive = in.readOptionalTimeValue();
        assert keepAlive == null || readerId != null : "readerId: null keepAlive: " + keepAlive;
        channelVersion = TransportVersion.min(TransportVersion.readVersion(in), in.getTransportVersion());
        waitForCheckpoint = in.readLong();
        waitForCheckpointsTimeout = in.readTimeValue();
        forceSyntheticSource = in.readBoolean();
        if (in.getTransportVersion().supports(SHARD_SEARCH_REQUEST_RESHARD_SHARD_COUNT_SUMMARY)) {
            splitShardCountSummary = new SplitShardCountSummary(in);
        } else {
            splitShardCountSummary = SplitShardCountSummary.UNSET;
        }

        originalIndices = OriginalIndices.readOriginalIndices(in);
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        super.writeTo(out);
        innerWriteTo(out, false);
        OriginalIndices.writeOriginalIndices(originalIndices, out);
    }

    protected final void innerWriteTo(StreamOutput out, boolean asKey) throws IOException {
        shardId.writeTo(out);
        out.writeByte(searchType.id());
        if (asKey == false) {
            out.writeVInt(shardRequestIndex);
            out.writeVInt(numberOfShards);
        }
        out.writeOptionalTimeValue(scroll);
        out.writeOptionalWriteable(source);
        aliasFilter.writeTo(out);
        out.writeFloat(indexBoost);
        if (asKey == false) {
            out.writeVLong(nowInMillis);
        }
        out.writeOptionalBoolean(requestCache);
        out.writeOptionalString(clusterAlias);
        out.writeBoolean(allowPartialSearchResults);
        if (asKey == false) {
            out.writeBoolean(canReturnNullResponseIfMatchNoDocs);
            out.writeOptionalWriteable(bottomSortValues);
            out.writeOptionalWriteable(readerId);
            out.writeOptionalTimeValue(keepAlive);
        }
        TransportVersion.writeVersion(channelVersion, out);
        out.writeLong(waitForCheckpoint);
        out.writeTimeValue(waitForCheckpointsTimeout);
        out.writeBoolean(forceSyntheticSource);
        if (out.getTransportVersion().supports(SHARD_SEARCH_REQUEST_RESHARD_SHARD_COUNT_SUMMARY)) {
            splitShardCountSummary.writeTo(out);
        }
    }

    @Override
    public String[] indices() {
        if (originalIndices == null) {
            return null;
        }
        return originalIndices.indices();
    }

    @Override
    public IndicesOptions indicesOptions() {
        if (originalIndices == null) {
            return null;
        }
        return originalIndices.indicesOptions();
    }

    public ShardId shardId() {
        return shardId;
    }

    public SearchSourceBuilder source() {
        return source;
    }

    public AliasFilter getAliasFilter() {
        return aliasFilter;
    }

    public void setAliasFilter(AliasFilter aliasFilter) {
        this.aliasFilter = aliasFilter;
    }

    public void source(SearchSourceBuilder source) {
        if (source != null && source.pointInTimeBuilder() != null) {
            // Discard the actual point in time as data nodes don't use it to reduce the memory usage and the serialization cost
            // of shard-level search requests. However, we need to assign as a dummy PIT instead of null as we verify PIT for
            // slice requests on data nodes.
            source = source.shallowCopy();
            source.pointInTimeBuilder(new PointInTimeBuilder(BytesArray.EMPTY));
        }
        this.source = source;
    }

    /**
     * Returns the shard request ordinal that is used by the main search request
     * to reference this shard.
     */
    public int shardRequestIndex() {
        return shardRequestIndex;
    }

    public int numberOfShards() {
        return numberOfShards;
    }

    public SearchType searchType() {
        return searchType;
    }

    public float indexBoost() {
        return indexBoost;
    }

    public long nowInMillis() {
        return nowInMillis;
    }

    public Boolean requestCache() {
        return requestCache;
    }

    public void requestCache(Boolean requestCache) {
        this.requestCache = requestCache;
    }

    public boolean allowPartialSearchResults() {
        return allowPartialSearchResults;
    }

    public TimeValue scroll() {
        return scroll;
    }

    /**
     * Sets the bottom sort values that can be used by the searcher to filter documents
     * that are after it. This value is computed by coordinating nodes that throttles the
     * query phase. After a partial merge of successful shards the sort values of the
     * bottom top document are passed as an hint on subsequent shard requests.
     */
    public void setBottomSortValues(SearchSortValuesAndFormats values) {
        this.bottomSortValues = values;
    }

    public SearchSortValuesAndFormats getBottomSortValues() {
        return bottomSortValues;
    }

    /**
     * Returns true if the caller can handle null response {@link QuerySearchResult#nullInstance()}.
     * Defaults to false since the coordinator node needs at least one shard response to build the global
     * response.
     */
    public boolean canReturnNullResponseIfMatchNoDocs() {
        return canReturnNullResponseIfMatchNoDocs;
    }

    public void canReturnNullResponseIfMatchNoDocs(boolean value) {
        this.canReturnNullResponseIfMatchNoDocs = value;
    }

    private static final ThreadLocal<BytesStreamOutput> scratch = ThreadLocal.withInitial(BytesStreamOutput::new);

    /**
     * Returns a non-null value if this request should execute using a specific point-in-time reader;
     * otherwise, using the most up to date point-in-time reader.
     */
    public ShardSearchContextId readerId() {
        return readerId;
    }

    /**
     * Returns a non-null to specify the time to live of the point-in-time reader that is used to execute this request.
     */
    public TimeValue keepAlive() {
        return keepAlive;
    }

    public long waitForCheckpoint() {
        return waitForCheckpoint;
    }

    public TimeValue getWaitForCheckpointsTimeout() {
        return waitForCheckpointsTimeout;
    }

    /**
     * Returns the cache key for this shard search request, based on its content
     */
    public BytesReference cacheKey(CheckedBiConsumer<ShardSearchRequest, StreamOutput, IOException> differentiator) throws IOException {
        BytesStreamOutput out = scratch.get();
        try {
            this.innerWriteTo(out, true);
            if (differentiator != null) {
                differentiator.accept(this, out);
            }
            return new BytesArray(MessageDigests.digest(out.bytes(), MessageDigests.sha256()));
        } finally {
            out.reset();
        }
    }

    public String getClusterAlias() {
        return clusterAlias;
    }

    public SplitShardCountSummary getSplitShardCountSummary() {
        return splitShardCountSummary;
    }

    @Override
    public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
        return new SearchShardTask(id, type, action, getDescription(), parentTaskId, headers);
    }

    @Override
    public String getDescription() {
        // Shard id is enough here, the request itself can be found by looking at the parent task description
        return "shardId[" + shardId() + "]";
    }

    @SuppressWarnings("rawtypes")
    public Rewriteable<Rewriteable> getRewriteable() {
        return new RequestRewritable(this);
    }

    @SuppressWarnings("rawtypes")
    static class RequestRewritable implements Rewriteable<Rewriteable> {

        final ShardSearchRequest request;

        RequestRewritable(ShardSearchRequest request) {
            this.request = request;
        }

        @Override
        public Rewriteable rewrite(QueryRewriteContext ctx) throws IOException {
            SearchSourceBuilder newSource = request.source() == null ? null : Rewriteable.rewrite(request.source(), ctx);
            AliasFilter newAliasFilter = Rewriteable.rewrite(request.getAliasFilter(), ctx);
            SearchExecutionContext searchExecutionContext = ctx.convertToSearchExecutionContext();
            if (searchExecutionContext != null) {
                final FieldSortBuilder primarySort = FieldSortBuilder.getPrimaryFieldSortOrNull(newSource);
                if (primarySort != null && primarySort.isBottomSortShardDisjoint(searchExecutionContext, request.getBottomSortValues())) {
                    assert newSource != null : "source should contain a primary sort field";
                    newSource = newSource.shallowCopy();
                    int trackTotalHitsUpTo = SearchRequest.resolveTrackTotalHitsUpTo(request.scroll, request.source);
                    if (trackTotalHitsUpTo == TRACK_TOTAL_HITS_DISABLED
                        && newSource.suggest() == null
                        && newSource.aggregations() == null) {
                        newSource.query(new MatchNoneQueryBuilder());
                    } else {
                        newSource.size(0);
                    }
                    request.source(newSource);
                    request.setBottomSortValues(null);
                }
            }
            if (newSource == request.source() && newAliasFilter == request.getAliasFilter()) {
                return this;
            } else {
                request.source(newSource);
                request.setAliasFilter(newAliasFilter);
                return new RequestRewritable(request);
            }
        }
    }

    /**
     * Returns the filter associated with listed filtering aliases.
     * <p>
     * The list of filtering aliases should be obtained by calling Metadata.filteringAliases.
     * Returns {@code null} if no filtering is required.</p>
     */
    public static QueryBuilder parseAliasFilter(
        CheckedFunction<BytesReference, QueryBuilder, IOException> filterParser,
        IndexMetadata metadata,
        String... aliasNames
    ) {
        if (aliasNames == null || aliasNames.length == 0) {
            return null;
        }
        Index index = metadata.getIndex();
        Map<String, AliasMetadata> aliases = metadata.getAliases();
        if (aliasNames.length == 1) {
            AliasMetadata alias = aliases.get(aliasNames[0]);
            if (alias == null) {
                // This shouldn't happen unless alias disappeared after filteringAliases was called.
                throw new InvalidAliasNameException(index, aliasNames[0], "Unknown alias name was passed to alias Filter");
            }
            return parseAliasFilter(filterParser, alias, index);
        } else {
            // we need to bench here a bit, to see maybe it makes sense to use OrFilter
            BoolQueryBuilder combined = new BoolQueryBuilder();
            for (String aliasName : aliasNames) {
                AliasMetadata alias = aliases.get(aliasName);
                if (alias == null) {
                    // This shouldn't happen unless alias disappeared after filteringAliases was called.
                    throw new InvalidAliasNameException(index, aliasNames[0], "Unknown alias name was passed to alias Filter");
                }
                QueryBuilder parsedFilter = parseAliasFilter(filterParser, alias, index);
                if (parsedFilter != null) {
                    combined.should(parsedFilter);
                } else {
                    // The filter might be null only if filter was removed after filteringAliases was called
                    return null;
                }
            }
            return combined;
        }
    }

    private static QueryBuilder parseAliasFilter(
        CheckedFunction<BytesReference, QueryBuilder, IOException> filterParser,
        AliasMetadata alias,
        Index index
    ) {
        if (alias.filter() == null) {
            return null;
        }
        try {
            return filterParser.apply(alias.filter().compressedReference());
        } catch (IOException ex) {
            throw new AliasFilterParsingException(index, alias.getAlias(), "Invalid alias filter", ex);
        }
    }

    public final Map<String, Object> getRuntimeMappings() {
        return source == null ? emptyMap() : source.runtimeMappings();
    }

    /**
     * Returns the minimum version of the channel that the request has been passed. If the request never passes around, then the channel
     * version is {@link TransportVersion#current()}; otherwise, it's the minimum transport version of the coordinating node and data node
     * (and the proxy node in case the request is sent to the proxy node of the remote cluster before reaching the data node).
     */
    public TransportVersion getChannelVersion() {
        return channelVersion;
    }

    /**
     * Should this request force {@link SourceLoader.Synthetic synthetic source}?
     * Use this to test if the mapping supports synthetic _source and to get a sense
     * of the worst case performance. Fetches with this enabled will be slower the
     * enabling synthetic source natively in the index.
     */
    public boolean isForceSyntheticSource() {
        return forceSyntheticSource;
    }
}
