/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.action.search;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.function.BiFunction;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.FixedBitSet;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.CanMatchNodeRequest;
import org.elasticsearch.action.search.CanMatchNodeResponse;
import org.elasticsearch.action.search.SearchPhase;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchShardIterator;
import org.elasticsearch.action.search.SearchTask;
import org.elasticsearch.action.search.SearchTransportService;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.action.search.VersionMismatchException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.Types;
import org.elasticsearch.index.query.CoordinatorRewriteContext;
import org.elasticsearch.index.query.CoordinatorRewriteContextProvider;
import org.elasticsearch.search.CanMatchShardResponse;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.MinAndMax;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;

final class CanMatchPreFilterSearchPhase {
    private final Logger logger;
    private final SearchRequest request;
    private final List<SearchShardIterator> shardsIts;
    private final ActionListener<List<SearchShardIterator>> listener;
    private final TransportSearchAction.SearchTimeProvider timeProvider;
    private final BiFunction<String, String, Transport.Connection> nodeIdToConnection;
    private final SearchTransportService searchTransportService;
    private final Map<SearchShardIterator, Integer> shardItIndexMap;
    private final Map<String, Float> concreteIndexBoosts;
    private final Map<String, AliasFilter> aliasFilter;
    private final SearchTask task;
    private final Executor executor;
    private final boolean requireAtLeastOneMatch;
    private final FixedBitSet possibleMatches;
    private final MinAndMax<?>[] minAndMaxes;
    private int numPossibleMatches;
    private final CoordinatorRewriteContextProvider coordinatorRewriteContextProvider;
    private static final float DEFAULT_INDEX_BOOST = 1.0f;

    CanMatchPreFilterSearchPhase(Logger logger, SearchTransportService searchTransportService, BiFunction<String, String, Transport.Connection> nodeIdToConnection, Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts, Executor executor, SearchRequest request, List<SearchShardIterator> shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, SearchTask task, boolean requireAtLeastOneMatch, CoordinatorRewriteContextProvider coordinatorRewriteContextProvider, ActionListener<List<SearchShardIterator>> listener) {
        this.logger = logger;
        this.searchTransportService = searchTransportService;
        this.nodeIdToConnection = nodeIdToConnection;
        this.request = request;
        this.listener = listener;
        this.shardsIts = shardsIts;
        this.timeProvider = timeProvider;
        this.concreteIndexBoosts = concreteIndexBoosts;
        this.aliasFilter = aliasFilter;
        this.task = task;
        this.requireAtLeastOneMatch = requireAtLeastOneMatch;
        this.coordinatorRewriteContextProvider = coordinatorRewriteContextProvider;
        this.executor = executor;
        int size = shardsIts.size();
        this.possibleMatches = new FixedBitSet(size);
        this.minAndMaxes = new MinAndMax[size];
        Object[] naturalOrder = new SearchShardIterator[size];
        int i = 0;
        for (SearchShardIterator shardsIt : shardsIts) {
            naturalOrder[i++] = shardsIt;
        }
        Arrays.sort(naturalOrder);
        Map<Object, Integer> shardItIndexMap = Maps.newHashMapWithExpectedSize(naturalOrder.length);
        for (int j = 0; j < naturalOrder.length; ++j) {
            shardItIndexMap.put(naturalOrder[j], j);
        }
        this.shardItIndexMap = shardItIndexMap;
    }

    private static boolean assertSearchCoordinationThread() {
        return ThreadPool.assertCurrentThreadPool("search_coordination");
    }

    private void runCoordinatorRewritePhase() {
        assert (CanMatchPreFilterSearchPhase.assertSearchCoordinationThread());
        ArrayList<SearchShardIterator> matchedShardLevelRequests = new ArrayList<SearchShardIterator>();
        for (SearchShardIterator searchShardIterator : this.shardsIts) {
            CanMatchNodeRequest canMatchNodeRequest = new CanMatchNodeRequest(this.request, searchShardIterator.getOriginalIndices().indicesOptions(), Collections.emptyList(), this.shardsIts.size(), this.timeProvider.absoluteStartMillis(), searchShardIterator.getClusterAlias());
            ShardSearchRequest request = canMatchNodeRequest.createShardSearchRequest(this.buildShardLevelRequest(searchShardIterator));
            if (searchShardIterator.prefiltered()) {
                this.consumeResult(!searchShardIterator.skip(), request);
                continue;
            }
            boolean canMatch = true;
            CoordinatorRewriteContext coordinatorRewriteContext = this.coordinatorRewriteContextProvider.getCoordinatorRewriteContext(request.shardId().getIndex());
            if (coordinatorRewriteContext != null) {
                try {
                    canMatch = SearchService.queryStillMatchesAfterRewrite(request, coordinatorRewriteContext);
                }
                catch (Exception exception) {
                    // empty catch block
                }
            }
            if (canMatch) {
                matchedShardLevelRequests.add(searchShardIterator);
                continue;
            }
            this.consumeResult(false, request);
        }
        if (matchedShardLevelRequests.isEmpty()) {
            this.finishPhase();
        } else {
            this.checkNoMissingShards(matchedShardLevelRequests);
            new Round(matchedShardLevelRequests).run();
        }
    }

