From 7e227302cbc08bc98d3d334a62a1ac476bafed3a Mon Sep 17 00:00:00 2001 From: fnhirwa Date: Sun, 14 Jun 2026 15:41:46 +0200 Subject: [PATCH 1/2] complex operators --- .../relax/frontend/tflite/tflite_frontend.py | 79 ++++++++++++++++++- tests/python/relax/test_frontend_tflite.py | 76 ++++++++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index d14643d75c60..409ddd3cfafc 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -207,6 +207,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "BROADCAST_ARGS": self.convert_broadcast_args, "CALL": self.convert_call, "CALL_ONCE": self.convert_call_once, + "COMPLEX_ABS": self.convert_complex_abs, "CAST": self.convert_cast, "CEIL": functools.partial(self._convert_unary_elemwise, relax_op=_op.ceil), "CONCATENATION": self.convert_concatenation, @@ -252,6 +253,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "HASHTABLE_LOOKUP": self.convert_hashtable_lookup, "HASHTABLE_SIZE": self.convert_hashtable_size, "IF": self.convert_if, + "IMAG": self.convert_imag, "L2_NORMALIZATION": self.convert_l2_normalization, "L2_POOL_2D": functools.partial(self.convert_pool2d, pool_type="l2"), "LEAKY_RELU": self.convert_leaky_relu, @@ -295,6 +297,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal, "RANDOM_UNIFORM": self.convert_random_uniform, "READ_VARIABLE": self.convert_read_variable, + "REAL": self.convert_real, "REDUCE_ALL": functools.partial(self._convert_reduce_bool, relax_op=_op.min), "REDUCE_ANY": functools.partial(self._convert_reduce_bool, relax_op=_op.max), "REDUCE_MAX": functools.partial(self._convert_reduce, relax_op=_op.max), @@ -303,6 +306,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "RELU": self.convert_relu, "RELU6": self.convert_relu6, "RELU_N1_TO_1": self.convert_relu_n1_to_1, + "RFFT2D": self.convert_rfft2d, "RESHAPE": self.convert_reshape, "RESIZE_BILINEAR": self.convert_resize_bilinear, "RESIZE_NEAREST_NEIGHBOR": self.convert_resize_nearest_neighbor, @@ -7580,6 +7584,69 @@ def convert_fake_quant(self, op): rounded = relax.op.floor(_op.add(_op.multiply(clamped_shifted, inv_scale), half)) return relax.op.add(_op.multiply(rounded, scale_expr), nudged_min_expr) + def convert_real(self, op): + """Convert TFLite REAL op. + + TFLite complex64 tensors are represented as float32[..., 2] in Relax, + where index 0 = real part, index 1 = imaginary part along the last axis + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = self.get_expr(input_tensors[0].tensor_idx) + last_axis = int(input_tensor.struct_info.ndim) - 1 + # slice last axis at index 0, and squeeze to remove the last axis + real = _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[last_axis]) + return _op.squeeze(real, axis=[last_axis]) + + def convert_imag(self, op): + """Convert TFLite IMAG op. + + See convert_real for representation of complex64 tensors in Relax. + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = self.get_expr(input_tensors[0].tensor_idx) + last_axis = int(input_tensor.struct_info.ndim) - 1 + # slice last axis at index 1, and squeeze to remove the last axis + imag = _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[last_axis]) + return _op.squeeze(imag, axis=[last_axis]) + + def convert_complex_abs(self, op): + """Convert TFLite COMPLEX_ABS op: sqrt(real^2 + imag^2) + + See convert_real for the float32[..., 2] complex representation convention. + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = self.get_expr(input_tensors[0].tensor_idx) + last_axis = int(input_tensor.struct_info.ndim) - 1 + real = self.bb.emit( + _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[last_axis]) + ) + real = self.bb.emit(_op.squeeze(real, axis=[last_axis])) + imag = self.bb.emit( + _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[last_axis]) + ) + imag = self.bb.emit(_op.squeeze(imag, axis=[last_axis])) + real_sq = self.bb.emit(_op.multiply(real, real)) + imag_sq = self.bb.emit(_op.multiply(imag, imag)) + sum_expr = self.bb.emit(_op.add(real_sq, imag_sq)) + return _op.sqrt(sum_expr) + + def convert_rfft2d(self, op): + """Convert TFLite RFFT2D op. + + Not implemented: Relax has no native FFT operator and topi.signal.dft + has no C++ registered backend (tvm.get_global_func returns None). + Implement relax.op.signal.rfft2d first, then route here. + """ + raise tvm.error.OpNotImplemented( + "RFFT2D is not supported in the Relax TFLite frontend. " + "topi.signal.dft is pure Python TE with no TVM_REGISTER_GLOBAL entry " + "and cannot be called via call_dps_packed. " + "A native relax.op.signal.rfft2d op is required." + ) + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) @@ -8044,8 +8111,14 @@ def _input_type(model): input_shape = tuple(tensor.ShapeAsNumpy()) tensor_type = tensor.Type() input_name = get_tensor_name(subgraph, input_) + input_dtype = _decode_type(tensor_type) + # Relax models complex64 tensors as float32[..., 2] where the trailing + # dimension stores real/imag parts. + if input_dtype == "complex64": + input_shape = input_shape + (2,) + input_dtype = "float32" shape_dict[input_name] = input_shape - dtype_dict[input_name] = _decode_type(tensor_type) + dtype_dict[input_name] = input_dtype return shape_dict, dtype_dict @@ -8183,6 +8256,10 @@ def func(self, data): dtype = ( _dtype_dict[model_input_name] if model_input_name in _dtype_dict else "float32" ) + if dtype == "complex64": + dtype = "float32" + if shape is not None: + shape = tuple(shape) + (2,) input_var = relax.Var( name_hint=model_input_name, struct_info=relax.TensorStructInfo(shape=shape, dtype=dtype), diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index e4483b9d41cc..5492eb8edd32 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -13020,5 +13020,81 @@ def test_unidirectional_sequence_rnn_time_major(): assert tuple(int(d) for d in out_shape) == (batch, time, num_units) +def test_real(): + class Real(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)]) + def func(self, x): + return tf.math.real(x) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + # slice real part (index 0 along last axis) + lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, axes=[2], begin=[0], end=[1], strides=[1] + ) + gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[2]) + R.output(gv) + return gv + + verify(Real, Expected) + + +def test_imag(): + class Imag(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)]) + def func(self, x): + return tf.math.imag(x) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + # slice imaginary part (index 1 along last axis) + lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, axes=[2], begin=[1], end=[2], strides=[1] + ) + gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[2]) + R.output(gv) + return gv + + verify(Imag, Expected) + + +def test_complex_abs(): + class ComplexAbs(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)]) + def func(self, x): + return tf.math.abs(x) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv0: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, axes=[2], begin=[0], end=[1], strides=[1] + ) + real: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv0, axis=[2]) + lv1: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, axes=[2], begin=[1], end=[2], strides=[1] + ) + imag: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1, axis=[2]) + lv2: R.Tensor((2, 4), dtype="float32") = R.multiply(real, real) + lv3: R.Tensor((2, 4), dtype="float32") = R.multiply(imag, imag) + lv4: R.Tensor((2, 4), dtype="float32") = R.add(lv2, lv3) + gv: R.Tensor((2, 4), dtype="float32") = R.sqrt(lv4) + R.output(gv) + return gv + + verify(ComplexAbs, Expected) + + if __name__ == "__main__": pytest.main(["-s", __file__]) From 9c464770ec19b0212731f2a6d5c62458f8979091 Mon Sep 17 00:00:00 2001 From: fnhirwa Date: Sun, 14 Jun 2026 16:18:08 +0200 Subject: [PATCH 2/2] apply gemini suggestions --- .../relax/frontend/tflite/tflite_frontend.py | 19 ++++++++----------- tests/python/relax/test_frontend_tflite.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 409ddd3cfafc..e2a7eed8f1e7 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -7593,10 +7593,9 @@ def convert_real(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = self.get_expr(input_tensors[0].tensor_idx) - last_axis = int(input_tensor.struct_info.ndim) - 1 # slice last axis at index 0, and squeeze to remove the last axis - real = _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[last_axis]) - return _op.squeeze(real, axis=[last_axis]) + real = _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[-1]) + return _op.squeeze(real, axis=[-1]) def convert_imag(self, op): """Convert TFLite IMAG op. @@ -7606,10 +7605,9 @@ def convert_imag(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = self.get_expr(input_tensors[0].tensor_idx) - last_axis = int(input_tensor.struct_info.ndim) - 1 # slice last axis at index 1, and squeeze to remove the last axis - imag = _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[last_axis]) - return _op.squeeze(imag, axis=[last_axis]) + imag = _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[-1]) + return _op.squeeze(imag, axis=[-1]) def convert_complex_abs(self, op): """Convert TFLite COMPLEX_ABS op: sqrt(real^2 + imag^2) @@ -7619,15 +7617,14 @@ def convert_complex_abs(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = self.get_expr(input_tensors[0].tensor_idx) - last_axis = int(input_tensor.struct_info.ndim) - 1 real = self.bb.emit( - _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[last_axis]) + _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[-1]) ) - real = self.bb.emit(_op.squeeze(real, axis=[last_axis])) + real = self.bb.emit(_op.squeeze(real, axis=[-1])) imag = self.bb.emit( - _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[last_axis]) + _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[-1]) ) - imag = self.bb.emit(_op.squeeze(imag, axis=[last_axis])) + imag = self.bb.emit(_op.squeeze(imag, axis=[-1])) real_sq = self.bb.emit(_op.multiply(real, real)) imag_sq = self.bb.emit(_op.multiply(imag, imag)) sum_expr = self.bb.emit(_op.add(real_sq, imag_sq)) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 5492eb8edd32..dc572e1edd75 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -13034,9 +13034,9 @@ def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="flo with R.dataflow(): # slice real part (index 0 along last axis) lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[2], begin=[0], end=[1], strides=[1] + x, axes=[-1], begin=[0], end=[1], strides=[1] ) - gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[2]) + gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1]) R.output(gv) return gv @@ -13057,9 +13057,9 @@ def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="flo with R.dataflow(): # slice imaginary part (index 1 along last axis) lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[2], begin=[1], end=[2], strides=[1] + x, axes=[-1], begin=[1], end=[2], strides=[1] ) - gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[2]) + gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1]) R.output(gv) return gv @@ -13079,13 +13079,13 @@ def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="flo R.func_attr({"num_input": 1}) with R.dataflow(): lv0: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[2], begin=[0], end=[1], strides=[1] + x, axes=[-1], begin=[0], end=[1], strides=[1] ) - real: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv0, axis=[2]) + real: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv0, axis=[-1]) lv1: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[2], begin=[1], end=[2], strides=[1] + x, axes=[-1], begin=[1], end=[2], strides=[1] ) - imag: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1, axis=[2]) + imag: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1, axis=[-1]) lv2: R.Tensor((2, 4), dtype="float32") = R.multiply(real, real) lv3: R.Tensor((2, 4), dtype="float32") = R.multiply(imag, imag) lv4: R.Tensor((2, 4), dtype="float32") = R.add(lv2, lv3)