/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

package org.elasticsearch.xpack.esql.parser;

import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.misc.Interval;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.ParseTreeVisitor;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.xpack.esql.core.tree.Location;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.util.Check;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static java.util.Collections.emptyList;

public final class ParserUtils {
    public enum ParamClassification {
        VALUE,
        IDENTIFIER,
        PATTERN
    }

    public static final Map<String, ParamClassification> paramClassifications = Maps.newMapWithExpectedSize(
        ParamClassification.values().length
    );

    static {
        for (ParamClassification e : ParamClassification.values()) {
            paramClassifications.put(e.name(), e);
        }
    }

    private ParserUtils() {}

    public static Object visit(Function<ParseTree, Object> visitor, ParseTree tree) {
        Object result = visitor.apply(tree);
        Check.notNull(result, "Don't know how to handle context [{}] with value [{}]", tree.getClass(), tree.getText());
        return result;
    }

    public static <T> List<T> visitList(ParseTreeVisitor<?> visitor, List<? extends ParserRuleContext> contexts, Class<T> clazz) {
        if (contexts == null || contexts.isEmpty()) {
            return emptyList();
        }

        List<T> results = new ArrayList<>(contexts.size());
        for (ParserRuleContext context : contexts) {
            results.add(clazz.cast(visitor.visit(context)));
        }
        return results;
    }

    @SuppressWarnings("unchecked")
    public static <T> T typedParsing(ParseTreeVisitor<?> visitor, ParseTree ctx, Class<T> type) {
        Object result = ctx.accept(visitor);
        if (type.isInstance(result)) {
            return (T) result;
        }

        throw new ParsingException(
            source(ctx),
            "Invalid query '{}'[{}] given; expected {} but found {}",
            ctx.getText(),
            ctx.getClass().getSimpleName(),
            type.getSimpleName(),
            (result != null ? result.getClass().getSimpleName() : "null")
        );
    }

    public static Source source(ParseTree ctx) {
        if (ctx instanceof ParserRuleContext) {
            return source((ParserRuleContext) ctx);
        }
        return Source.EMPTY;
    }

    public static Source source(TerminalNode terminalNode) {
        Check.notNull(terminalNode, "terminalNode is null");
        return source(terminalNode.getSymbol());
    }

    public static Source source(ParserRuleContext parserRuleContext) {
        Check.notNull(parserRuleContext, "parserRuleContext is null");
        Token start = parserRuleContext.start;
        Token stop = parserRuleContext.stop != null ? parserRuleContext.stop : start;
        return source(start, stop);
    }

    public static Source source(Token token) {
        Check.notNull(token, "token is null");
        String text = token.getInputStream().getText(new Interval(token.getStartIndex(), token.getStopIndex()));
        return new Source(new Location(token.getLine(), token.getCharPositionInLine()), text);
    }

    public static Source source(ParserRuleContext begin, ParserRuleContext end) {
        Check.notNull(begin, "begin is null");
        Check.notNull(end, "end is null");
        Token start = begin.start;
        Token stop = end.stop != null ? end.stop : begin.stop;
        return source(start, stop);
    }

    public static Source source(TerminalNode begin, ParserRuleContext end) {
        Check.notNull(begin, "begin is null");
        Check.notNull(end, "end is null");
        Token start = begin.getSymbol();
        Token stop = end.stop != null ? end.stop : start;
        return source(start, stop);
    }

    public static Source source(Token start, Token stop) {
        Check.notNull(start, "start is null");
        stop = stop == null ? start : stop;
        String text = start.getInputStream().getText(new Interval(start.getStartIndex(), stop.getStopIndex()));
        return new Source(new Location(start.getLine(), start.getCharPositionInLine()), text);
    }

    /**
     * Retrieves the raw text of the node (without interpreting it as a string literal).
     */
    public static String text(ParseTree node) {
        return node == null ? null : node.getText();
    }
}