    private void consumeResult(boolean canMatch, ShardSearchRequest request) {
        CanMatchShardResponse result = new CanMatchShardResponse(canMatch, null);
        result.setShardIndex(request.shardRequestIndex());
        this.consumeResult(result, () -> {});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void consumeResult(CanMatchShardResponse result, Runnable next) {
        try {
            boolean canMatch = result.canMatch();
            MinAndMax<?> minAndMax = result.estimatedMinAndMax();
            if (canMatch || minAndMax != null) {
                this.consumeResult(result.getShardIndex(), canMatch, minAndMax);
            }
        }
        finally {
            next.run();
        }
    }

    private synchronized void consumeResult(int shardIndex, boolean canMatch, MinAndMax<?> minAndMax) {
        if (canMatch) {
            this.possibleMatches.set(shardIndex);
            ++this.numPossibleMatches;
        }
        this.minAndMaxes[shardIndex] = minAndMax;
    }

    private void checkNoMissingShards(List<SearchShardIterator> shards) {
        assert (CanMatchPreFilterSearchPhase.assertSearchCoordinationThread());
        SearchPhase.doCheckNoMissingShards("can_match", this.request, shards);
    }

    private Map<SendingTarget, List<SearchShardIterator>> groupByNode(List<SearchShardIterator> shards) {
        HashMap<SendingTarget, List<SearchShardIterator>> requests = new HashMap<SendingTarget, List<SearchShardIterator>>();
        for (int i = 0; i < shards.size(); ++i) {
            SearchShardIterator shardRoutings = shards.get(i);
            assert (!shardRoutings.skip());
            assert (this.shardItIndexMap.containsKey(shardRoutings));
            SearchShardTarget target = shardRoutings.nextOrNull();
            if (target != null) {
                requests.computeIfAbsent(new SendingTarget(target.getClusterAlias(), target.getNodeId()), t -> new ArrayList()).add(shardRoutings);
                continue;
            }
            requests.computeIfAbsent(new SendingTarget(null, null), t -> new ArrayList()).add(shardRoutings);
        }
        return requests;
    }

    private CanMatchNodeRequest createCanMatchRequest(Map.Entry<SendingTarget, List<SearchShardIterator>> entry) {
        SearchShardIterator first = entry.getValue().get(0);
        List<CanMatchNodeRequest.Shard> shardLevelRequests = entry.getValue().stream().map(this::buildShardLevelRequest).toList();
        assert (entry.getValue().stream().allMatch(Objects::nonNull));
        assert (entry.getValue().stream().allMatch(ssi -> Objects.equals(ssi.getOriginalIndices().indicesOptions(), first.getOriginalIndices().indicesOptions())));
        assert (entry.getValue().stream().allMatch(ssi -> Objects.equals(ssi.getClusterAlias(), first.getClusterAlias())));
        return new CanMatchNodeRequest(this.request, first.getOriginalIndices().indicesOptions(), shardLevelRequests, this.shardsIts.size(), this.timeProvider.absoluteStartMillis(), first.getClusterAlias());
    }

    private void finishPhase() {
        this.listener.onResponse(this.getIterator(this.shardsIts));
    }

    public CanMatchNodeRequest.Shard buildShardLevelRequest(SearchShardIterator shardIt) {
        AliasFilter filter = this.aliasFilter.get(shardIt.shardId().getIndex().getUUID());
        assert (filter != null);
        float indexBoost = this.concreteIndexBoosts.getOrDefault(shardIt.shardId().getIndex().getUUID(), Float.valueOf(1.0f)).floatValue();
        int shardRequestIndex = this.shardItIndexMap.get(shardIt);
        return new CanMatchNodeRequest.Shard(shardIt.getOriginalIndices().indices(), shardIt.shardId(), shardRequestIndex, filter, indexBoost, shardIt.getSearchContextId(), shardIt.getSearchContextKeepAlive(), ShardSearchRequest.computeWaitForCheckpoint(this.request.getWaitForCheckpoints(), shardIt.shardId(), shardRequestIndex));
    }

    private boolean checkMinimumVersion(List<SearchShardIterator> shardsIts) {
        for (SearchShardIterator it : shardsIts) {
            boolean isCompatible;
            if (it.getTargetNodeIds().isEmpty() || (isCompatible = it.getTargetNodeIds().stream().anyMatch(nodeId -> {
                Transport.Connection conn = this.nodeIdToConnection.apply(it.getClusterAlias(), (String)nodeId);
                return conn == null || conn.getNode().getVersion().onOrAfter(this.request.minCompatibleShardNode());
            }))) continue;
            return false;
        }
        return true;
    }

    public void start() {
        if (this.shardsIts.isEmpty()) {
            this.finishPhase();
            return;
        }
        this.executor.execute(new AbstractRunnable(){

            @Override
            public void onFailure(Exception e) {
                if (CanMatchPreFilterSearchPhase.this.logger.isDebugEnabled()) {
                    CanMatchPreFilterSearchPhase.this.logger.debug(() -> Strings.format((String)"Failed to execute [%s] while running [can_match] phase", (Object[])new Object[]{CanMatchPreFilterSearchPhase.this.request}), (Throwable)e);
                }
                CanMatchPreFilterSearchPhase.this.onPhaseFailure("start", e);
            }

            @Override
            protected void doRun() {
                assert (CanMatchPreFilterSearchPhase.assertSearchCoordinationThread());
                Version version = CanMatchPreFilterSearchPhase.this.request.minCompatibleShardNode();
                if (version != null && !Version.CURRENT.minimumCompatibilityVersion().equals(version) && !CanMatchPreFilterSearchPhase.this.checkMinimumVersion(CanMatchPreFilterSearchPhase.this.shardsIts)) {
                    throw new VersionMismatchException("One of the shards is incompatible with the required minimum version [{}]", CanMatchPreFilterSearchPhase.this.request.minCompatibleShardNode());
                }
                CanMatchPreFilterSearchPhase.this.runCoordinatorRewritePhase();
            }
        });
    }

    private void onPhaseFailure(String msg, Exception cause) {
        this.listener.onFailure(new SearchPhaseExecutionException("can_match", msg, cause, ShardSearchFailure.EMPTY_ARRAY));
    }

    private synchronized List<SearchShardIterator> getIterator(List<SearchShardIterator> shardsIts) {
        if (this.requireAtLeastOneMatch && this.numPossibleMatches == 0) {
            int shardIndexToQuery = 0;
            for (int i = 0; i < shardsIts.size(); ++i) {
                SearchShardIterator it = shardsIts.get(i);
                if (it.size() <= 0) continue;
                shardIndexToQuery = i;
                it.skip(false);
                break;
            }
            this.possibleMatches.set(shardIndexToQuery);
        }
        int i = 0;
        for (SearchShardIterator iter : shardsIts) {
            iter.reset();
            boolean match = this.possibleMatches.get(i++);
            if (match) {
                assert (!iter.skip());
                continue;
            }
            iter.skip(true);
        }
        if (!CanMatchPreFilterSearchPhase.shouldSortShards(this.minAndMaxes)) {
            return shardsIts;
        }
        FieldSortBuilder fieldSort = FieldSortBuilder.getPrimaryFieldSortOrNull(this.request.source());
        return CanMatchPreFilterSearchPhase.sortShards(shardsIts, this.minAndMaxes, fieldSort.order());
    }

    private static List<SearchShardIterator> sortShards(List<SearchShardIterator> shardsIts, MinAndMax<?>[] minAndMaxes, SortOrder order) {
        int bound = shardsIts.size();
        ArrayList<Integer> toSort = new ArrayList<Integer>(bound);
        for (int i = 0; i < bound; ++i) {
            toSort.add(i);
        }
        Comparator keyComparator = (Comparator)Types.forciblyCast(MinAndMax.getComparator(order));
        toSort.sort((idx1, idx2) -> {
            int res = keyComparator.compare(minAndMaxes[idx1], minAndMaxes[idx2]);
            if (res != 0) {
                return res;
            }
            return ((SearchShardIterator)shardsIts.get((int)idx1)).compareTo((SearchShardIterator)shardsIts.get((int)idx2));
        });
        ArrayList<SearchShardIterator> list = new ArrayList<SearchShardIterator>(bound);
        for (Integer integer : toSort) {
            list.add(shardsIts.get(integer));
        }
        return list;
    }

    private static boolean shouldSortShards(MinAndMax<?>[] minAndMaxes) {
        Class<?> clazz = null;
        for (MinAndMax<?> minAndMax : minAndMaxes) {
            if (clazz == null) {
                clazz = minAndMax == null ? null : minAndMax.getMin().getClass();
                continue;
            }
            if (minAndMax == null || clazz == minAndMax.getMin().getClass()) continue;
            return false;
        }
        return clazz != null;
    }

    class Round
    extends AbstractRunnable {
        private final List<SearchShardIterator> shards;
        private final CountDown countDown;
        private final AtomicReferenceArray<Exception> failedResponses;

        Round(List<SearchShardIterator> shards) {
            this.shards = shards;
            this.countDown = new CountDown(shards.size());
            this.failedResponses = new AtomicReferenceArray(CanMatchPreFilterSearchPhase.this.shardsIts.size());
        }

        @Override
        protected void doRun() {
            assert (CanMatchPreFilterSearchPhase.assertSearchCoordinationThread());
            Map<SendingTarget, List<SearchShardIterator>> requests = CanMatchPreFilterSearchPhase.this.groupByNode(this.shards);
            for (Map.Entry<SendingTarget, List<SearchShardIterator>> entry : requests.entrySet()) {
                final CanMatchNodeRequest canMatchNodeRequest = CanMatchPreFilterSearchPhase.this.createCanMatchRequest(entry);
                final List<CanMatchNodeRequest.Shard> shardLevelRequests = canMatchNodeRequest.getShardLevelRequests();
                if (entry.getKey().nodeId == null) {
                    for (CanMatchNodeRequest.Shard shard : shardLevelRequests) {
                        this.onOperationFailed(shard.getShardRequestIndex(), null);
                    }
                    continue;
                }
                SendingTarget sendingTarget = entry.getKey();
                try {
                    CanMatchPreFilterSearchPhase.this.searchTransportService.sendCanMatch(CanMatchPreFilterSearchPhase.this.nodeIdToConnection.apply(sendingTarget.clusterAlias, sendingTarget.nodeId), canMatchNodeRequest, CanMatchPreFilterSearchPhase.this.task, new ActionListener<CanMatchNodeResponse>(){

                        @Override
                        public void onResponse(CanMatchNodeResponse canMatchNodeResponse) {
                            assert (canMatchNodeResponse.getResponses().size() == canMatchNodeRequest.getShardLevelRequests().size());
                            for (int i = 0; i < canMatchNodeResponse.getResponses().size(); ++i) {
                                CanMatchNodeResponse.ResponseOrFailure response = canMatchNodeResponse.getResponses().get(i);
                                if (response.getResponse() != null) {
                                    CanMatchShardResponse shardResponse = response.getResponse();
                                    shardResponse.setShardIndex(((CanMatchNodeRequest.Shard)shardLevelRequests.get(i)).getShardRequestIndex());
                                    Round.this.onOperation(shardResponse.getShardIndex(), shardResponse);
                                    continue;
                                }
                                Exception failure = response.getException();
                                assert (failure != null);
                                Round.this.onOperationFailed(((CanMatchNodeRequest.Shard)shardLevelRequests.get(i)).getShardRequestIndex(), failure);
                            }
                        }

                        @Override
                        public void onFailure(Exception e) {
                            for (CanMatchNodeRequest.Shard shard : shardLevelRequests) {
                                Round.this.onOperationFailed(shard.getShardRequestIndex(), e);
                            }
                        }
                    });
                }
                catch (Exception e) {
                    for (CanMatchNodeRequest.Shard shard : shardLevelRequests) {
                        this.onOperationFailed(shard.getShardRequestIndex(), e);
                    }
                }
            }
        }

        private void onOperation(int idx, CanMatchShardResponse response) {
            this.failedResponses.set(idx, null);
            CanMatchPreFilterSearchPhase.this.consumeResult(response, () -> {
                if (this.countDown.countDown()) {
                    this.finishRound();
                }
            });
        }

        private void onOperationFailed(int idx, Exception e) {
            this.failedResponses.set(idx, e);
            CanMatchPreFilterSearchPhase.this.consumeResult(idx, true, null);
            if (this.countDown.countDown()) {
                this.finishRound();
            }
        }

        private void finishRound() {
            ArrayList<SearchShardIterator> remainingShards = new ArrayList<SearchShardIterator>();
            for (SearchShardIterator ssi : this.shards) {
                int shardIndex = CanMatchPreFilterSearchPhase.this.shardItIndexMap.get(ssi);
                Exception failedResponse = this.failedResponses.get(shardIndex);
                if (failedResponse == null) continue;
                remainingShards.add(ssi);
            }
            if (remainingShards.isEmpty()) {
                CanMatchPreFilterSearchPhase.this.finishPhase();
            } else {
                CanMatchPreFilterSearchPhase.this.executor.execute(new Round(remainingShards){

                    @Override
                    public boolean isForceExecution() {
                        return true;
                    }
                });
            }
        }

        @Override
        public void onFailure(Exception e) {
            if (CanMatchPreFilterSearchPhase.this.logger.isDebugEnabled()) {
                CanMatchPreFilterSearchPhase.this.logger.debug(() -> Strings.format((String)"Failed to execute [%s] while running [can_match] phase", (Object[])new Object[]{CanMatchPreFilterSearchPhase.this.request}), (Throwable)e);
            }
            CanMatchPreFilterSearchPhase.this.onPhaseFailure("round", e);
        }
    }

    public record SendingTarget(@Nullable String clusterAlias, @Nullable String nodeId) {
    }
}

