Skip to content

Commit 4d9d129

Browse files
authored
[Relax][ONNX] Fix Cast operator float->int NaN/Inf handling (#19626)
Hi Committers, This PR is trying to fix issues #19542. Any suggestions would be appreciated if you are available. ### Root cause: FP to INT lowering can be implementation-defined or UB for NaN/Inf and extreme floats, producing backend-dependent results versus ONNX Runtime. ### Solution: Apply a minimal, deterministic frontend sanitization for float to integer Casts: map NaN and ±Inf to 0.0 before astype. This prevents NaN/Inf from reaching backend fptosi/fptoui lowers and yields stable behavior across targets. --------- Co-authored-by: cchung100m <cchung100m@users.noreply.github.com>
1 parent 913fc4b commit 4d9d129

2 files changed

Lines changed: 88 additions & 0 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,63 @@ def _impl_v13(cls, bb, inputs, attr, params):
11051105
return relax.const(output, to_type)
11061106
if isinstance(inputs[0], relax.PrimValue):
11071107
return relax.PrimValue(inputs[0].value.astype(to_type))
1108+
1109+
try:
1110+
np_dst = _np.dtype(str(to_type))
1111+
except Exception:
1112+
return relax.op.astype(inputs[0], to_type)
1113+
1114+
if np_dst.kind in ("i", "u"):
1115+
src = inputs[0]
1116+
src_dtype = getattr(getattr(src, "struct_info", None), "dtype", None) or getattr(
1117+
src, "dtype", None
1118+
)
1119+
if src_dtype is not None and _relax_dtype_is_floating_point(src_dtype):
1120+
x_sanitized = bb.emit(
1121+
relax.op.where(
1122+
relax.op.logical_not(relax.op.isfinite(src)),
1123+
relax.const(0.0, src_dtype),
1124+
src,
1125+
)
1126+
)
1127+
dst_str = str(to_type)
1128+
if dst_str.startswith("uint"):
1129+
signed = False
1130+
bits = int(dst_str[4:])
1131+
elif dst_str.startswith("int"):
1132+
signed = True
1133+
bits = int(dst_str[3:])
1134+
else:
1135+
return relax.op.astype(x_sanitized, to_type)
1136+
1137+
if bits == 64:
1138+
return relax.op.astype(x_sanitized, to_type)
1139+
1140+
temp_dtype = "int64" if bits >= 32 else "int32"
1141+
t = relax.op.astype(x_sanitized, temp_dtype)
1142+
if bits == 32:
1143+
two_pow = relax.const(1 << bits, temp_dtype)
1144+
uw = relax.op.floor_mod(t, two_pow)
1145+
else:
1146+
mask_val = (1 << bits) - 1
1147+
mask = relax.const(mask_val, temp_dtype)
1148+
uw = relax.op.bitwise_and(t, mask)
1149+
if signed:
1150+
half = 1 << (bits - 1)
1151+
half_c = relax.const(half, temp_dtype)
1152+
if bits == 32:
1153+
two_pow = relax.const(1 << bits, temp_dtype)
1154+
else:
1155+
two_pow = relax.op.add(mask, relax.const(1, temp_dtype))
1156+
wrapped = relax.op.where(
1157+
relax.op.greater_equal(uw, half_c),
1158+
relax.op.subtract(uw, two_pow),
1159+
uw,
1160+
)
1161+
else:
1162+
wrapped = uw
1163+
return relax.op.astype(wrapped, to_type)
1164+
11081165
return relax.op.astype(inputs[0], to_type)
11091166

11101167

tests/python/relax/test_frontend_onnx.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,37 @@ def test_cast(from_type, to_type):
863863
check_correctness(model, opset=13)
864864

865865

866+
@pytest.mark.parametrize("to_type", [TensorProto.INT64, TensorProto.UINT64])
867+
def test_cast_float_to_64bit_int_dynamic(to_type):
868+
cast_node = helper.make_node("Cast", ["a"], ["b"], to=to_type)
869+
graph = helper.make_graph(
870+
[cast_node],
871+
"cast_float_to_64bit_int_dynamic_test",
872+
inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [1, 8])],
873+
outputs=[helper.make_tensor_value_info("b", to_type, [1, 8])],
874+
)
875+
model = helper.make_model(graph, producer_name="cast_float_to_64bit_int_dynamic_test")
876+
inputs = {"a": np.array([[0.0, 1.2, 2.8, 7.9, 15.1, 31.7, 63.4, 127.9]], dtype=np.float32)}
877+
check_correctness(model, inputs=inputs, opset=13, check_dtypes=True)
878+
879+
880+
def test_cast_nan_inf_to_int8():
881+
vals = np.array([300.0, np.nan, np.inf, -np.inf, 50.0, -50.0], dtype=np.float32)
882+
node = helper.make_node("Cast", inputs=["a"], outputs=["b"], to=TensorProto.INT8)
883+
graph = helper.make_graph(
884+
[node],
885+
"cast_nan_inf_test",
886+
inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, list(vals.shape))],
887+
outputs=[helper.make_tensor_value_info("b", TensorProto.INT8, list(vals.shape))],
888+
)
889+
model = helper.make_model(graph, producer_name="cast_nan_inf_test")
890+
tvm_output = run_in_tvm(model, inputs={"a": vals}, opset=13)
891+
out_np = tvm_output.numpy()
892+
expected = np.array([44, 0, 0, 0, 50, -50], dtype=np.int8)
893+
assert out_np.dtype == np.int8
894+
np.testing.assert_array_equal(out_np, expected)
895+
896+
866897
def test_gather():
867898
def _verify_gather(data_shape, indices, out_shape, axis=0):
868899
gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=axis)

0 commit comments

Comments
 (0)