/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.TemporaryNameUtils;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.ExecutesOn;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.PipelineBreaker;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.Streaming;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;

public final class HoistRemoteEnrichTopN
extends OptimizerRules.OptimizerRule<Enrich>
implements OptimizerRules.CoordinatorOnly {
    public HoistRemoteEnrichTopN() {
        super(OptimizerRules.TransformDirection.UP);
    }

    @Override
    protected LogicalPlan rule(Enrich en) {
        if (en.mode() == Enrich.Mode.REMOTE) {
            LogicalPlan plan = en.child();
            List<Attribute> outputs = en.output();
            while (true) {
                Enrich e;
                ExecutesOn ex;
                TopN topN;
                if (plan instanceof TopN && !(topN = (TopN)plan).local()) {
                    Set<Attribute> missingAttributes = this.checkOrderInOutputs(topN.order(), outputs);
                    if (missingAttributes.isEmpty()) {
                        LogicalPlan transformedEnrich = en.transformDown(TopN.class, t -> t == topN ? topN.withLocal(true) : t);
                        return new TopN(topN.source(), transformedEnrich, topN.order(), topN.limit(), false);
                    }
                    HashSet<String> missingNames = new HashSet<String>(Expressions.names(missingAttributes));
                    AttributeMap.Builder aliasesForReplacedAttributesBuilder = AttributeMap.builder();
                    ArrayList<Order> rewrittenExpressions = new ArrayList<Order>();
                    for (Order expr : topN.order()) {
                        rewrittenExpressions.add((Order)expr.transformUp(Attribute.class, attr -> {
                            if (missingNames.contains(attr.name())) {
                                Alias renamedAttribute = aliasesForReplacedAttributesBuilder.computeIfAbsent((Attribute)attr, a -> {
                                    String tempName = TemporaryNameUtils.locallyUniqueTemporaryName(a.name());
                                    return new Alias(a.source(), tempName, (Expression)a, null, true);
                                });
                                return renamedAttribute.toAttribute();
                            }
                            return attr;
                        }));
                    }
                    AttributeMap replacedAttributes = aliasesForReplacedAttributesBuilder.build();
                    Eval replacementTop = new Eval(topN.source(), topN.withLocal(true), new ArrayList<Alias>(replacedAttributes.values()));
                    ArrayList<Order> newOrder = rewrittenExpressions;
                    Holder stop = new Holder((Object)false);
                    LogicalPlan transformedEnrich = en.transformDown(p -> {
                        LogicalPlan logicalPlan;
                        LogicalPlan logicalPlan2 = p;
                        Objects.requireNonNull(logicalPlan2);
                        LogicalPlan selector0$temp = logicalPlan2;
                        int index$1 = 0;
                        block4: while (true) {
                            switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{TopN.class, Project.class}, (Object)selector0$temp, index$1)) {
                                case 0: {
                                    TopN t = (TopN)selector0$temp;
                                    if (((Boolean)stop.get()).booleanValue() || t != topN) {
                                        index$1 = 1;
                                        continue block4;
                                    }
                                    stop.set((Object)true);
                                    logicalPlan = replacementTop;
                                    break block4;
                                }
                                case 1: {
                                    Project pr = (Project)selector0$temp;
                                    if (((Boolean)stop.get()).booleanValue()) {
                                        index$1 = 2;
                                        continue block4;
                                    }
                                    LinkedList<? extends NamedExpression> allFields = new LinkedList<NamedExpression>(pr.projections());
                                    allFields.addAll(replacementTop.fields());
                                    logicalPlan = pr.withProjections(allFields);
                                    break block4;
                                }
                                default: {
                                    logicalPlan = p;
                                    break block4;
                                }
                            }
                            break;
                        }
                        return logicalPlan;
                    });
                    Set<Attribute> missing = this.checkOrderInOutputs(newOrder, transformedEnrich.output());
                    if (!missing.isEmpty()) {
                        throw new IllegalStateException("Attribute [" + String.valueOf(Expressions.names(missing)) + "] required by TopN but is not in the outputs of Enrich [" + String.valueOf(Expressions.names(transformedEnrich.output())) + "]");
                    }
                    TopN copyTop = new TopN(topN.source(), transformedEnrich, newOrder, topN.limit(), false);
                    return new Project(en.source(), copyTop, outputs);
                }
                if (!(plan instanceof Streaming) || plan instanceof ExecutesOn && (ex = (ExecutesOn)((Object)plan)).executesOn() == ExecutesOn.ExecuteLocation.COORDINATOR || plan instanceof Enrich && (e = (Enrich)plan).mode() == Enrich.Mode.REMOTE || plan instanceof PipelineBreaker || !(plan instanceof UnaryPlan)) break;
                UnaryPlan u = (UnaryPlan)plan;
                plan = u.child();
            }
        }
        return en;
    }

    private Set<Attribute> checkOrderInOutputs(List<Order> orderList, List<Attribute> outputs) {
        Set outputsSet = outputs.stream().map(NamedExpression::id).collect(Collectors.toSet());
        HashSet<Attribute> missingAttributes = new HashSet<Attribute>();
        for (Order o : orderList) {
            o.child().forEachDown(Attribute.class, a -> {
                if (!outputsSet.contains(a.id())) {
                    missingAttributes.add((Attribute)a);
                }
            });
        }
        return missingAttributes;
    }
}

