/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.analysis;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.xpack.esql.LicenseAware;
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisVerificationAware;
import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.capabilities.Unresolvable;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.InlineStats;
import org.elasticsearch.xpack.esql.plan.logical.Insist;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Lookup;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.telemetry.FeatureMetric;
import org.elasticsearch.xpack.esql.telemetry.Metrics;

public class Verifier {
    private final List<BiConsumer<LogicalPlan, Failures>> extraCheckers;
    private final Metrics metrics;
    private final XPackLicenseState licenseState;

    public Verifier(Metrics metrics, XPackLicenseState licenseState) {
        this(metrics, licenseState, Collections.emptyList());
    }

    public Verifier(Metrics metrics, XPackLicenseState licenseState, List<BiConsumer<LogicalPlan, Failures>> extraCheckers) {
        this.metrics = metrics;
        this.licenseState = licenseState;
        this.extraCheckers = extraCheckers;
    }

    Collection<Failure> verify(LogicalPlan plan, BitSet partialMetrics) {
        assert (partialMetrics != null);
        Failures failures = new Failures();
        Verifier.checkUnresolvedAttributes(plan, failures);
        if (failures.hasFailures()) {
            return failures.failures();
        }
        List<BiConsumer<LogicalPlan, Failures>> planCheckers = Verifier.planCheckers(plan);
        planCheckers.addAll(this.extraCheckers);
        plan.forEachDown(p -> {
            if (!p.childrenResolved()) {
                return;
            }
            planCheckers.forEach(c -> c.accept(p, failures));
            Verifier.checkOperationsOnUnsignedLong(p, failures);
            Verifier.checkBinaryComparison(p, failures);
            Verifier.checkInsist(p, failures);
            Verifier.checkLimitBeforeInlineStats(p, failures);
        });
        if (!failures.hasFailures()) {
            this.licenseCheck(plan, failures);
        }
        if (!failures.hasFailures()) {
            this.gatherMetrics(plan, partialMetrics);
        }
        return failures.failures();
    }

    private static void checkUnresolvedAttributes(LogicalPlan plan, Failures failures) {
        plan.forEachUp(p -> {
            if (!p.childrenResolved()) {
                return;
            }
            if (p instanceof Unresolvable) {
                Unresolvable u = (Unresolvable)((Object)p);
                failures.add(Failure.fail(p, u.unresolvedMessage(), new Object[0]));
            } else if (p.resolved()) {
                return;
            }
            Consumer<Expression> unresolvedExpressions = e -> {
                if (e.resolved()) {
                    return;
                }
                e.forEachUp(ae -> {
                    if (p instanceof Project || p instanceof Insist) {
                        Alias as;
                        Expression patt0$temp;
                        if (ae instanceof Alias && (patt0$temp = (as = (Alias)ae).child()) instanceof UnsupportedAttribute) {
                            UnsupportedAttribute ua = (UnsupportedAttribute)patt0$temp;
                            failures.add(Failure.fail(ae, ua.unresolvedMessage(), new Object[0]));
                        }
                        if (ae instanceof UnsupportedAttribute) {
                            return;
                        }
                    }
                    if (!ae.childrenResolved()) {
                        return;
                    }
                    if (ae instanceof Unresolvable) {
                        Unresolvable u = (Unresolvable)((Object)ae);
                        failures.add(Failure.fail(ae, u.unresolvedMessage(), new Object[0]));
                    }
                    if (ae.typeResolved().unresolved()) {
                        failures.add(Failure.fail(ae, ae.typeResolved().message(), new Object[0]));
                    }
                });
            };
            if (p instanceof Aggregate) {
                Aggregate agg = (Aggregate)p;
                List<Expression> groupings = agg.groupings();
                groupings.forEach(unresolvedExpressions);
                boolean hasGroupByAll = false;
                for (Expression grouping : groupings) {
                    if (!MetadataAttribute.isTimeSeriesAttribute(grouping)) continue;
                    hasGroupByAll = true;
                    break;
                }
                int groupingSize = hasGroupByAll ? groupings.size() - 1 : groupings.size();
                List<? extends NamedExpression> aggs = agg.aggregates();
                int size = aggs.size() - groupingSize;
                aggs.subList(0, size).forEach(unresolvedExpressions);
            } else if (p instanceof Lookup) {
                Lookup lookup = (Lookup)p;
                Expression tableName = lookup.tableName();
                if (tableName instanceof Unresolvable) {
                    Unresolvable u = (Unresolvable)((Object)tableName);
                    failures.add(Failure.fail(tableName, u.unresolvedMessage(), new Object[0]));
                } else {
                    lookup.matchFields().forEach(unresolvedExpressions);
                }
            } else {
                p.forEachExpression(unresolvedExpressions);
            }
        });
    }

