Skip to content

Commit a61bd17

Browse files
authored
[Fix CI] Convert tiles to sizes for all torch.* functions (#563)
1 parent e5bb08c commit a61bd17

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

helion/language/tile_proxy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ def __torch_function__(
7979
index_calls.count += 1
8080
if func is torch.Tensor.__format__:
8181
return repr(args[0])
82+
83+
# For any other torch.* function or torch.Tensor.* method, convert tiles to sizes
84+
is_torch_func = getattr(func, "__module__", "") == "torch"
85+
is_tensor_method = hasattr(torch.Tensor, getattr(func, "__name__", ""))
86+
if is_torch_func or is_tensor_method:
87+
new_args = cls._tiles_to_sizes(args)
88+
new_kwargs = cls._tiles_to_sizes(kwargs) if kwargs else {}
89+
return func(*new_args, **new_kwargs)
90+
8291
raise exc.IncorrectTileUsage(func)
8392

8493
@staticmethod

0 commit comments

Comments
 (0)