/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.action.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopFieldDocs;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.IndicesRequest;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.search.AbstractSearchAsyncAction;
import org.elasticsearch.action.search.BottomSortValuesCollector;
import org.elasticsearch.action.search.CanMatchPreFilterSearchPhase;
import org.elasticsearch.action.search.FetchSearchPhase;
import org.elasticsearch.action.search.QueryPhaseResultConsumer;
import org.elasticsearch.action.search.RankFeaturePhase;
import org.elasticsearch.action.search.SearchActionListener;
import org.elasticsearch.action.search.SearchPhase;
import org.elasticsearch.action.search.SearchPhaseController;
import org.elasticsearch.action.search.SearchPhaseResults;
import org.elasticsearch.action.search.SearchProgressListener;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchShardIterator;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.action.search.SearchTask;
import org.elasticsearch.action.search.SearchTransportService;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.SimpleRefCounted;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.dfs.AggregatedDfs;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.LeakTracker;
import org.elasticsearch.transport.SendRequestTransportException;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportActionProxy;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;

public class SearchQueryThenFetchAsyncAction
extends AbstractSearchAsyncAction<SearchPhaseResult> {
    private static final Logger logger = LogManager.getLogger(SearchQueryThenFetchAsyncAction.class);
    private static final TransportVersion BATCHED_QUERY_PHASE_VERSION = TransportVersion.fromName("batched_query_phase_version");
    private final SearchProgressListener progressListener;
    private final int topDocsSize;
    private final int trackTotalHitsUpTo;
    private volatile BottomSortValuesCollector bottomSortCollector;
    private final Client client;
    private final boolean batchQueryPhase;
    private static final String NODE_SEARCH_ACTION_NAME = "indices:data/read/search[query][n]";

    SearchQueryThenFetchAsyncAction(Logger logger, NamedWriteableRegistry namedWriteableRegistry, SearchTransportService searchTransportService, BiFunction<String, String, Transport.Connection> nodeIdToConnection, Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts, Executor executor, SearchPhaseResults<SearchPhaseResult> resultConsumer, SearchRequest request, ActionListener<SearchResponse> listener, List<SearchShardIterator> shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters, Client client, boolean batchQueryPhase) {
        super("query", logger, namedWriteableRegistry, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, executor, request, listener, shardsIts, timeProvider, clusterState, task, resultConsumer, request.getMaxConcurrentShardRequests(), clusters);
        this.topDocsSize = SearchPhaseController.getTopDocsSize(request);
        this.trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo();
        this.progressListener = task.getProgressListener();
        this.client = client;
        this.batchQueryPhase = batchQueryPhase;
        if (this.progressListener != SearchProgressListener.NOOP) {
            this.notifyListShards(this.progressListener, clusters, request, shardsIts);
        }
    }

    @Override
    protected void executePhaseOnShard(SearchShardIterator shardIt, Transport.Connection connection, SearchActionListener<SearchPhaseResult> listener) {
        ShardSearchRequest request = SearchQueryThenFetchAsyncAction.tryRewriteWithUpdatedSortValue(this.bottomSortCollector, this.trackTotalHitsUpTo, super.buildShardSearchRequest(shardIt, listener.requestIndex));
        this.getSearchTransport().sendExecuteQuery(connection, request, this.getTask(), listener);
    }

    @Override
    protected void onShardGroupFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
        this.progressListener.notifyQueryFailure(shardIndex, shardTarget, exc);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected void onShardResult(SearchPhaseResult result) {
        QuerySearchResult queryResult = result.queryResult();
        if (!queryResult.isNull() && this.getRequest().scroll() == null && !queryResult.hasConsumedTopDocs() && queryResult.topDocs() != null && queryResult.topDocs().topDocs.getClass() == TopFieldDocs.class) {
            TopFieldDocs topDocs = (TopFieldDocs)queryResult.topDocs().topDocs;
            if (this.bottomSortCollector == null) {
                SearchQueryThenFetchAsyncAction searchQueryThenFetchAsyncAction = this;
                synchronized (searchQueryThenFetchAsyncAction) {
                    if (this.bottomSortCollector == null) {
                        this.bottomSortCollector = new BottomSortValuesCollector(this.topDocsSize, topDocs.fields);
                    }
                }
            }
            try {
                this.bottomSortCollector.consumeTopDocs(topDocs, queryResult.sortValueFormats());
            }
            catch (Exception e) {
                logger.debug("failed to consume top docs for shard [{}] with sort fields [{}]: {}", (Object)result.getShardIndex(), (Object)Arrays.toString(topDocs.fields), (Object)e);
            }
        }
        super.onShardResult(result);
    }

    static SearchPhase nextPhase(Client client, AbstractSearchAsyncAction<?> context, SearchPhaseResults<SearchPhaseResult> queryResults, AggregatedDfs aggregatedDfs) {
        RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseCoordCtx = RankFeaturePhase.coordinatorContext(context.getRequest().source(), client);
        if (rankFeaturePhaseCoordCtx == null) {
            return new FetchSearchPhase(queryResults, aggregatedDfs, context, null);
        }
        return new RankFeaturePhase(queryResults, aggregatedDfs, context, rankFeaturePhaseCoordCtx);
    }

    @Override
    protected SearchPhase getNextPhase() {
        return SearchQueryThenFetchAsyncAction.nextPhase(this.client, this, this.results, null);
    }

    private static ShardSearchRequest tryRewriteWithUpdatedSortValue(BottomSortValuesCollector bottomSortCollector, int trackTotalHitsUpTo, ShardSearchRequest request) {
        if (bottomSortCollector == null) {
            return request;
        }
        if (trackTotalHitsUpTo != Integer.MAX_VALUE && bottomSortCollector.getTotalHits() > (long)trackTotalHitsUpTo) {
            request.source(request.source().shallowCopy().trackTotalHits(false));
        }
        if (bottomSortCollector.getBottomSortValues() != null) {
            request.setBottomSortValues(bottomSortCollector.getBottomSortValues());
        }
        return request;
    }

    private static boolean isPartOfPIT(SearchRequest request, ShardSearchContextId contextId, NamedWriteableRegistry namedWriteableRegistry) {
        PointInTimeBuilder pointInTimeBuilder = request.pointInTimeBuilder();
        if (pointInTimeBuilder != null) {
            return request.pointInTimeBuilder().getSearchContextId(namedWriteableRegistry).contains(contextId);
        }
        return false;
    }

    @Override
    protected void doRun(Map<SearchShardIterator, Integer> shardIndexMap) {
        if (!this.batchQueryPhase) {
            super.doRun(shardIndexMap);
            return;
        }
        AbstractSearchAsyncAction.doCheckNoMissingShards(this.getName(), this.request, this.shardsIts);
        HashMap<CanMatchPreFilterSearchPhase.SendingTarget, NodeQueryRequest> perNodeQueries = new HashMap<CanMatchPreFilterSearchPhase.SendingTarget, NodeQueryRequest>();
        String localNodeId = this.searchTransportService.transportService().getLocalNode().getId();
        int numberOfShardsTotal = this.shardsIts.size();
        for (int i = 0; i < numberOfShardsTotal; ++i) {
            SearchShardIterator shardRoutings = (SearchShardIterator)this.shardsIts.get(i);
            assert (!shardRoutings.skip());
            assert (shardIndexMap.containsKey(shardRoutings));
            int shardIndex = shardIndexMap.get(shardRoutings);
            SearchShardTarget routing2 = shardRoutings.nextOrNull();
            if (routing2 == null) {
                this.failOnUnavailable(shardIndex, shardRoutings);
                continue;
            }
            String nodeId = routing2.getNodeId();
            if (localNodeId.equals(nodeId)) {
                this.performPhaseOnShard(shardIndex, shardRoutings, routing2);
                continue;
            }
            NodeQueryRequest perNodeRequest = perNodeQueries.computeIfAbsent(new CanMatchPreFilterSearchPhase.SendingTarget(routing2.getClusterAlias(), nodeId), t -> new NodeQueryRequest(this.request, numberOfShardsTotal, this.timeProvider.absoluteStartMillis(), t.clusterAlias()));
            String indexUUID = routing2.getShardId().getIndex().getUUID();
            perNodeRequest.shards.add(new ShardToQuery(this.concreteIndexBoosts.getOrDefault(indexUUID, Float.valueOf(1.0f)).floatValue(), this.getOriginalIndices(shardIndex).indices(), shardIndex, routing2.getShardId(), shardRoutings.getSearchContextId()));
            AliasFilter filterForAlias = this.aliasFilter.getOrDefault(indexUUID, AliasFilter.EMPTY);
            if (filterForAlias == AliasFilter.EMPTY) continue;
            perNodeRequest.aliasFilters.putIfAbsent(indexUUID, filterForAlias);
        }
        perNodeQueries.forEach((routing, request) -> {
            Transport.Connection connection;
            if (request.shards.size() == 1) {
                this.executeAsSingleRequest((CanMatchPreFilterSearchPhase.SendingTarget)routing, request.shards.get(0));
                return;
            }
            final String nodeId = routing.nodeId();
            try {
                connection = this.getConnection(routing.clusterAlias(), routing.nodeId());
            }
            catch (Exception e) {
                this.onNodeQueryFailure(e, (NodeQueryRequest)request, (CanMatchPreFilterSearchPhase.SendingTarget)routing);
                return;
            }
            if (!connection.getTransportVersion().supports(BATCHED_QUERY_PHASE_VERSION) || connection.getNode().getVersionInformation().nodeVersion().before(Version.V_8_19_0)) {
                this.executeWithoutBatching((CanMatchPreFilterSearchPhase.SendingTarget)routing, (NodeQueryRequest)request);
                return;
            }
            this.searchTransportService.transportService().sendChildRequest(connection, NODE_SEARCH_ACTION_NAME, (TransportRequest)request, this.task, new TransportResponseHandler<NodeQueryResponse>(){

                @Override
                public NodeQueryResponse read(StreamInput in) throws IOException {
                    return new NodeQueryResponse(in);
                }

                @Override
                public Executor executor() {
                    return EsExecutors.DIRECT_EXECUTOR_SERVICE;
                }

                @Override
                public void handleResponse(NodeQueryResponse response) {
                    SearchPhaseResults searchPhaseResults = SearchQueryThenFetchAsyncAction.this.results;
                    if (searchPhaseResults instanceof QueryPhaseResultConsumer) {
                        QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer)searchPhaseResults;
                        queryPhaseResultConsumer.addBatchedPartialResult(response.topDocsStats, response.mergeResult);
                    }
                    for (int i = 0; i < response.results.length; ++i) {
                        ShardToQuery s = request.shards.get(i);
                        int shardIdx = s.shardIndex;
                        SearchShardTarget target = new SearchShardTarget(routing.nodeId(), s.shardId, routing.clusterAlias());
                        Object object = response.results[i];
                        if (object instanceof Exception) {
                            Exception e = (Exception)object;
                            SearchQueryThenFetchAsyncAction.this.onShardFailure(shardIdx, target, SearchQueryThenFetchAsyncAction.this.shardIterators[shardIdx], e);
                            continue;
                        }
                        object = response.results[i];
                        if (object instanceof SearchPhaseResult) {
                            SearchPhaseResult q = (SearchPhaseResult)object;
                            q.setShardIndex(shardIdx);
                            q.setSearchShardTarget(target);
                            SearchQueryThenFetchAsyncAction.this.onShardResult(q);
                            continue;
                        }
                        assert (false) : "impossible [" + String.valueOf(response.results[i]) + "]";
                    }
                }

                @Override
                public void handleException(TransportException e) {
                    Exception cause = (Exception)ExceptionsHelper.unwrapCause(e);
                    logger.debug("handling node search exception coming from [" + nodeId + "]", (Throwable)cause);
                    if (e instanceof SendRequestTransportException || cause instanceof TaskCancelledException) {
                        SearchQueryThenFetchAsyncAction.this.onNodeQueryFailure(e, request, routing);
                    } else {
                        SearchPhaseResults searchPhaseResults = SearchQueryThenFetchAsyncAction.this.results;
                        if (searchPhaseResults instanceof QueryPhaseResultConsumer) {
                            QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer)searchPhaseResults;
                            queryPhaseResultConsumer.failure.compareAndSet(null, cause);
                        }
                        SearchQueryThenFetchAsyncAction.this.onPhaseFailure(SearchQueryThenFetchAsyncAction.this.getName(), "", cause);
                    }
                }
            });
        });
    }

    private void executeWithoutBatching(CanMatchPreFilterSearchPhase.SendingTarget targetNode, NodeQueryRequest request) {
        for (ShardToQuery shard : request.shards) {
            this.executeAsSingleRequest(targetNode, shard);
        }
    }

    private void executeAsSingleRequest(CanMatchPreFilterSearchPhase.SendingTarget targetNode, ShardToQuery shard) {
        int sidx = shard.shardIndex;
        this.performPhaseOnShard(sidx, this.shardIterators[sidx], new SearchShardTarget(targetNode.nodeId(), shard.shardId, targetNode.clusterAlias()));
    }

    private void onNodeQueryFailure(Exception e, NodeQueryRequest request, CanMatchPreFilterSearchPhase.SendingTarget target) {
        for (ShardToQuery shard : request.shards) {
            int idx = shard.shardIndex;
            this.onShardFailure(idx, new SearchShardTarget(target.nodeId(), shard.shardId, target.clusterAlias()), this.shardIterators[idx], e);
        }
    }

    static void registerNodeSearchAction(SearchTransportService searchTransportService, SearchService searchService, SearchPhaseController searchPhaseController, NamedWriteableRegistry namedWriteableRegistry) {
        TransportService transportService = searchTransportService.transportService();
        ThreadPool threadPool = transportService.getThreadPool();
        Dependencies dependencies = new Dependencies(searchService, threadPool.executor("search"));
        int searchPoolMax = threadPool.info("search").getMax();
        transportService.registerRequestHandler(NODE_SEARCH_ACTION_NAME, EsExecutors.DIRECT_EXECUTOR_SERVICE, NodeQueryRequest::new, (request, channel, task) -> {
            CancellableTask cancellableTask = (CancellableTask)task;
            int shardCount = request.shards.size();
            int workers = Math.min(request.searchRequest.getMaxConcurrentShardRequests(), Math.min(shardCount, searchPoolMax));
            QueryPerNodeState state = new QueryPerNodeState(new QueryPhaseResultConsumer(request.searchRequest, dependencies.executor, searchService.getCircuitBreaker(), searchPhaseController, cancellableTask::isCancelled, SearchProgressListener.NOOP, shardCount, e -> logger.error("failed to merge on data node", (Throwable)e)), (NodeQueryRequest)request, cancellableTask, channel, dependencies, namedWriteableRegistry);
            for (int i = 0; i < workers; ++i) {
                SearchQueryThenFetchAsyncAction.executeShardTasks(state);
            }
        });
        TransportActionProxy.registerProxyAction(transportService, NODE_SEARCH_ACTION_NAME, true, NodeQueryResponse::new);
    }

    private static void releaseLocalContext(SearchService searchService, NodeQueryRequest request, SearchPhaseResult result, NamedWriteableRegistry namedWriteableRegistry) {
        SearchPhaseResult phaseResult;
        SearchPhaseResult searchPhaseResult = phaseResult = result.queryResult() != null ? result.queryResult() : result.rankFeatureResult();
        if (phaseResult != null && phaseResult.hasSearchContext() && request.searchRequest.scroll() == null && !SearchQueryThenFetchAsyncAction.isPartOfPIT(request.searchRequest, phaseResult.getContextId(), namedWriteableRegistry)) {
            searchService.freeReaderContext(phaseResult.getContextId());
        }
    }

    private static ShardSearchRequest buildShardSearchRequest(ShardId shardId, String clusterAlias, int shardIndex, ShardSearchContextId searchContextId, OriginalIndices originalIndices, AliasFilter aliasFilter, TimeValue searchContextKeepAlive, float indexBoost, SearchRequest searchRequest, int totalShardCount, long absoluteStartMillis, boolean hasResponse) {
        ShardSearchRequest shardRequest = new ShardSearchRequest(originalIndices, searchRequest, shardId, shardIndex, totalShardCount, aliasFilter, indexBoost, absoluteStartMillis, clusterAlias, searchContextId, searchContextKeepAlive);
        shardRequest.canReturnNullResponseIfMatchNoDocs(hasResponse && shardRequest.scroll() == null);
        return shardRequest;
    }

    private static void executeShardTasks(final QueryPerNodeState state) {
        int idx;
        int totalShardCount = state.searchRequest.shards.size();
        while ((idx = state.currentShardIndex.getAndIncrement()) < totalShardCount) {
            final int dataNodeLocalIdx = idx;
            final ListenableFuture doneFuture = new ListenableFuture();
            try {
                NodeQueryRequest nodeQueryRequest = state.searchRequest;
                SearchRequest searchRequest = nodeQueryRequest.searchRequest;
                PointInTimeBuilder pitBuilder = searchRequest.pointInTimeBuilder();
                ShardToQuery shardToQuery = nodeQueryRequest.shards.get(dataNodeLocalIdx);
                ShardId shardId = shardToQuery.shardId;
                state.dependencies.searchService.executeQueryPhase(SearchQueryThenFetchAsyncAction.tryRewriteWithUpdatedSortValue(state.bottomSortCollector, state.trackTotalHitsUpTo, SearchQueryThenFetchAsyncAction.buildShardSearchRequest(shardId, nodeQueryRequest.localClusterAlias, shardToQuery.shardIndex, shardToQuery.contextId, new OriginalIndices(shardToQuery.originalIndices, nodeQueryRequest.indicesOptions()), nodeQueryRequest.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), pitBuilder == null ? null : pitBuilder.getKeepAlive(), shardToQuery.boost, searchRequest, nodeQueryRequest.totalShards, nodeQueryRequest.absoluteStartMillis, state.hasResponse.getAcquire())), state.task, (ActionListener<SearchPhaseResult>)new SearchActionListener<SearchPhaseResult>(new SearchShardTarget(null, shardToQuery.shardId, nodeQueryRequest.localClusterAlias), dataNodeLocalIdx){

                    @Override
                    protected void innerOnResponse(SearchPhaseResult searchPhaseResult) {
                        try {
                            state.consumeResult(searchPhaseResult.queryResult());
                        }
                        catch (Exception e) {
                            this.setFailure(state, dataNodeLocalIdx, e);
                        }
                        finally {
                            doneFuture.onResponse(null);
                        }
                    }

                    private void setFailure(QueryPerNodeState state2, int dataNodeLocalIdx2, Exception e) {
                        state2.failures.put(dataNodeLocalIdx2, e);
                        state2.onShardDone();
                    }

                    @Override
                    public void onFailure(Exception e) {
                        this.setFailure(state, dataNodeLocalIdx, e);
                        doneFuture.onResponse(null);
                    }
                });
            }
            catch (Exception e) {
                state.failures.put(dataNodeLocalIdx, e);
                state.onShardDone();
                continue;
            }
            if (doneFuture.isDone()) continue;
            doneFuture.addListener(ActionListener.running(() -> SearchQueryThenFetchAsyncAction.executeShardTasks(state)));
            break;
        }
    }

    public static final class NodeQueryRequest
    extends TransportRequest
    implements IndicesRequest {
        private final List<ShardToQuery> shards;
        private final SearchRequest searchRequest;
        private final Map<String, AliasFilter> aliasFilters;
        private final int totalShards;
        private final long absoluteStartMillis;
        private final String localClusterAlias;

        private NodeQueryRequest(SearchRequest searchRequest, int totalShards, long absoluteStartMillis, String localClusterAlias) {
            this.shards = new ArrayList<ShardToQuery>();
            this.searchRequest = searchRequest;
            this.aliasFilters = new HashMap<String, AliasFilter>();
            this.totalShards = totalShards;
            this.absoluteStartMillis = absoluteStartMillis;
            this.localClusterAlias = localClusterAlias;
        }

        private NodeQueryRequest(StreamInput in) throws IOException {
            super(in);
            this.shards = in.readCollectionAsImmutableList(ShardToQuery::readFrom);
            this.searchRequest = new SearchRequest(in);
            this.aliasFilters = in.readImmutableMap(AliasFilter::readFrom);
            this.totalShards = in.readVInt();
            this.absoluteStartMillis = in.readLong();
            this.localClusterAlias = in.readOptionalString();
        }

        @Override
        public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
            return new SearchShardTask(id, type, action, "NodeQueryRequest", parentTaskId, headers);
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            out.writeCollection(this.shards);
            this.searchRequest.writeTo(out, true);
            out.writeMap(this.aliasFilters, (o, v) -> v.writeTo(o));
            out.writeVInt(this.totalShards);
            out.writeLong(this.absoluteStartMillis);
            out.writeOptionalString(this.localClusterAlias);
        }

        @Override
        public String[] indices() {
            return (String[])this.shards.stream().flatMap(s -> Arrays.stream(s.originalIndices())).distinct().toArray(String[]::new);
        }

        @Override
        public IndicesOptions indicesOptions() {
            return this.searchRequest.indicesOptions();
        }
    }

    private record ShardToQuery(float boost, String[] originalIndices, int shardIndex, ShardId shardId, ShardSearchContextId contextId) implements Writeable
    {
        static ShardToQuery readFrom(StreamInput in) throws IOException {
            return new ShardToQuery(in.readFloat(), in.readStringArray(), in.readVInt(), new ShardId(in), in.readOptionalWriteable(ShardSearchContextId::new));
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            out.writeFloat(this.boost);
            out.writeStringArray(this.originalIndices);
            out.writeVInt(this.shardIndex);
            this.shardId.writeTo(out);
            out.writeOptionalWriteable(this.contextId);
        }
    }

    private record Dependencies(SearchService searchService, Executor executor) {
    }

    private static final class QueryPerNodeState {
        private static final QueryPhaseResultConsumer.MergeResult EMPTY_PARTIAL_MERGE_RESULT = new QueryPhaseResultConsumer.MergeResult(List.of(), null, null, 0L);
        private final AtomicInteger currentShardIndex = new AtomicInteger();
        private final QueryPhaseResultConsumer queryPhaseResultConsumer;
        private final NodeQueryRequest searchRequest;
        private final CancellableTask task;
        private final ConcurrentHashMap<Integer, Exception> failures = new ConcurrentHashMap();
        private final Dependencies dependencies;
        private final AtomicBoolean hasResponse = new AtomicBoolean(false);
        private final int trackTotalHitsUpTo;
        private final int topDocsSize;
        private final CountDown countDown;
        private final TransportChannel channel;
        private volatile BottomSortValuesCollector bottomSortCollector;
        private final NamedWriteableRegistry namedWriteableRegistry;

        private QueryPerNodeState(QueryPhaseResultConsumer queryPhaseResultConsumer, NodeQueryRequest searchRequest, CancellableTask task, TransportChannel channel, Dependencies dependencies, NamedWriteableRegistry namedWriteableRegistry) {
            this.queryPhaseResultConsumer = queryPhaseResultConsumer;
            this.searchRequest = searchRequest;
            this.trackTotalHitsUpTo = searchRequest.searchRequest.resolveTrackTotalHitsUpTo();
            this.topDocsSize = SearchPhaseController.getTopDocsSize(searchRequest.searchRequest);
            this.task = task;
            this.countDown = new CountDown(queryPhaseResultConsumer.getNumShards());
            this.channel = channel;
            this.dependencies = dependencies;
            this.namedWriteableRegistry = namedWriteableRegistry;
        }

        void onShardDone() {
            if (!this.countDown.countDown()) {
                return;
            }
            ChannelActionListener<TransportResponse> channelListener = new ChannelActionListener<TransportResponse>(this.channel);
            try (QueryPhaseResultConsumer queryPhaseResultConsumer = this.queryPhaseResultConsumer;){
                QueryPhaseResultConsumer.MergeResult mergeResult;
                Exception failure = this.queryPhaseResultConsumer.failure.get();
                if (failure != null) {
                    this.handleMergeFailure(failure, channelListener, this.namedWriteableRegistry);
                    return;
                }
                try {
                    mergeResult = Objects.requireNonNullElse(this.queryPhaseResultConsumer.consumePartialMergeResultDataNode(), EMPTY_PARTIAL_MERGE_RESULT);
                }
                catch (Exception e) {
                    this.handleMergeFailure(e, channelListener, this.namedWriteableRegistry);
                    if (queryPhaseResultConsumer != null) {
                        queryPhaseResultConsumer.close();
                    }
                    return;
                }
                BitSet relevantShardIndices = new BitSet(this.searchRequest.shards.size());
                if (mergeResult.reducedTopDocs() != null) {
                    for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) {
                        int localIndex = scoreDoc.shardIndex;
                        scoreDoc.shardIndex = this.searchRequest.shards.get((int)localIndex).shardIndex;
                        relevantShardIndices.set(localIndex);
                    }
                }
                Object[] results = new Object[this.queryPhaseResultConsumer.getNumShards()];
                for (int i = 0; i < results.length; ++i) {
                    SearchPhaseResult result = (SearchPhaseResult)this.queryPhaseResultConsumer.results.get(i);
                    if (result == null) {
                        results[i] = this.failures.get(i);
                    } else {
                        QuerySearchResult q;
                        if (result instanceof QuerySearchResult && (q = (QuerySearchResult)result).getContextId() != null && !relevantShardIndices.get(q.getShardIndex()) && !q.hasSuggestHits() && q.getRankShardResult() == null && this.searchRequest.searchRequest.scroll() == null && !SearchQueryThenFetchAsyncAction.isPartOfPIT(this.searchRequest.searchRequest, q.getContextId(), this.namedWriteableRegistry) && this.dependencies.searchService.freeReaderContext(q.getContextId())) {
                            q.clearContextId();
                        }
                        results[i] = result;
                    }
                    assert (results[i] != null);
                }
                ActionListener.respondAndRelease(channelListener, new NodeQueryResponse(mergeResult, results, this.queryPhaseResultConsumer.topDocsStats));
            }
        }

        private void handleMergeFailure(Exception e, ChannelActionListener<TransportResponse> channelListener, NamedWriteableRegistry namedWriteableRegistry) {
            this.queryPhaseResultConsumer.getSuccessfulResults().forEach(searchPhaseResult -> SearchQueryThenFetchAsyncAction.releaseLocalContext(this.dependencies.searchService, this.searchRequest, searchPhaseResult, namedWriteableRegistry));
            channelListener.onFailure(e);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        void consumeResult(QuerySearchResult queryResult) {
            this.hasResponse.compareAndExchangeRelease(false, true);
            if (!queryResult.isNull() && this.searchRequest.searchRequest.scroll() == null && !queryResult.hasConsumedTopDocs() && queryResult.topDocs() != null && queryResult.topDocs().topDocs.getClass() == TopFieldDocs.class) {
                TopFieldDocs topDocs = (TopFieldDocs)queryResult.topDocs().topDocs;
                BottomSortValuesCollector bottomSortCollector = this.bottomSortCollector;
                if (bottomSortCollector == null) {
                    QueryPerNodeState queryPerNodeState = this;
                    synchronized (queryPerNodeState) {
                        bottomSortCollector = this.bottomSortCollector;
                        if (bottomSortCollector == null) {
                            bottomSortCollector = this.bottomSortCollector = new BottomSortValuesCollector(this.topDocsSize, topDocs.fields);
                        }
                    }
                }
                bottomSortCollector.consumeTopDocs(topDocs, queryResult.sortValueFormats());
            }
            this.queryPhaseResultConsumer.consumeResult(queryResult, this::onShardDone);
        }
    }

    public static final class NodeQueryResponse
    extends TransportResponse {
        private final RefCounted refCounted = LeakTracker.wrap((RefCounted)new SimpleRefCounted());
        private final Object[] results;
        private final SearchPhaseController.TopDocsStats topDocsStats;
        private final QueryPhaseResultConsumer.MergeResult mergeResult;

        NodeQueryResponse(StreamInput in) throws IOException {
            this.results = in.readArray(i -> i.readBoolean() ? new QuerySearchResult(i) : i.readException(), Object[]::new);
            this.mergeResult = QueryPhaseResultConsumer.MergeResult.readFrom(in);
            this.topDocsStats = SearchPhaseController.TopDocsStats.readFrom(in);
        }

        NodeQueryResponse(QueryPhaseResultConsumer.MergeResult mergeResult, Object[] results, SearchPhaseController.TopDocsStats topDocsStats) {
            this.results = results;
            for (Object result : results) {
                if (!(result instanceof QuerySearchResult)) continue;
                QuerySearchResult r = (QuerySearchResult)result;
                r.incRef();
            }
            this.mergeResult = mergeResult;
            this.topDocsStats = topDocsStats;
            assert (Arrays.stream(results).noneMatch(Objects::isNull)) : Arrays.toString(results);
        }

        public Object[] getResults() {
            return this.results;
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            out.writeArray((o, v) -> {
                if (v instanceof Exception) {
                    Exception e = (Exception)v;
                    o.writeBoolean(false);
                    o.writeException(e);
                } else {
                    o.writeBoolean(true);
                    assert (v instanceof QuerySearchResult) : v;
                    ((QuerySearchResult)v).writeTo(o);
                }
            }, this.results);
            this.mergeResult.writeTo(out);
            this.topDocsStats.writeTo(out);
        }

        @Override
        public void incRef() {
            this.refCounted.incRef();
        }

        @Override
        public boolean tryIncRef() {
            return this.refCounted.tryIncRef();
        }

        @Override
        public boolean hasReferences() {
            return this.refCounted.hasReferences();
        }

        @Override
        public boolean decRef() {
            if (this.refCounted.decRef()) {
                for (int i = 0; i < this.results.length; ++i) {
                    Object object = this.results[i];
                    if (object instanceof QuerySearchResult) {
                        QuerySearchResult r = (QuerySearchResult)object;
                        r.decRef();
                    }
                    this.results[i] = null;
                }
                return true;
            }
            return false;
        }
    }
}

