/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.parser.promql;

import java.time.Duration;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.Arithmetics;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.plan.logical.promql.operator.arithmetic.VectorBinaryArithmetic;
import org.elasticsearch.xpack.esql.plan.logical.promql.operator.comparison.VectorBinaryComparison;

public class PromqlFoldingUtils {
    public static Object evaluate(Source source, Object left, Object right, VectorBinaryArithmetic.ArithmeticOp operation) {
        if (left instanceof Duration) {
            Duration leftDuration = (Duration)left;
            if (right instanceof Duration) {
                Duration rightDuration = (Duration)right;
                return PromqlFoldingUtils.arithmetics(source, leftDuration, rightDuration, operation);
            }
            if (right instanceof Number) {
                Number rightNumber = (Number)right;
                return PromqlFoldingUtils.arithmetics(source, leftDuration, rightNumber, operation);
            }
        } else if (left instanceof Number) {
            Number leftNumber = (Number)left;
            if (right instanceof Duration) {
                Duration rightDuration = (Duration)right;
                return PromqlFoldingUtils.arithmetics(source, leftNumber, rightDuration, operation);
            }
            if (right instanceof Number) {
                Number rightNumber = (Number)right;
                return PromqlFoldingUtils.numericArithmetics(source, leftNumber, rightNumber, operation);
            }
        }
        throw new ParsingException(source, "Cannot perform arithmetic between [{}] and [{}]", left.getClass().getSimpleName(), right.getClass().getSimpleName());
    }

    private static Duration arithmetics(Source source, Duration left, Duration right, VectorBinaryArithmetic.ArithmeticOp op) {
        Duration result = switch (op) {
            case VectorBinaryArithmetic.ArithmeticOp.ADD -> left.plus(right);
            case VectorBinaryArithmetic.ArithmeticOp.SUB -> left.minus(right);
            default -> throw new ParsingException(source, "Operation [{}] not supported between two durations", op);
        };
        return result;
    }

    private static Duration arithmetics(Source source, Duration duration, Number scalar, VectorBinaryArithmetic.ArithmeticOp op) {
        long durationSeconds = duration.getSeconds();
        long scalarValue = scalar.longValue();
        long resultSeconds = switch (op) {
            default -> throw new MatchException(null, null);
            case VectorBinaryArithmetic.ArithmeticOp.ADD -> Math.addExact(durationSeconds, scalarValue);
            case VectorBinaryArithmetic.ArithmeticOp.SUB -> Math.subtractExact(durationSeconds, scalarValue);
            case VectorBinaryArithmetic.ArithmeticOp.MUL -> Math.round((double)durationSeconds * scalar.doubleValue());
            case VectorBinaryArithmetic.ArithmeticOp.DIV -> {
                if (scalarValue == 0L) {
                    throw new ParsingException(source, "Cannot divide duration by zero", new Object[0]);
                }
                yield Math.round((double)durationSeconds / scalar.doubleValue());
            }
            case VectorBinaryArithmetic.ArithmeticOp.MOD -> {
                if (scalarValue == 0L) {
                    throw new ParsingException(source, "Cannot compute modulo with zero", new Object[0]);
                }
                yield Math.floorMod(durationSeconds, scalarValue);
            }
            case VectorBinaryArithmetic.ArithmeticOp.POW -> Math.round(Math.pow(durationSeconds, scalarValue));
        };
        return Duration.ofSeconds(resultSeconds);
    }

    private static Duration arithmetics(Source source, Number scalar, Duration duration, VectorBinaryArithmetic.ArithmeticOp op) {
        return switch (op) {
            case VectorBinaryArithmetic.ArithmeticOp.ADD -> PromqlFoldingUtils.arithmetics(source, duration, scalar, VectorBinaryArithmetic.ArithmeticOp.ADD);
            case VectorBinaryArithmetic.ArithmeticOp.SUB -> PromqlFoldingUtils.arithmetics(source, Duration.ofSeconds(scalar.longValue()), duration, VectorBinaryArithmetic.ArithmeticOp.SUB);
            case VectorBinaryArithmetic.ArithmeticOp.MUL -> PromqlFoldingUtils.arithmetics(source, duration, scalar, VectorBinaryArithmetic.ArithmeticOp.MUL);
            default -> throw new ParsingException(source, "Operation [{}] not supported with scalar on left and duration on right", op);
        };
    }

    private static Number numericArithmetics(Source source, Number left, Number right, VectorBinaryArithmetic.ArithmeticOp op) {
        try {
            return switch (op) {
                default -> throw new MatchException(null, null);
                case VectorBinaryArithmetic.ArithmeticOp.ADD -> Arithmetics.add((Number)left, (Number)right);
                case VectorBinaryArithmetic.ArithmeticOp.SUB -> Arithmetics.sub((Number)left, (Number)right);
                case VectorBinaryArithmetic.ArithmeticOp.MUL -> Arithmetics.mul((Number)left, (Number)right);
                case VectorBinaryArithmetic.ArithmeticOp.DIV -> Arithmetics.div((Number)left, (Number)right);
                case VectorBinaryArithmetic.ArithmeticOp.MOD -> Arithmetics.mod((Number)left, (Number)right);
                case VectorBinaryArithmetic.ArithmeticOp.POW -> {
                    double result = Math.pow(left.doubleValue(), right.doubleValue());
                    if (Double.isFinite(result) && result == (double)((long)result)) {
                        if (result >= -2.147483648E9 && result <= 2.147483647E9) {
                            yield (int)result;
                        }
                        yield (long)result;
                    }
                    yield result;
                }
            };
        }
        catch (ArithmeticException e) {
            throw new ParsingException(source, "Arithmetic error: {}", e.getMessage());
        }
    }

    public static boolean evaluate(Source source, Object left, Object right, VectorBinaryComparison.ComparisonOp operation) {
        if (left instanceof Number) {
            Number ln = (Number)left;
            if (right instanceof Number) {
                Number rn = (Number)right;
                double l = ln.doubleValue();
                double r = rn.doubleValue();
                return switch (operation) {
                    default -> throw new MatchException(null, null);
                    case VectorBinaryComparison.ComparisonOp.EQ -> {
                        if (l == r) {
                            yield true;
                        }
                        yield false;
                    }
                    case VectorBinaryComparison.ComparisonOp.NEQ -> {
                        if (l != r) {
                            yield true;
                        }
                        yield false;
                    }
                    case VectorBinaryComparison.ComparisonOp.GT -> {
                        if (l > r) {
                            yield true;
                        }
                        yield false;
                    }
                    case VectorBinaryComparison.ComparisonOp.GTE -> {
                        if (l >= r) {
                            yield true;
                        }
                        yield false;
                    }
                    case VectorBinaryComparison.ComparisonOp.LT -> {
                        if (l < r) {
                            yield true;
                        }
                        yield false;
                    }
                    case VectorBinaryComparison.ComparisonOp.LTE -> l <= r;
                };
            }
        }
        throw new ParsingException(source, "Cannot perform comparison between [{}] and [{}]", left.getClass().getSimpleName(), right.getClass().getSimpleName());
    }

    private static void validatePositiveDuration(Source source, Duration duration) {
        if (duration.isNegative() || duration.isZero()) {
            throw new ParsingException(source, "Duration must be positive, got [{}]", duration);
        }
    }
}