    private static List<BiConsumer<LogicalPlan, Failures>> planCheckers(LogicalPlan plan) {
        ArrayList<BiConsumer<LogicalPlan, Failures>> planCheckers = new ArrayList<BiConsumer<LogicalPlan, Failures>>();
        Consumer<Node> collectPlanCheckers = p -> {
            if (p instanceof PostAnalysisPlanVerificationAware) {
                PostAnalysisPlanVerificationAware pva = (PostAnalysisPlanVerificationAware)((Object)p);
                planCheckers.add(pva.postAnalysisPlanVerification());
            }
        };
        plan.forEachDown(p -> {
            collectPlanCheckers.accept((Node)p);
            p.forEachExpression(collectPlanCheckers);
            if (p instanceof PostAnalysisVerificationAware) {
                PostAnalysisVerificationAware va = (PostAnalysisVerificationAware)((Object)p);
                planCheckers.add((lp, failures) -> {
                    if (lp.getClass().equals(va.getClass())) {
                        va.postAnalysisVerification((Failures)failures);
                    }
                });
            }
        });
        return planCheckers;
    }

    private static void checkOperationsOnUnsignedLong(LogicalPlan p, Failures failures) {
        p.forEachExpression(e -> {
            Failure f = null;
            if (e instanceof BinaryOperator) {
                BinaryOperator bo = (BinaryOperator)e;
                f = Verifier.validateUnsignedLongOperator(bo);
            } else if (e instanceof Neg) {
                Neg neg = (Neg)e;
                f = Verifier.validateUnsignedLongNegation(neg);
            }
            if (f != null) {
                failures.add(f);
            }
        });
    }

    private static void checkBinaryComparison(LogicalPlan p, Failures failures) {
        p.forEachExpression(BinaryComparison.class, bc -> {
            Failure f = Verifier.validateBinaryComparison(bc);
            if (f != null) {
                failures.add(f);
            }
        });
    }

    private static void checkInsist(LogicalPlan p, Failures failures) {
        Insist i;
        LogicalPlan child;
        if (p instanceof Insist && !((child = (i = (Insist)p).child()) instanceof EsRelation || child instanceof Insist)) {
            failures.add(Failure.fail(i, "[insist] can only be used after [from] or [insist] commands, but was [{}]", child.sourceText()));
        }
    }

    private static void checkLimitBeforeInlineStats(LogicalPlan plan, Failures failures) {
        if (plan instanceof InlineStats) {
            InlineStats is = (InlineStats)plan;
            Holder inlineStatsDescendantLimit = new Holder();
            is.forEachDownMayReturnEarly((p, breakEarly) -> {
                if (p instanceof Limit) {
                    Limit l = (Limit)p;
                    inlineStatsDescendantLimit.set((Object)l);
                    breakEarly.set((Object)true);
                    return;
                }
            });
            Limit firstLimit = (Limit)inlineStatsDescendantLimit.get();
            if (firstLimit != null) {
                String isString = is.sourceText().length() > 110 ? is.sourceText().substring(0, 110) + "..." : is.sourceText();
                String limitString = firstLimit.sourceText().length() > 110 ? firstLimit.sourceText().substring(0, 110) + "..." : firstLimit.sourceText();
                failures.add(Failure.fail(is, "INLINE STATS cannot be used after an explicit or implicit LIMIT command, but was [{}] after [{}] [{}]", isString, limitString, firstLimit.source().source().toString()));
            }
        }
    }

