/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.List;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryPredicate;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
import org.elasticsearch.xpack.esql.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;

public final class BooleanSimplification
extends OptimizerRules.OptimizerExpressionRule<ScalarFunction> {
    public BooleanSimplification() {
        super(OptimizerRules.TransformDirection.UP);
    }

    @Override
    public Expression rule(ScalarFunction e, LogicalOptimizerContext ctx) {
        if (e instanceof And || e instanceof Or) {
            return BooleanSimplification.simplifyAndOr((BinaryPredicate)e);
        }
        if (e instanceof Not) {
            return this.simplifyNot((Not)e);
        }
        return e;
    }

    private static Expression simplifyAndOr(BinaryPredicate<?, ?, ?, ?> bc) {
        Expression l = bc.left();
        Expression r = bc.right();
        if (bc instanceof And) {
            List<Expression> rightSplit;
            if (Literal.TRUE.equals(l)) {
                return r;
            }
            if (Literal.TRUE.equals(r)) {
                return l;
            }
            if (Literal.FALSE.equals(l) || Literal.FALSE.equals(r)) {
                return new Literal(bc.source(), Boolean.FALSE, DataType.BOOLEAN);
            }
            if (l.semanticEquals(r)) {
                return l;
            }
            List<Expression> leftSplit = Predicates.splitOr(l);
            List<Expression> common = Predicates.inCommon(leftSplit, rightSplit = Predicates.splitOr(r));
            if (common.isEmpty()) {
                return bc;
            }
            List<Expression> lDiff = Predicates.subtract(leftSplit, common);
            List<Expression> rDiff = Predicates.subtract(rightSplit, common);
            if (lDiff.isEmpty() || rDiff.isEmpty()) {
                return Predicates.combineOr(common);
            }
            Expression combineLeft = Predicates.combineOr(lDiff);
            Expression combineRight = Predicates.combineOr(rDiff);
            return Predicates.combineOr(CollectionUtils.combine(common, new And(combineLeft.source(), combineLeft, combineRight)));
        }
        if (bc instanceof Or) {
            List<Expression> rightSplit;
            if (Literal.TRUE.equals(l) || Literal.TRUE.equals(r)) {
                return new Literal(bc.source(), Boolean.TRUE, DataType.BOOLEAN);
            }
            if (Literal.FALSE.equals(l)) {
                return r;
            }
            if (Literal.FALSE.equals(r)) {
                return l;
            }
            if (l.semanticEquals(r)) {
                return l;
            }
            List<Expression> leftSplit = Predicates.splitAnd(l);
            List<Expression> common = Predicates.inCommon(leftSplit, rightSplit = Predicates.splitAnd(r));
            if (common.isEmpty()) {
                return bc;
            }
            List<Expression> lDiff = Predicates.subtract(leftSplit, common);
            List<Expression> rDiff = Predicates.subtract(rightSplit, common);
            if (lDiff.isEmpty() || rDiff.isEmpty()) {
                return Predicates.combineAnd(common);
            }
            Expression combineLeft = Predicates.combineAnd(lDiff);
            Expression combineRight = Predicates.combineAnd(rDiff);
            return Predicates.combineAnd(CollectionUtils.combine(common, new Or(combineLeft.source(), combineLeft, combineRight)));
        }
        return bc;
    }

    private Expression simplifyNot(Not n) {
        Expression c = n.field();
        if (Literal.TRUE.semanticEquals(c)) {
            return new Literal(n.source(), Boolean.FALSE, DataType.BOOLEAN);
        }
        if (Literal.FALSE.semanticEquals(c)) {
            return new Literal(n.source(), Boolean.TRUE, DataType.BOOLEAN);
        }
        Expression negated = this.maybeSimplifyNegatable(c);
        if (negated != null) {
            return negated;
        }
        if (c instanceof Not) {
            return ((Not)c).field();
        }
        return n;
    }

    protected Expression maybeSimplifyNegatable(Expression e) {
        return null;
    }
}

