/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.plugin;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.NoShardAvailableActionException;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.search.SearchShardsGroup;
import org.elasticsearch.action.search.SearchShardsRequest;
import org.elasticsearch.action.search.SearchShardsResponse;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.support.TransportActions;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.compute.operator.DriverCompletionInfo;
import org.elasticsearch.compute.operator.FailureCollector;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction;
import org.elasticsearch.xpack.esql.plugin.ComputeListener;
import org.elasticsearch.xpack.esql.plugin.ComputeResponse;
import org.elasticsearch.xpack.esql.plugin.DataNodeComputeResponse;

abstract class DataNodeRequestSender {
    private static final Logger LOGGER = LogManager.getLogger(DataNodeRequestSender.class);
    private static final List<String> NODE_QUERY_ORDER = List.of(DiscoveryNodeRole.SEARCH_ROLE.roleName(), DiscoveryNodeRole.DATA_CONTENT_NODE_ROLE.roleName(), DiscoveryNodeRole.DATA_HOT_NODE_ROLE.roleName(), DiscoveryNodeRole.DATA_WARM_NODE_ROLE.roleName(), DiscoveryNodeRole.DATA_COLD_NODE_ROLE.roleName(), DiscoveryNodeRole.DATA_FROZEN_NODE_ROLE.roleName());
    private final ClusterService clusterService;
    private final TransportService transportService;
    private final Executor esqlExecutor;
    private final CancellableTask rootTask;
    private final String clusterAlias;
    private final OriginalIndices originalIndices;
    private final QueryBuilder requestFilter;
    private final boolean allowPartialResults;
    private final Semaphore concurrentRequests;
    private final ReentrantLock sendingLock = new ReentrantLock();
    private final Queue<ShardId> pendingShardIds = ConcurrentCollections.newQueue();
    private final Map<DiscoveryNode, Semaphore> nodePermits = new HashMap<DiscoveryNode, Semaphore>();
    private final Map<ShardId, ShardFailure> shardFailures = ConcurrentCollections.newConcurrentMap();
    private final AtomicInteger skippedShards = new AtomicInteger();
    private final AtomicBoolean changed = new AtomicBoolean();
    private boolean reportedFailure = false;
    private final AtomicInteger remainingUnavailableShardResolutionAttempts;

    DataNodeRequestSender(ClusterService clusterService, TransportService transportService, Executor esqlExecutor, CancellableTask rootTask, OriginalIndices originalIndices, QueryBuilder requestFilter, String clusterAlias, boolean allowPartialResults, int concurrentRequests, int unavailableShardResolutionAttempts) {
        this.clusterService = clusterService;
        this.transportService = transportService;
        this.esqlExecutor = esqlExecutor;
        this.rootTask = rootTask;
        this.originalIndices = originalIndices;
        this.requestFilter = requestFilter;
        this.clusterAlias = clusterAlias;
        this.allowPartialResults = allowPartialResults;
        this.concurrentRequests = concurrentRequests > 0 ? new Semaphore(concurrentRequests) : null;
        this.remainingUnavailableShardResolutionAttempts = new AtomicInteger(unavailableShardResolutionAttempts >= 0 ? unavailableShardResolutionAttempts : Integer.MAX_VALUE);
    }