    private void licenseCheck(LogicalPlan plan, Failures failures) {
        Consumer<Node> licenseCheck = n -> {
            LicenseAware la;
            if (n instanceof LicenseAware && !(la = (LicenseAware)((Object)n)).licenseCheck(this.licenseState)) {
                failures.add(Failure.fail(n, "current license is non-compliant for [{}]", n.sourceText()));
            }
        };
        plan.forEachDown(p -> {
            licenseCheck.accept((Node)p);
            p.forEachExpression(Expression.class, licenseCheck);
        });
    }

    private void gatherMetrics(LogicalPlan plan, BitSet b) {
        plan.forEachDown(p -> FeatureMetric.set(p, b));
        int i = b.nextSetBit(0);
        while (i >= 0) {
            this.metrics.inc(FeatureMetric.values()[i]);
            i = b.nextSetBit(i + 1);
        }
        HashSet functions = new HashSet();
        plan.forEachExpressionDown(Function.class, p -> functions.add(p.getClass()));
        functions.forEach(f -> this.metrics.incFunctionMetric((Class<?>)f));
    }

    public XPackLicenseState licenseState() {
        return this.licenseState;
    }

    public static Failure validateBinaryComparison(BinaryComparison bc) {
        if (bc.left().dataType().isNumeric()) {
            if (!bc.right().dataType().isNumeric()) {
                return Failure.fail(bc, "first argument of [{}] is [numeric] so second argument must also be [numeric] but was [{}]", bc.sourceText(), bc.right().dataType().typeName());
            }
            return null;
        }
        ArrayList<DataType> allowed = new ArrayList<DataType>();
        allowed.add(DataType.KEYWORD);
        allowed.add(DataType.TEXT);
        allowed.add(DataType.IP);
        allowed.add(DataType.DATETIME);
        allowed.add(DataType.DATE_NANOS);
        allowed.add(DataType.VERSION);
        allowed.add(DataType.GEO_POINT);
        allowed.add(DataType.GEO_SHAPE);
        allowed.add(DataType.CARTESIAN_POINT);
        allowed.add(DataType.CARTESIAN_SHAPE);
        allowed.add(DataType.GEOHASH);
        allowed.add(DataType.GEOTILE);
        allowed.add(DataType.GEOHEX);
        if (bc instanceof Equals || bc instanceof NotEquals) {
            allowed.add(DataType.BOOLEAN);
        }
        Expression.TypeResolution r = TypeResolutions.isType(bc.left(), allowed::contains, bc.sourceText(), TypeResolutions.ParamOrdinal.FIRST, (String[])Stream.concat(Stream.of("numeric"), allowed.stream().map(DataType::typeName)).toArray(String[]::new));
        if (!r.resolved()) {
            return Failure.fail(bc, r.message(), new Object[0]);
        }
        if (DataType.isString(bc.left().dataType()) && DataType.isString(bc.right().dataType())) {
            return null;
        }
        if (bc.left().dataType().isDate() && bc.right().dataType().isDate()) {
            return null;
        }
        if (bc.left().dataType() != bc.right().dataType()) {
            return Failure.fail(bc, "first argument of [{}] is [{}] so second argument must also be [{}] but was [{}]", bc.sourceText(), bc.left().dataType().typeName(), bc.left().dataType().typeName(), bc.right().dataType().typeName());
        }
        return null;
    }

    public static Failure validateUnsignedLongOperator(BinaryOperator<?, ?, ?, ?> bo) {
        DataType leftType = bo.left().dataType();
        DataType rightType = bo.right().dataType();
        if ((leftType == DataType.UNSIGNED_LONG || rightType == DataType.UNSIGNED_LONG) && leftType != rightType) {
            return Failure.fail(bo, "first argument of [{}] is [{}] and second is [{}]. [{}] can only be operated on together with another [{}]", bo.sourceText(), leftType.typeName(), rightType.typeName(), DataType.UNSIGNED_LONG.typeName(), DataType.UNSIGNED_LONG.typeName());
        }
        return null;
    }

    private static Failure validateUnsignedLongNegation(Neg neg) {
        DataType childExpressionType = neg.field().dataType();
        if (childExpressionType.equals((Object)DataType.UNSIGNED_LONG)) {
            return Failure.fail(neg, "negation unsupported for arguments of type [{}] in expression [{}]", childExpressionType.typeName(), neg.sourceText());
        }
        return null;
    }
}

