Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"BROADCAST_ARGS": self.convert_broadcast_args,
"CALL": self.convert_call,
"CALL_ONCE": self.convert_call_once,
"COMPLEX_ABS": self.convert_complex_abs,
"CAST": self.convert_cast,
"CEIL": functools.partial(self._convert_unary_elemwise, relax_op=_op.ceil),
"CONCATENATION": self.convert_concatenation,
Expand Down Expand Up @@ -252,6 +253,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"HASHTABLE_LOOKUP": self.convert_hashtable_lookup,
"HASHTABLE_SIZE": self.convert_hashtable_size,
"IF": self.convert_if,
"IMAG": self.convert_imag,
"L2_NORMALIZATION": self.convert_l2_normalization,
"L2_POOL_2D": functools.partial(self.convert_pool2d, pool_type="l2"),
"LEAKY_RELU": self.convert_leaky_relu,
Expand Down Expand Up @@ -295,6 +297,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal,
"RANDOM_UNIFORM": self.convert_random_uniform,
"READ_VARIABLE": self.convert_read_variable,
"REAL": self.convert_real,
"REDUCE_ALL": functools.partial(self._convert_reduce_bool, relax_op=_op.min),
"REDUCE_ANY": functools.partial(self._convert_reduce_bool, relax_op=_op.max),
"REDUCE_MAX": functools.partial(self._convert_reduce, relax_op=_op.max),
Expand All @@ -303,6 +306,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"RELU": self.convert_relu,
"RELU6": self.convert_relu6,
"RELU_N1_TO_1": self.convert_relu_n1_to_1,
"RFFT2D": self.convert_rfft2d,
"RESHAPE": self.convert_reshape,
"RESIZE_BILINEAR": self.convert_resize_bilinear,
"RESIZE_NEAREST_NEIGHBOR": self.convert_resize_nearest_neighbor,
Expand Down Expand Up @@ -7580,6 +7584,66 @@ def convert_fake_quant(self, op):
rounded = relax.op.floor(_op.add(_op.multiply(clamped_shifted, inv_scale), half))
return relax.op.add(_op.multiply(rounded, scale_expr), nudged_min_expr)

def convert_real(self, op):
"""Convert TFLite REAL op.

TFLite complex64 tensors are represented as float32[..., 2] in Relax,
where index 0 = real part, index 1 = imaginary part along the last axis
"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = self.get_expr(input_tensors[0].tensor_idx)
# slice last axis at index 0, and squeeze to remove the last axis
real = _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[-1])
return _op.squeeze(real, axis=[-1])

def convert_imag(self, op):
"""Convert TFLite IMAG op.

See convert_real for representation of complex64 tensors in Relax.
"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = self.get_expr(input_tensors[0].tensor_idx)
# slice last axis at index 1, and squeeze to remove the last axis
imag = _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[-1])
return _op.squeeze(imag, axis=[-1])

def convert_complex_abs(self, op):
"""Convert TFLite COMPLEX_ABS op: sqrt(real^2 + imag^2)

See convert_real for the float32[..., 2] complex representation convention.
"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = self.get_expr(input_tensors[0].tensor_idx)
real = self.bb.emit(
_op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[-1])
)
real = self.bb.emit(_op.squeeze(real, axis=[-1]))
imag = self.bb.emit(
_op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[-1])
)
imag = self.bb.emit(_op.squeeze(imag, axis=[-1]))
real_sq = self.bb.emit(_op.multiply(real, real))
imag_sq = self.bb.emit(_op.multiply(imag, imag))
sum_expr = self.bb.emit(_op.add(real_sq, imag_sq))
return _op.sqrt(sum_expr)

def convert_rfft2d(self, op):
"""Convert TFLite RFFT2D op.

Not implemented: Relax has no native FFT operator and topi.signal.dft
has no C++ registered backend (tvm.get_global_func returns None).
Implement relax.op.signal.rfft2d first, then route here.
"""
raise tvm.error.OpNotImplemented(
"RFFT2D is not supported in the Relax TFLite frontend. "
"topi.signal.dft is pure Python TE with no TVM_REGISTER_GLOBAL entry "
"and cannot be called via call_dps_packed. "
"A native relax.op.signal.rfft2d op is required."
)

def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

Expand Down Expand Up @@ -8044,8 +8108,14 @@ def _input_type(model):
input_shape = tuple(tensor.ShapeAsNumpy())
tensor_type = tensor.Type()
input_name = get_tensor_name(subgraph, input_)
input_dtype = _decode_type(tensor_type)
# Relax models complex64 tensors as float32[..., 2] where the trailing
# dimension stores real/imag parts.
if input_dtype == "complex64":
input_shape = input_shape + (2,)
input_dtype = "float32"
shape_dict[input_name] = input_shape
dtype_dict[input_name] = _decode_type(tensor_type)
dtype_dict[input_name] = input_dtype

return shape_dict, dtype_dict

Expand Down Expand Up @@ -8183,6 +8253,10 @@ def func(self, data):
dtype = (
_dtype_dict[model_input_name] if model_input_name in _dtype_dict else "float32"
)
if dtype == "complex64":
dtype = "float32"
if shape is not None:
shape = tuple(shape) + (2,)
input_var = relax.Var(
name_hint=model_input_name,
struct_info=relax.TensorStructInfo(shape=shape, dtype=dtype),
Expand Down
76 changes: 76 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13020,5 +13020,81 @@ def test_unidirectional_sequence_rnn_time_major():
assert tuple(int(d) for d in out_shape) == (batch, time, num_units)


def test_real():
class Real(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)])
def func(self, x):
return tf.math.real(x)

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
# slice real part (index 0 along last axis)
lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice(
x, axes=[-1], begin=[0], end=[1], strides=[1]
)
gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1])
R.output(gv)
return gv

verify(Real, Expected)


def test_imag():
class Imag(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)])
def func(self, x):
return tf.math.imag(x)

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
# slice imaginary part (index 1 along last axis)
lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice(
x, axes=[-1], begin=[1], end=[2], strides=[1]
)
gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1])
R.output(gv)
return gv

verify(Imag, Expected)


def test_complex_abs():
class ComplexAbs(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)])
def func(self, x):
return tf.math.abs(x)

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv0: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice(
x, axes=[-1], begin=[0], end=[1], strides=[1]
)
real: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv0, axis=[-1])
lv1: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice(
x, axes=[-1], begin=[1], end=[2], strides=[1]
)
imag: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1, axis=[-1])
lv2: R.Tensor((2, 4), dtype="float32") = R.multiply(real, real)
lv3: R.Tensor((2, 4), dtype="float32") = R.multiply(imag, imag)
lv4: R.Tensor((2, 4), dtype="float32") = R.add(lv2, lv3)
gv: R.Tensor((2, 4), dtype="float32") = R.sqrt(lv4)
R.output(gv)
return gv

verify(ComplexAbs, Expected)


if __name__ == "__main__":
pytest.main(["-s", __file__])
Loading