    final void startComputeOnDataNodes(Set<String> concreteIndices, Runnable runOnTaskFailure, ActionListener<ComputeResponse> listener) {
        assert (ThreadPool.assertCurrentThreadPool((String[])new String[]{"esql_worker", "system_read", "search", "search_coordination"}));
        long startTimeInNanos = System.nanoTime();
        this.searchShards(concreteIndices, (ActionListener<TargetShards>)ActionListener.wrap(targetShards -> {
            try (ComputeListener computeListener = new ComputeListener(this.transportService.getThreadPool(), runOnTaskFailure, (ActionListener<DriverCompletionInfo>)listener.map(completionInfo -> {
                int totalSkipShards = targetShards.skippedShards() + this.skippedShards.get();
                int failedShards = this.shardFailures.size();
                int successfulShards = targetShards.totalShards() - totalSkipShards - failedShards;
                return new ComputeResponse((DriverCompletionInfo)completionInfo, TimeValue.timeValueNanos((long)(System.nanoTime() - startTimeInNanos)), targetShards.totalShards(), successfulShards, totalSkipShards, failedShards, this.selectFailures());
            }));){
                this.pendingShardIds.addAll(DataNodeRequestSender.order(targetShards));
                this.trySendingRequestsForPendingShards((TargetShards)targetShards, computeListener);
            }
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    private static List<ShardId> order(TargetShards targetShards) {
        HashMap computedNodeOrder = new HashMap();
        ArrayList<ShardId> ordered = new ArrayList<ShardId>(targetShards.shards.keySet());
        ordered.sort(Comparator.comparingInt(shardId -> DataNodeRequestSender.nodesOrder(targetShards.getShard((ShardId)shardId).remainingNodes, computedNodeOrder)));
        return ordered;
    }

    private static int nodesOrder(List<DiscoveryNode> nodes, Map<DiscoveryNode, Integer> computedNodeOrder) {
        if (nodes.isEmpty()) {
            return Integer.MAX_VALUE;
        }
        int order = 0;
        for (DiscoveryNode node : nodes) {
            order = Math.max(order, computedNodeOrder.computeIfAbsent(node, DataNodeRequestSender::nodeOrder));
        }
        return order;
    }

    private static int nodeOrder(DiscoveryNode node) {
        for (int i = 0; i < NODE_QUERY_ORDER.size(); ++i) {
            if (!node.hasRole(NODE_QUERY_ORDER.get(i))) continue;
            return i;
        }
        return Integer.MAX_VALUE;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void trySendingRequestsForPendingShards(TargetShards targetShards, ComputeListener computeListener) {
        this.changed.set(true);
        ActionListener<Void> listener = computeListener.acquireAvoid();
        try {
            while (this.sendingLock.tryLock()) {
                try {
                    if (this.changed.compareAndSet(true, false)) {
                        HashSet<ShardId> pendingRetries = new HashSet<ShardId>();
                        for (ShardId shardId : this.pendingShardIds) {
                            if (!targetShards.getShard((ShardId)shardId).remainingNodes.isEmpty() || !DataNodeRequestSender.isRetryableFailure(this.shardFailures.get(shardId))) continue;
                            pendingRetries.add(shardId);
                        }
                        if (!pendingRetries.isEmpty() && this.remainingUnavailableShardResolutionAttempts.decrementAndGet() >= 0) {
                            for (Map.Entry entry : this.resolveShards(pendingRetries).entrySet()) {
                                targetShards.getShard((ShardId)((ShardId)entry.getKey())).remainingNodes.addAll((Collection)entry.getValue());
                            }
                        }
                        for (ShardId shardId : this.pendingShardIds) {
                            if (!targetShards.getShard((ShardId)shardId).remainingNodes.isEmpty() || DataNodeRequestSender.isRetryableFailure(this.shardFailures.get(shardId)) && !pendingRetries.contains(shardId)) continue;
                            this.shardFailures.compute(shardId, (k, v) -> new ShardFailure(true, (Exception)(v == null ? new NoShardAvailableActionException(shardId, "no shard copies found") : v.failure)));
                        }
                        if (this.reportedFailure || !this.allowPartialResults && this.shardFailures.values().stream().anyMatch(shardFailure -> shardFailure.fatal)) {
                            this.reportedFailure = true;
                            this.reportFailures(computeListener);
                            continue;
                        }
                        for (NodeRequest nodeRequest : this.selectNodeRequests(targetShards)) {
                            this.sendOneNodeRequest(targetShards, computeListener, nodeRequest);
                        }
                        continue;
                    }
                    break;
                }
                finally {
                    this.sendingLock.unlock();
                }
            }
        }
        finally {
            listener.onResponse(null);
        }
    }

    private void reportFailures(ComputeListener computeListener) {
        assert (this.sendingLock.isHeldByCurrentThread());
        assert (this.reportedFailure);
        Iterator<ShardFailure> it = this.shardFailures.values().iterator();
        Set seen = Collections.newSetFromMap(new IdentityHashMap());
        while (it.hasNext()) {
            ShardFailure failure = it.next();
            if (seen.add(failure.failure)) {
                computeListener.acquireAvoid().onFailure(failure.failure);
            }
            it.remove();
        }
    }

    private List<ShardSearchFailure> selectFailures() {
        assert (!this.reportedFailure);
        ArrayList<ShardSearchFailure> failures = new ArrayList<ShardSearchFailure>();
        Set seen = Collections.newSetFromMap(new IdentityHashMap());
        for (Map.Entry<ShardId, ShardFailure> e : this.shardFailures.entrySet()) {
            ShardFailure failure = e.getValue();
            if (ExceptionsHelper.unwrap((Throwable)failure.failure(), (Class[])new Class[]{TaskCancelledException.class}) != null || !seen.add(failure.failure) || failures.size() >= 5) continue;
            failures.add(new ShardSearchFailure(failure.failure, new SearchShardTarget(null, e.getKey(), this.clusterAlias)));
        }
        if (failures.isEmpty() && !this.shardFailures.isEmpty()) {
            ShardFailure any = this.shardFailures.values().iterator().next();
            failures.add(new ShardSearchFailure(any.failure));
        }
        return failures;
    }

    private void sendOneNodeRequest(final TargetShards targetShards, final ComputeListener computeListener, final NodeRequest request) {
        final ActionListener<DriverCompletionInfo> listener = computeListener.acquireCompute();
        this.sendRequest(request.node, request.shardIds, request.aliasFilters, new NodeListener(){

            void onAfterRequest() {
                DataNodeRequestSender.this.nodePermits.get(request.node).release();
                if (DataNodeRequestSender.this.concurrentRequests != null) {
                    DataNodeRequestSender.this.concurrentRequests.release();
                }
                DataNodeRequestSender.this.trySendingRequestsForPendingShards(targetShards, computeListener);
            }

            @Override
            public void onResponse(DataNodeComputeResponse response) {
                try {
                    for (ShardId shardId : request.shardIds()) {
                        if (response.shardLevelFailures().containsKey(shardId)) continue;
                        DataNodeRequestSender.this.shardFailures.remove(shardId);
                    }
                    for (Map.Entry entry : response.shardLevelFailures().entrySet()) {
                        ShardId shardId = (ShardId)entry.getKey();
                        DataNodeRequestSender.this.trackShardLevelFailure(shardId, false, (Exception)entry.getValue());
                        DataNodeRequestSender.this.pendingShardIds.add(shardId);
                    }
                    this.onAfterRequest();
                }
                catch (Exception ex) {
                    this.expectNoFailure("expect no failure while handling data node response", ex);
                    listener.onFailure(ex);
                    return;
                }
                listener.onResponse((Object)response.completionInfo());
            }

            @Override
            public void onFailure(Exception e, boolean receivedData) {
                try {
                    for (ShardId shardId : request.shardIds) {
                        DataNodeRequestSender.this.trackShardLevelFailure(shardId, receivedData, e);
                        DataNodeRequestSender.this.pendingShardIds.add(shardId);
                    }
                    this.onAfterRequest();
                }
                catch (Exception ex) {
                    this.expectNoFailure("expect no failure while handling failure of data node request", ex);
                    listener.onFailure(ex);
                    return;
                }
                listener.onResponse((Object)DriverCompletionInfo.EMPTY);
            }

            @Override
            public void onSkip() {
                DataNodeRequestSender.this.skippedShards.incrementAndGet();
                if (DataNodeRequestSender.this.rootTask.isCancelled()) {
                    this.onFailure((Exception)((Object)new TaskCancelledException("null")), true);
                } else {
                    this.onResponse(new DataNodeComputeResponse(DriverCompletionInfo.EMPTY, Map.of()));
                }
            }

            private void expectNoFailure(String message, Exception e) {
                LOGGER.error(message, (Throwable)e);
                assert (false) : new AssertionError(message, e);
            }
        });
    }

    abstract void sendRequest(DiscoveryNode var1, List<ShardId> var2, Map<Index, AliasFilter> var3, NodeListener var4);

    private static Exception unwrapFailure(Exception e) {
        if (e instanceof TransportException) {
            TransportException te = (TransportException)e;
            v0 = FailureCollector.unwrapTransportException((TransportException)te);
        } else {
            v0 = e = e;
        }
        if (TransportActions.isShardNotAvailableException((Throwable)e)) {
            return NoShardAvailableActionException.forOnShardFailureWrapper((String)e.getMessage());
        }
        return e;
    }

    private void trackShardLevelFailure(ShardId shardId, boolean fatal, Exception originalEx) {
        Exception e = DataNodeRequestSender.unwrapFailure(originalEx);
        boolean isTaskCanceledException = ExceptionsHelper.unwrap((Throwable)e, (Class[])new Class[]{TaskCancelledException.class}) != null;
        boolean isCircuitBreakerException = ExceptionsHelper.unwrap((Throwable)e, (Class[])new Class[]{CircuitBreakingException.class}) != null;
        this.shardFailures.compute(shardId, (k, current) -> {
            boolean mergedFatal;
            boolean bl = mergedFatal = fatal || isTaskCanceledException || isCircuitBreakerException;
            return current == null ? new ShardFailure(mergedFatal, e) : new ShardFailure(mergedFatal || current.fatal, isTaskCanceledException || e instanceof NoShardAvailableActionException ? current.failure : e);
        });
    }

    private static boolean isRetryableFailure(ShardFailure failure) {
        return failure != null && !failure.fatal && failure.failure instanceof NoShardAvailableActionException;
    }

    private List<NodeRequest> selectNodeRequests(TargetShards targetShards) {
        assert (this.sendingLock.isHeldByCurrentThread());
        LinkedHashMap<DiscoveryNode, ArrayList<ShardId>> nodeToShardIds = new LinkedHashMap<DiscoveryNode, ArrayList<ShardId>>();
        Iterator shardsIt = this.pendingShardIds.iterator();
        block0: while (shardsIt.hasNext()) {
            ShardId shardId = (ShardId)shardsIt.next();
            ShardFailure failure = this.shardFailures.get(shardId);
            if (failure != null && failure.fatal) {
                shardsIt.remove();
                continue;
            }
            TargetShard shard = targetShards.getShard(shardId);
            Iterator<DiscoveryNode> nodesIt = shard.remainingNodes.iterator();
            while (nodesIt.hasNext()) {
                DiscoveryNode node = nodesIt.next();
                ArrayList<ShardId> pendingRequest = (ArrayList<ShardId>)nodeToShardIds.get(node);
                if (pendingRequest != null) {
                    pendingRequest.add(shard.shardId);
                    nodesIt.remove();
                    shardsIt.remove();
                    continue block0;
                }
                if (this.concurrentRequests != null && !this.concurrentRequests.tryAcquire()) continue;
                if (this.nodePermits.computeIfAbsent(node, n -> new Semaphore(1)).tryAcquire()) {
                    pendingRequest = new ArrayList<ShardId>();
                    pendingRequest.add(shard.shardId);
                    nodeToShardIds.put(node, pendingRequest);
                    nodesIt.remove();
                    shardsIt.remove();
                    continue block0;
                }
                if (this.concurrentRequests == null) continue;
                this.concurrentRequests.release();
            }
        }
        ArrayList<NodeRequest> nodeRequests = new ArrayList<NodeRequest>(nodeToShardIds.size());
        for (Map.Entry entry : nodeToShardIds.entrySet()) {
            DiscoveryNode node = (DiscoveryNode)entry.getKey();
            List shardIds = (List)entry.getValue();
            HashMap<Index, AliasFilter> aliasFilters = new HashMap<Index, AliasFilter>();
            for (ShardId shardId : shardIds) {
                AliasFilter aliasFilter = targetShards.getShard((ShardId)shardId).aliasFilter;
                if (aliasFilter == null) continue;
                aliasFilters.put(shardId.getIndex(), aliasFilter);
            }
            nodeRequests.add(new NodeRequest(node, shardIds, aliasFilters));
        }
        return nodeRequests;
    }

    void searchShards(Set<String> concreteIndices, ActionListener<TargetShards> listener) {
        ActionListener searchShardsListener = listener.map(resp -> {
            Map nodes = Maps.newHashMapWithExpectedSize((int)resp.getNodes().size());
            for (DiscoveryNode node : resp.getNodes()) {
                nodes.put(node.getId(), node);
            }
            int totalShards = 0;
            int skippedShards = 0;
            Map shards = Maps.newHashMapWithExpectedSize((int)resp.getGroups().size());
            for (SearchShardsGroup group : resp.getGroups()) {
                ShardId shardId = group.shardId();
                if (!concreteIndices.contains(shardId.getIndexName())) continue;
                ++totalShards;
                if (group.skipped()) {
                    ++skippedShards;
                    continue;
                }
                ArrayList<DiscoveryNode> allocatedNodes = new ArrayList<DiscoveryNode>(group.allocatedNodes().size());
                for (String n : group.allocatedNodes()) {
                    allocatedNodes.add((DiscoveryNode)nodes.get(n));
                }
                AliasFilter aliasFilter = (AliasFilter)resp.getAliasFilters().get(shardId.getIndex().getUUID());
                shards.put(shardId, new TargetShard(shardId, allocatedNodes, aliasFilter));
            }
            return new TargetShards(shards, totalShards, skippedShards);
        });
        SearchShardsRequest searchShardsRequest = new SearchShardsRequest(this.originalIndices.indices(), this.originalIndices.indicesOptions(), this.requestFilter, null, null, true, this.clusterAlias);
        this.transportService.sendChildRequest(this.transportService.getLocalNode(), EsqlSearchShardsAction.TYPE.name(), (TransportRequest)searchShardsRequest, (Task)this.rootTask, TransportRequestOptions.EMPTY, (TransportResponseHandler)new ActionListenerResponseHandler(searchShardsListener, SearchShardsResponse::new, this.esqlExecutor));
    }

    Map<ShardId, List<DiscoveryNode>> resolveShards(Set<ShardId> shardIds) {
        ClusterState state = this.clusterService.state();
        Map nodes = Maps.newMapWithExpectedSize((int)shardIds.size());
        for (ShardId shardId : shardIds) {
            List<DiscoveryNode> allocatedNodes;
            try {
                allocatedNodes = state.routingTable().shardRoutingTable(shardId).allShards().filter(shard -> shard.active() && shard.isSearchable()).map(shard -> state.nodes().get(shard.currentNodeId())).toList();
            }
            catch (Exception ignored) {
                continue;
            }
            nodes.put(shardId, allocatedNodes);
        }
        return nodes;
    }

    record TargetShards(Map<ShardId, TargetShard> shards, int totalShards, int skippedShards) {
        TargetShard getShard(ShardId shardId) {
            return this.shards.get(shardId);
        }
    }

    record TargetShard(ShardId shardId, List<DiscoveryNode> remainingNodes, AliasFilter aliasFilter) {
    }

    private record ShardFailure(boolean fatal, Exception failure) {
    }

    record NodeRequest(DiscoveryNode node, List<ShardId> shardIds, Map<Index, AliasFilter> aliasFilters) {
    }

    static interface NodeListener {
        public void onResponse(DataNodeComputeResponse var1);

        public void onFailure(Exception var1, boolean var2);

        public void onSkip();
    }
}

