Skip to content

Commit 3af8a82

Browse files
committed
[Ref Mode] Make Tile apis work in ref eager mode
stack-info: PR: #378, branch: yf225/stack/40
1 parent 4ddfba2 commit 3af8a82

File tree

7 files changed

+254
-105
lines changed

7 files changed

+254
-105
lines changed

helion/language/creation_ops.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,16 @@ def _(
150150
value: float,
151151
dtype: torch.dtype = torch.float32,
152152
) -> torch.Tensor:
153-
processed_shape = [s.stop - s.start if isinstance(s, slice) else s for s in shape]
153+
from .tile_proxy import RefTile
154+
processed_shape = []
155+
for s in shape:
156+
if isinstance(s, RefTile):
157+
# RefTile is a slice subclass with a block_size property
158+
processed_shape.append(s.block_size)
159+
elif isinstance(s, slice):
160+
processed_shape.append(s.stop - s.start)
161+
else:
162+
processed_shape.append(s)
154163
return torch.full(processed_shape, value, dtype=dtype, device="cuda")
155164

156165

helion/language/loops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from ..autotuner.config_spec import StaticRangeSpec
4545
from . import _decorators
4646
from .tile_proxy import Tile
47+
from .tile_proxy import RefTile
4748

4849
if TYPE_CHECKING:
4950
from collections.abc import Sequence
@@ -455,7 +456,7 @@ def _(
455456
begin_or_end: int | torch.Tensor | list[int | torch.Tensor],
456457
end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None,
457458
block_size: int | torch.Tensor | list[int | torch.Tensor] | None = None,
458-
) -> Iterator[slice | tuple[slice, ...]]:
459+
) -> Iterator[RefTile | tuple[RefTile, ...]]:
459460
# Convert tensor values to int
460461
def _to_int(value):
461462
if value is None:
@@ -511,14 +512,14 @@ def _normalize_to_list(
511512
e = end_list[0]
512513
bs = block_size_list[0]
513514
for i in range(b, e, bs):
514-
yield slice(i, min(i + bs, e))
515+
yield RefTile(i, min(i + bs, e))
515516
else:
516517
# Handle multi-dimensional case
517518
ranges = []
518519
for b, e, bs in zip(begin_list, end_list, block_size_list, strict=False):
519520
dim_ranges = []
520521
for i in range(b, e, bs):
521-
dim_ranges.append(slice(i, min(i + bs, e)))
522+
dim_ranges.append(RefTile(i, min(i + bs, e)))
522523
ranges.append(dim_ranges)
523524

524525
for combo in itertools.product(*ranges):

helion/language/memory_ops.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,13 @@ def _handle_mixed_indices(
8585
for i, idx in enumerate(indices):
8686
if isinstance(idx, slice):
8787
# Handle slice indices
88-
shape_size = idx.stop - idx.start
88+
if idx.start is None and idx.stop is None:
89+
# Full slice like `:`
90+
shape_size = tensor_shape[i] if i < len(tensor_shape) else 1
91+
else:
92+
start = idx.start or 0
93+
stop = idx.stop or (tensor_shape[i] if i < len(tensor_shape) else 1)
94+
shape_size = stop - start
8995
expected_shape.append(shape_size)
9096
actual_indices.append(idx)
9197
elif isinstance(idx, torch.Tensor):
@@ -204,6 +210,16 @@ def _(
204210
value: torch.Tensor,
205211
extra_mask: torch.Tensor | None = None,
206212
) -> None:
213+
# Convert RefTile objects to slices
214+
from .tile_proxy import RefTile
215+
processed_indices = []
216+
for idx in indices:
217+
if isinstance(idx, RefTile):
218+
processed_indices.append(idx._slice)
219+
else:
220+
processed_indices.append(idx)
221+
indices = processed_indices
222+
207223
normalized_indices = _normalize_indices(indices)
208224

209225
if extra_mask is not None:
@@ -269,6 +285,16 @@ def _(
269285
other = 0
270286

271287
assert isinstance(indices, (list, tuple))
288+
289+
# Convert RefTile objects to slices
290+
from .tile_proxy import RefTile
291+
processed_indices = []
292+
for idx in indices:
293+
if isinstance(idx, RefTile):
294+
processed_indices.append(idx._slice)
295+
else:
296+
processed_indices.append(idx)
297+
indices = processed_indices
272298

273299
# Case 1: Single tensor index (jagged indexing)
274300
if len(indices) == 1 and isinstance(indices[0], torch.Tensor):
@@ -400,6 +426,16 @@ def _(
400426
value: torch.Tensor | float,
401427
sem: str = "relaxed",
402428
) -> None:
429+
# Convert RefTile objects to slices
430+
from .tile_proxy import RefTile
431+
processed_indices = []
432+
for idx in indices:
433+
if isinstance(idx, RefTile):
434+
processed_indices.append(idx._slice)
435+
else:
436+
processed_indices.append(idx)
437+
indices = processed_indices
438+
403439
# Special handling for scatter-add pattern (`tensor[tensor_idx, slice] += value`)
404440
if isinstance(indices, (list, tuple)) and len(indices) == 2:
405441
idx0, idx1 = indices

helion/language/tile_ops.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,16 @@ def _(state: CodegenState) -> ast.AST:
4949

5050

5151
@_decorators.ref(tile_index)
52-
def _(tile: slice) -> torch.Tensor:
52+
def _(tile: slice | int) -> torch.Tensor:
5353
# Handle different tile representations in ref mode
54-
return torch.arange(tile.start, tile.stop, dtype=torch.int64, device="cuda")
54+
from .tile_proxy import RefTile
55+
if isinstance(tile, RefTile):
56+
return tile.index
57+
elif isinstance(tile, slice):
58+
return torch.arange(tile.start, tile.stop, dtype=torch.int64, device="cuda")
59+
else:
60+
# tiles_as_sizes=True means we get an int
61+
return torch.arange(0, tile, dtype=torch.int64, device="cuda")
5562

5663

5764
@_decorators.api(tiles_as_sizes=True)
@@ -91,7 +98,10 @@ def _(state: CodegenState) -> ast.AST:
9198
@_decorators.ref(tile_begin)
9299
def _(tile: int | slice) -> int:
93100
# Handle different tile representations in ref mode
94-
if isinstance(tile, slice):
101+
from .tile_proxy import RefTile
102+
if isinstance(tile, RefTile):
103+
return tile.begin
104+
elif isinstance(tile, slice):
95105
return tile.start
96106
# In ref mode with tiles_as_sizes=True, we lost the begin info
97107
# This is a limitation - we return 0 as we don't know the actual begin
@@ -140,7 +150,10 @@ def _(state: CodegenState) -> ast.AST:
140150
@_decorators.ref(tile_end)
141151
def _(tile: int | slice) -> int:
142152
# Handle different tile representations in ref mode
143-
if isinstance(tile, slice):
153+
from .tile_proxy import RefTile
154+
if isinstance(tile, RefTile):
155+
return tile.end
156+
elif isinstance(tile, slice):
144157
return tile.stop
145158
# In ref mode with tiles_as_sizes=True, we get the size
146159
# We lost the begin info, so we assume end = size
@@ -168,7 +181,10 @@ def _(tile: torch.SymInt) -> torch.SymInt:
168181
@_decorators.ref(tile_block_size)
169182
def _(tile: int | slice) -> int:
170183
# Handle different tile representations in ref mode
171-
if isinstance(tile, slice):
184+
from .tile_proxy import RefTile
185+
if isinstance(tile, RefTile):
186+
return tile.block_size
187+
elif isinstance(tile, slice):
172188
return tile.stop - tile.start
173189
# In ref mode with tiles_as_sizes=True, the tile IS the size
174190
return tile
@@ -206,5 +222,8 @@ def _(state: CodegenState) -> ast.AST:
206222
@_decorators.ref(tile_id)
207223
def _(tile: int | slice) -> int:
208224
# tile_id is the index of the tile in the grid
225+
from .tile_proxy import RefTile
226+
if isinstance(tile, RefTile):
227+
return tile.id
209228
# For ref mode we don't have the original block_size, so we return 0
210229
return 0

helion/language/tile_proxy.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from torch.utils._pytree import tree_map_only
13+
import functools
1314

1415
from .. import exc
1516
from .._compiler.compile_environment import CompileEnvironment
@@ -182,3 +183,147 @@ def __enter__(self) -> Self:
182183

183184
def __exit__(self, *args: object) -> None:
184185
_tls.index_calls = None
186+
187+
188+
class RefTile(torch.Tensor):
189+
"""
190+
A tile-like object used in reference eager mode that behaves like a slice.
191+
This allows tile.index and other tile operations to work properly in ref eager mode.
192+
"""
193+
194+
def __new__(cls, start: int, stop: int, step: int | None = None):
195+
# Create a tensor instance
196+
instance = super().__new__(cls)
197+
return instance
198+
199+
def __init__(self, start: int, stop: int, step: int | None = None) -> None:
200+
super().__init__()
201+
# Store slice data
202+
self.start = start
203+
self.stop = stop
204+
self.step = step
205+
self._slice = slice(start, stop, step)
206+
# We need to set block_id to something for compatibility
207+
self.block_id = -1 # Special value for ref mode
208+
209+
@property
210+
def index(self) -> torch.Tensor:
211+
"""Return a tensor containing the offsets for this tile."""
212+
return torch.arange(
213+
self.start,
214+
self.stop,
215+
dtype=torch.int64,
216+
device="cuda"
217+
)
218+
219+
@property
220+
def begin(self) -> int:
221+
"""Return the start offset of this tile."""
222+
return self.start
223+
224+
@property
225+
def end(self) -> int:
226+
"""Return the end offset of this tile."""
227+
return self.stop
228+
229+
@property
230+
def block_size(self) -> int:
231+
"""Return the block size of this tile."""
232+
return self.stop - self.start
233+
234+
@property
235+
def id(self) -> int:
236+
"""Return the id of this tile (always 0 in ref mode)."""
237+
# We don't have enough info to compute the actual tile id
238+
return 0
239+
240+
def __repr__(self) -> str:
241+
return f"RefTile({self._slice!r})"
242+
243+
def __int__(self) -> int:
244+
"""Convert to int for cases where a size is expected."""
245+
return self.block_size
246+
247+
# Make RefTile usable as an index by delegating to the slice
248+
def indices(self, length: int) -> tuple[int, int, int]:
249+
"""Return (start, stop, step) tuple, like slice.indices()."""
250+
return self._slice.indices(length)
251+
252+
def __eq__(self, other: object) -> bool:
253+
"""Compare with other RefTile or slice objects."""
254+
if isinstance(other, RefTile):
255+
return self._slice == other._slice
256+
elif isinstance(other, slice):
257+
return self._slice == other
258+
return False
259+
260+
def __hash__(self) -> int:
261+
"""Hash based on the slice."""
262+
return hash(self._slice)
263+
264+
def __index__(self) -> int:
265+
"""Convert to int for use in tensor indexing.
266+
267+
This is called when RefTile is used in advanced indexing contexts.
268+
We return the start value which works for single-element tiles.
269+
"""
270+
# For single-element access (when block_size=1), return the index
271+
if self.block_size == 1:
272+
return self.start
273+
# For larger tiles, we can't meaningfully convert to a single index
274+
# This might happen in user lambdas trying to do advanced indexing
275+
raise TypeError(f"Cannot convert RefTile with block_size={self.block_size} to index")
276+
277+
@classmethod
278+
def __torch_function__(
279+
cls,
280+
func: Callable[..., object],
281+
types: object,
282+
args: tuple[object, ...] = (),
283+
kwargs: dict[str, object] | None = None,
284+
) -> object:
285+
from ..language.memory_ops import load
286+
from ..language.memory_ops import store
287+
288+
if func is torch.Tensor.__getitem__:
289+
if len(args) != 2 or kwargs:
290+
raise exc.IncorrectTileUsage(func)
291+
tensor, index = args
292+
assert isinstance(tensor, torch.Tensor)
293+
294+
# If a single RefTile is used as index, we want to use it as a slice
295+
# e.g., tensor[ref_tile] should behave like tensor[ref_tile._slice]
296+
if isinstance(index, RefTile):
297+
return tensor[index._slice]
298+
299+
# For multi-dimensional indexing (including lists)
300+
return load(tensor, cls._prepare_index(index))
301+
302+
if func is torch.Tensor.__setitem__:
303+
if len(args) != 3 or kwargs:
304+
raise exc.IncorrectTileUsage(func)
305+
tensor, index, value = args
306+
assert isinstance(tensor, torch.Tensor)
307+
assert isinstance(value, torch.Tensor)
308+
309+
# Similar handling for setitem
310+
if isinstance(index, RefTile):
311+
tensor[index._slice] = value
312+
return None
313+
314+
return store(tensor, cls._prepare_index(index), value)
315+
316+
if func is torch.Tensor.__format__:
317+
return repr(args[0])
318+
raise exc.IncorrectTileUsage(func)
319+
320+
@staticmethod
321+
def _prepare_index(index: object) -> list[object]:
322+
if isinstance(index, (list, tuple)):
323+
# When indexing with a list of RefTiles like bias[[tile_m, tile_n]],
324+
# we want it to be interpreted as bias[tile_m, tile_n]
325+
# So we return the list as-is for multi-dimensional indexing
326+
return [*index]
327+
assert isinstance(index, RefTile)
328+
return [index]
329+

test/ref_utils.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)