Skip to content

Commit c8daec5

Browse files
jinhongyiiclaude
andcommitted
[Arith] Gate canonical-simplify LT Case 2 on extra scale == +1
CanonicalSimplifier::Impl::VisitExpr_(LTNode) Case 2 rewrites a "scaled-by-d sum plus a single leftover split" comparison S + xn < 0 <=> S/d + (xn // d) < 0 where d = gcd(scales) into one where the leftover yn % m gets replaced by floormod(floordiv(yn, d*L), m/(d*L)). The Case 1 derivation that justifies dropping the remainder xn % d in [0, d) only works when xn >= 0. With scale = -1 the equivalence becomes <= rather than <, and the rewrite silently strengthens the predicate by dropping the boundary case S/d == xn // d. This surfaced as a miscompile in kernels that mask a per-lane write by `row > col`, where `row = (lane_id // 4) + 16 * warp_id` and `col = 2 * (lane_id % 4)` are independent projections of the same lane id. After CSE+inlining the comparison hit canonical_simplify with the divided projection on the LHS (scale = -1), and Case 2 folded `2*(tx%4) < 16*warp + (tx%32)//4` into a plain `0 < warp_id`, zeroing every thread that should have written `val` in warp 0. The same path also folded other configurations (e.g. `0 < (tx%32) - 8*warp`) all the way to False. Gate Case 2 with `extra->args[0]->scale == 1`. The original target shape (`(yn % m)` with positive scale and lower_factor=1, as well as the scale=+1 + lower_factor>1 generalization) is unchanged; both are covered by the existing `test_simplify_le` cases and by the new `test_simplify_le_negative_scale_extra` regression test, which also pins the buggy scale=-1 shape to its unsimplified form and re-asserts that the truly-always-true `r=2` variant still folds to True. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 1240649 commit c8daec5

2 files changed

Lines changed: 53 additions & 2 deletions

File tree

src/arith/canonical_simplify.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,10 +1419,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) {
14191419
// Case 1. 0 <= xn < d
14201420
divisible.CopyOnWrite()->DivideBy(gcd);
14211421
return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype));
1422-
} else if (extra->args.size() == 1 &&
1422+
} else if (extra->args.size() == 1 && extra->args[0]->scale == 1 &&
14231423
extra->args[0]->upper_factor != ConstIntBoundNode::kPosInf &&
14241424
extra->args[0]->upper_factor % (gcd * extra->args[0]->lower_factor) == 0) {
1425-
// Case 2. xn == yn % m, where m % d == 0
1425+
// Case 2. xn == ((yn % m) // L), scale = +1, m % (d*L) == 0.
1426+
// S + xn < 0 with S divisible by d ⇔ S/d + xn // d < 0, because
1427+
// xn % d ∈ [0, d) lets us drop the remainder via the Case 1 argument,
1428+
// and xn // d = (yn // (d*L)) % (m/(d*L)).
1429+
// The scale must be +1: with scale = -1 the equivalence becomes ≤
1430+
// rather than <, so the rewrite would strengthen the predicate and
1431+
// silently drop the boundary S/d == xn // d (e.g. row > col where
1432+
// row and col are independent projections of the same lane id).
14261433
divisible.CopyOnWrite()->DivideBy(gcd);
14271434
const auto split_expr = extra->args[0];
14281435
int64_t lower_factor = gcd * extra->args[0]->lower_factor;

tests/python/arith/test_arith_canonical_simplify.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,5 +490,49 @@ def test_simplify_le():
490490
ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0)
491491

492492

493+
def test_simplify_le_negative_scale_extra():
494+
"""Regression: Case 2 of the LT-with-divisible-coeffs rewrite must not
495+
fire when the leftover split term has a negative scale.
496+
497+
The rewrite ``S + xn < 0 ⇔ S/d + xn // d < 0`` is only sound when
498+
the leftover ``xn`` has scale ``+1``. With scale ``-1`` the equivalence
499+
becomes ``≤`` rather than ``<`` and the rewrite silently strengthens
500+
the predicate. The original bug surfaced as ``row > col`` masks of
501+
``.16x*b`` tcgen05 readbacks collapsing to plain ``warp_id > k``
502+
comparisons (lower-triangle writes were silently dropped on the
503+
boundary warp).
504+
"""
505+
ck = CanonicalChecker()
506+
tx = tvm.tirx.Var("tx", "int32")
507+
warp = tvm.tirx.Var("warp", "int32")
508+
ck.analyzer.bind(tx, tvm.ir.Range(0, 128))
509+
ck.analyzer.bind(warp, tvm.ir.Range(0, 4))
510+
511+
# Same-source joint projection: the comparison genuinely depends on tx
512+
# at warp == 0 (e.g. tx == 4 ⇒ 0 < 1 = True; tx == 1 ⇒ 2 < 0 = False),
513+
# so the simplifier must keep both sides. Pre-fix this folded to
514+
# ``0 < warp`` and dropped every True case in warp 0.
515+
expr = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4
516+
ck.verify(expr, expr)
517+
518+
# The simpler ``scale = -1`` with ``lower_factor = 1`` shape. Pre-fix
519+
# this folded to ``False`` (drops all warp >= 1 cases where the rhs
520+
# actually exceeds 8*warp).
521+
expr = warp * 8 < (tx % 32)
522+
ck.verify(expr, expr)
523+
524+
# The corresponding ``scale = +1`` Case 2 path (the rewrite this guards)
525+
# must still optimize — verifies we did not over-restrict.
526+
x1 = tvm.tirx.Var("x1", "int32")
527+
y1 = tvm.tirx.Var("y1", "int32")
528+
ck.verify(x1 * 64 + (y1 % 64) < 120, x1 * 8 + (y1 % 64) // 8 < 15)
529+
530+
# The truly-always-true comparison that arises from the same kernel
531+
# (``r = 2 / va = 1`` in the tcgen05.ld.16x256b readback) must still
532+
# fold to True so the masked store can be elided.
533+
expr_true = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4 + 8
534+
ck.verify(expr_true, tvm.tirx.const(True, "bool"))
535+
536+
493537
if __name__ == "__main__":
494538
tvm.testing.main()

0 commit comments

Comments
 (0)