/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.time.DateTimeException;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiFunction;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.ArithmeticOperation;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.BinaryComparisonInversible;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;

public final class SimplifyComparisonsArithmetics
extends OptimizerRules.OptimizerExpressionRule<BinaryComparison> {
    BiFunction<DataType, DataType, Boolean> typesCompatible;

    public SimplifyComparisonsArithmetics(BiFunction<DataType, DataType, Boolean> typesCompatible) {
        super(OptimizerRules.TransformDirection.UP);
        this.typesCompatible = typesCompatible;
    }

    @Override
    protected Expression rule(BinaryComparison bc, LogicalOptimizerContext ctx) {
        if (bc.right() instanceof Literal) {
            if (bc.left() instanceof ArithmeticOperation) {
                return this.simplifyBinaryComparison(ctx.foldCtx(), bc);
            }
            if (bc.left() instanceof Neg) {
                return SimplifyComparisonsArithmetics.foldNegation(ctx.foldCtx(), bc);
            }
        }
        return bc;
    }

    private Expression simplifyBinaryComparison(FoldContext foldContext, BinaryComparison comparison) {
        ArithmeticOperation operation = (ArithmeticOperation)comparison.left();
        String opSymbol = operation.symbol();
        if (opSymbol.equals(DefaultBinaryArithmeticOperation.MOD.symbol())) {
            return comparison;
        }
        OperationSimplifier simplification = null;
        if (SimplifyComparisonsArithmetics.isMulOrDiv(opSymbol)) {
            simplification = new MulDivSimplifier(foldContext, comparison);
        } else if (opSymbol.equals(DefaultBinaryArithmeticOperation.ADD.symbol()) || opSymbol.equals(DefaultBinaryArithmeticOperation.SUB.symbol())) {
            simplification = new AddSubSimplifier(foldContext, comparison);
        }
        return simplification == null || simplification.isUnsafe(this.typesCompatible) ? comparison : simplification.apply();
    }

    private static boolean isMulOrDiv(String opSymbol) {
        return opSymbol.equals(DefaultBinaryArithmeticOperation.MUL.symbol()) || opSymbol.equals(DefaultBinaryArithmeticOperation.DIV.symbol());
    }

    private static Expression foldNegation(FoldContext ctx, BinaryComparison bc) {
        Literal bcLiteral = (Literal)bc.right();
        Expression literalNeg = SimplifyComparisonsArithmetics.tryFolding(ctx, new Neg(bcLiteral.source(), bcLiteral));
        return literalNeg == null ? bc : bc.reverse().replaceChildren((List)Arrays.asList(((Neg)bc.left()).field(), literalNeg));
    }

    private static Expression tryFolding(FoldContext ctx, Expression expression) {
        if (expression.foldable()) {
            try {
                expression = new Literal(expression.source(), expression.fold(ctx), expression.dataType());
            }
            catch (ArithmeticException | DateTimeException e) {
                expression = null;
            }
        }
        return expression;
    }

    private static class MulDivSimplifier
    extends OperationSimplifier {
        private final boolean isDiv;
        private final int opRightSign;

        MulDivSimplifier(FoldContext foldContext, BinaryComparison comparison) {
            super(foldContext, comparison);
            this.isDiv = this.operation.symbol().equals(DefaultBinaryArithmeticOperation.DIV.symbol());
            this.opRightSign = MulDivSimplifier.sign(this.opRight);
        }

        @Override
        boolean isOpUnsafe() {
            if (this.operation.dataType().isWholeNumber() && this.isDiv) {
                return true;
            }
            if (!this.isDiv && this.opLeft.dataType().isWholeNumber()) {
                long opLiteralValue = ((Number)this.opLiteral.value()).longValue();
                return opLiteralValue == 0L || ((Number)this.bcLiteral.value()).longValue() % opLiteralValue != 0L;
            }
            return this.opRightSign == 0;
        }

        @Override
        Expression postProcess(BinaryComparison binaryComparison) {
            return this.opRightSign < 0 ? binaryComparison.reverse() : binaryComparison;
        }

        private static int sign(Object obj) {
            ArithmeticOperation operation;
            int sign = 1;
            if (obj instanceof Number) {
                sign = (int)Math.signum(((Number)obj).doubleValue());
            } else if (obj instanceof Literal) {
                sign = MulDivSimplifier.sign(((Literal)obj).value());
            } else if (obj instanceof Neg) {
                sign = -MulDivSimplifier.sign(((Neg)obj).field());
            } else if (obj instanceof ArithmeticOperation && SimplifyComparisonsArithmetics.isMulOrDiv((operation = (ArithmeticOperation)obj).symbol())) {
                sign = MulDivSimplifier.sign(operation.left()) * MulDivSimplifier.sign(operation.right());
            }
            return sign;
        }
    }

    private static class AddSubSimplifier
    extends OperationSimplifier {
        AddSubSimplifier(FoldContext foldContext, BinaryComparison comparison) {
            super(foldContext, comparison);
        }

        @Override
        boolean isOpUnsafe() {
            if (this.operation.dataType().isRationalNumber()) {
                return true;
            }
            if (this.operation.symbol().equals(DefaultBinaryArithmeticOperation.SUB.symbol()) && !(this.opRight instanceof Literal)) {
                return SimplifyComparisonsArithmetics.tryFolding(this.foldContext, new Sub(Source.EMPTY, this.opLeft, this.bcLiteral)) == null;
            }
            return false;
        }
    }

    private static abstract class OperationSimplifier {
        final FoldContext foldContext;
        final BinaryComparison comparison;
        final Literal bcLiteral;
        final ArithmeticOperation operation;
        final Expression opLeft;
        final Expression opRight;
        final Literal opLiteral;

        OperationSimplifier(FoldContext foldContext, BinaryComparison comparison) {
            this.foldContext = foldContext;
            this.comparison = comparison;
            this.operation = (ArithmeticOperation)comparison.left();
            this.bcLiteral = (Literal)comparison.right();
            this.opLeft = this.operation.left();
            this.opRight = this.operation.right();
            this.opLiteral = this.opLeft instanceof Literal ? (Literal)this.opLeft : (this.opRight instanceof Literal ? (Literal)this.opRight : null);
        }

        final boolean isUnsafe(BiFunction<DataType, DataType, Boolean> typesCompatible) {
            if (this.opLiteral == null) {
                return true;
            }
            if (this.opLiteral.dataType().isRationalNumber() || this.bcLiteral.dataType().isRationalNumber()) {
                return true;
            }
            if (!typesCompatible.apply(this.bcLiteral.dataType(), this.opLiteral.dataType()).booleanValue()) {
                return true;
            }
            return this.isOpUnsafe();
        }

        final Expression apply() {
            Literal bcl = this.operation.dataType().isRationalNumber() ? new Literal(this.bcLiteral.source(), ((Number)this.bcLiteral.value()).doubleValue(), DataType.DOUBLE) : this.bcLiteral;
            Expression bcRightExpression = ((BinaryComparisonInversible)((Object)this.operation)).binaryComparisonInverse().create(bcl.source(), bcl, this.opRight);
            bcRightExpression = SimplifyComparisonsArithmetics.tryFolding(this.foldContext, bcRightExpression);
            return bcRightExpression != null ? this.postProcess((BinaryComparison)this.comparison.replaceChildren(List.of(this.opLeft, bcRightExpression))) : this.comparison;
        }

        abstract boolean isOpUnsafe();

        Expression postProcess(BinaryComparison binaryComparison) {
            return binaryComparison;
        }
    }
}

