/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.action.search;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.AbstractSearchAsyncAction;
import org.elasticsearch.action.search.CountedCollector;
import org.elasticsearch.action.search.SearchActionListener;
import org.elasticsearch.action.search.SearchPhase;
import org.elasticsearch.action.search.SearchPhaseController;
import org.elasticsearch.action.search.SearchPhaseResults;
import org.elasticsearch.action.search.SearchProgressListener;
import org.elasticsearch.action.search.SearchQueryThenFetchAsyncAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.builder.SubSearchSourceBuilder;
import org.elasticsearch.search.dfs.AggregatedDfs;
import org.elasticsearch.search.dfs.DfsKnnResults;
import org.elasticsearch.search.dfs.DfsSearchResult;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.search.query.QuerySearchRequest;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.vectors.KnnScoreDocQueryBuilder;
import org.elasticsearch.transport.Transport;

class DfsQueryPhase
extends SearchPhase {
    public static final String NAME = "dfs_query";
    private final SearchPhaseResults<SearchPhaseResult> queryResult;
    private final Client client;
    private final AbstractSearchAsyncAction<?> context;
    private final SearchProgressListener progressListener;
    private long phaseStartTimeInNanos;

    DfsQueryPhase(SearchPhaseResults<SearchPhaseResult> queryResult, Client client, AbstractSearchAsyncAction<?> context) {
        super(NAME);
        this.progressListener = context.getTask().getProgressListener();
        this.queryResult = queryResult;
        this.client = client;
        this.context = context;
    }

    protected SearchPhase nextPhase(AggregatedDfs dfs) {
        return SearchQueryThenFetchAsyncAction.nextPhase(this.client, this.context, this.queryResult, dfs);
    }

    @Override
    protected void run() {
        this.phaseStartTimeInNanos = System.nanoTime();
        List<DfsSearchResult> searchResults = this.context.results.getAtomicArray().asList();
        AggregatedDfs dfs = DfsQueryPhase.aggregateDfs(searchResults);
        final CountedCollector<SearchPhaseResult> counter = new CountedCollector<SearchPhaseResult>(this.queryResult, searchResults.size(), () -> this.onFinish(dfs), this.context);
        List<DfsKnnResults> knnResults = DfsQueryPhase.mergeKnnResults(this.context.getRequest(), searchResults);
        for (final DfsSearchResult dfsResult : searchResults) {
            Transport.Connection connection;
            final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget();
            final int shardIndex = dfsResult.getShardIndex();
            final QuerySearchRequest querySearchRequest = new QuerySearchRequest(this.context.getOriginalIndices(shardIndex), dfsResult.getContextId(), this.rewriteShardSearchRequest(knnResults, dfsResult.getShardSearchRequest()), dfs);
            try {
                connection = this.context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
            }
            catch (Exception e) {
                this.shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter);
                continue;
            }
            this.context.getSearchTransport().sendExecuteQuery(connection, querySearchRequest, this.context.getTask(), (ActionListener<QuerySearchResult>)new SearchActionListener<QuerySearchResult>(shardTarget, shardIndex){

                @Override
                protected void innerOnResponse(QuerySearchResult response) {
                    try {
                        response.setSearchProfileDfsPhaseResult(dfsResult.searchProfileDfsPhaseResult());
                        counter.onResult(response);
                    }
                    catch (Exception e) {
                        DfsQueryPhase.this.context.onPhaseFailure(DfsQueryPhase.NAME, "", e);
                    }
                }

                @Override
                public void onFailure(Exception exception) {
                    try {
                        DfsQueryPhase.this.shardFailure(exception, querySearchRequest, shardIndex, shardTarget, counter);
                    }
                    finally {
                        if (!DfsQueryPhase.this.context.isPartOfPointInTime(querySearchRequest.contextId())) {
                            DfsQueryPhase.this.context.sendReleaseSearchContext(querySearchRequest.contextId(), connection);
                        }
                    }
                }
            });
        }
    }

    private void onFinish(AggregatedDfs dfs) {
        this.context.getSearchResponseMetrics().recordSearchPhaseDuration(this.getName(), System.nanoTime() - this.phaseStartTimeInNanos, this.context.getSearchRequestAttributes());
        this.context.executeNextPhase(NAME, () -> this.nextPhase(dfs));
    }

    private void shardFailure(Exception exception, QuerySearchRequest querySearchRequest, int shardIndex, SearchShardTarget shardTarget, CountedCollector<SearchPhaseResult> counter) {
        this.context.getLogger().debug(() -> "[" + String.valueOf(querySearchRequest.contextId()) + "] Failed to execute query phase", (Throwable)exception);
        this.progressListener.notifyQueryFailure(shardIndex, shardTarget, exception);
        counter.onFailure(shardIndex, shardTarget, exception);
    }

    ShardSearchRequest rewriteShardSearchRequest(List<DfsKnnResults> knnResults, ShardSearchRequest request) {
        SearchSourceBuilder source = request.source();
        if (source == null || source.knnSearch().isEmpty()) {
            return request;
        }
        ArrayList<SubSearchSourceBuilder> subSearchSourceBuilders = new ArrayList<SubSearchSourceBuilder>(source.subSearches());
        int i = 0;
        for (DfsKnnResults dfsKnnResults : knnResults) {
            ArrayList<ScoreDoc> scoreDocs = new ArrayList<ScoreDoc>();
            for (ScoreDoc scoreDoc2 : dfsKnnResults.scoreDocs()) {
                if (scoreDoc2.shardIndex != request.shardRequestIndex()) continue;
                scoreDocs.add(scoreDoc2);
            }
            scoreDocs.sort(Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
            String nestedPath = dfsKnnResults.getNestedPath();
            QueryBuilder query = ((KnnScoreDocQueryBuilder)new KnnScoreDocQueryBuilder(scoreDocs.toArray(Lucene.EMPTY_SCORE_DOCS), source.knnSearch().get(i).getField(), source.knnSearch().get(i).getQueryVector(), source.knnSearch().get(i).getSimilarity(), source.knnSearch().get(i).getFilterQueries()).boost(source.knnSearch().get(i).boost())).queryName(source.knnSearch().get(i).queryName());
            if (nestedPath != null) {
                query = new NestedQueryBuilder(nestedPath, query, ScoreMode.Max).innerHit(source.knnSearch().get(i).innerHit());
            }
            subSearchSourceBuilders.add(new SubSearchSourceBuilder(query));
            ++i;
        }
        source = source.shallowCopy().subSearches(subSearchSourceBuilders).knnSearch(List.of());
        request.source(source);
        return request;
    }

    private static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
        if (!request.hasKnnSearch()) {
            return null;
        }
        SearchSourceBuilder source = request.source();
        ArrayList topDocsLists = new ArrayList(source.knnSearch().size());
        ArrayList<SetOnce> nestedPath = new ArrayList<SetOnce>(source.knnSearch().size());
        for (int i = 0; i < source.knnSearch().size(); ++i) {
            topDocsLists.add(new ArrayList());
            nestedPath.add(new SetOnce());
        }
        for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
            if (dfsSearchResult.knnResults() == null) continue;
            for (int i = 0; i < dfsSearchResult.knnResults().size(); ++i) {
                DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
                ScoreDoc[] scoreDocs = knnResults.scoreDocs();
                TotalHits totalHits = new TotalHits((long)scoreDocs.length, TotalHits.Relation.EQUAL_TO);
                TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
                SearchPhaseController.setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
                ((List)topDocsLists.get(i)).add(shardTopDocs);
                ((SetOnce)nestedPath.get(i)).trySet((Object)knnResults.getNestedPath());
            }
        }
        ArrayList<DfsKnnResults> mergedResults = new ArrayList<DfsKnnResults>(source.knnSearch().size());
        for (int i = 0; i < source.knnSearch().size(); ++i) {
            TopDocs mergedTopDocs = TopDocs.merge((int)source.knnSearch().get(i).k(), (TopDocs[])((List)topDocsLists.get(i)).toArray(new TopDocs[0]));
            mergedResults.add(new DfsKnnResults((String)((SetOnce)nestedPath.get(i)).get(), mergedTopDocs.scoreDocs));
        }
        return mergedResults;
    }

    private static AggregatedDfs aggregateDfs(Collection<DfsSearchResult> results) {
        HashMap<Term, TermStatistics> termStatistics = new HashMap<Term, TermStatistics>();
        HashMap<String, CollectionStatistics> fieldStatistics = new HashMap<String, CollectionStatistics>();
        long aggMaxDoc = 0L;
        for (DfsSearchResult lEntry : results) {
            Term[] terms = lEntry.terms();
            TermStatistics[] stats = lEntry.termStatistics();
            assert (terms.length == stats.length);
            for (int i = 0; i < terms.length; ++i) {
                assert (terms[i] != null);
                if (stats[i] == null) continue;
                TermStatistics existing = (TermStatistics)termStatistics.get(terms[i]);
                if (existing != null) {
                    assert (terms[i].bytes().equals((Object)existing.term()));
                    termStatistics.put(terms[i], new TermStatistics(existing.term(), existing.docFreq() + stats[i].docFreq(), existing.totalTermFreq() + stats[i].totalTermFreq()));
                    continue;
                }
                termStatistics.put(terms[i], stats[i]);
            }
            assert (!lEntry.fieldStatistics().containsKey(null));
            for (Map.Entry<String, CollectionStatistics> entry : lEntry.fieldStatistics().entrySet()) {
                String key = entry.getKey();
                CollectionStatistics value = entry.getValue();
                if (value == null) continue;
                assert (key != null);
                CollectionStatistics existing = (CollectionStatistics)fieldStatistics.get(key);
                if (existing != null) {
                    CollectionStatistics merged = new CollectionStatistics(key, existing.maxDoc() + value.maxDoc(), existing.docCount() + value.docCount(), existing.sumTotalTermFreq() + value.sumTotalTermFreq(), existing.sumDocFreq() + value.sumDocFreq());
                    fieldStatistics.put(key, merged);
                    continue;
                }
                fieldStatistics.put(key, value);
            }
            aggMaxDoc += (long)lEntry.maxDoc();
        }
        return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
    }
}

