Skip to content

Commit c01cc2f

Browse files
committed
[Ref Mode] PyTorch reference mode (eager only)
Part of #77. Please see inline code comments on the PR. stack-info: PR: #339, branch: yf225/stack/34
1 parent 01fee7b commit c01cc2f

20 files changed

+1035
-8
lines changed

helion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
from .runtime import Kernel
1212
from .runtime import kernel
1313
from .runtime import kernel as jit # alias
14+
from .runtime.settings import RefMode
1415
from .runtime.settings import Settings
1516
from .runtime.settings import set_default_settings
1617

1718
__all__ = [
1819
"Config",
1920
"Kernel",
21+
"RefMode",
2022
"Settings",
2123
"cdiv",
2224
"exc",

helion/_testing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ def code_and_output(
4545
args: tuple[object, ...],
4646
**kwargs: object,
4747
) -> tuple[str, object]:
48+
bound = fn.bind(args)
49+
from helion.runtime.settings import RefMode
50+
51+
if bound.kernel.settings.ref_mode != RefMode.OFF:
52+
result = fn(*args)
53+
# Return the original kernel source code
54+
code = inspect.getsource(fn.fn)
55+
return code, result
56+
4857
if kwargs:
4958
config = Config(
5059
**kwargs # pyright: ignore[reportArgumentType]

helion/language/_decorators.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class APIFunc(Protocol):
7979
_to_device_ir: Callable[..., object] | None
8080
_allow_host_tensor: bool
8181
_signature: inspect.Signature
82+
_ref_fn: Callable[..., object] | None
8283

8384
def __call__(self, *args: object, **kwargs: object) -> object: ...
8485

@@ -133,6 +134,15 @@ def api(
133134
def _impl(fn: _C) -> _C:
134135
@functools.wraps(fn)
135136
def wrapper(*args: object, **kwargs: object) -> object:
137+
from ..runtime.ref_mode import is_ref_mode_enabled
138+
139+
if is_ref_mode_enabled() and api._ref_fn is not None:
140+
# In ref mode, use the registered ref implementation
141+
bound = api._signature.bind(*args, **kwargs)
142+
bound.apply_defaults()
143+
flat_args = api._prepare_args(*bound.arguments.values())
144+
return api._ref_fn(*flat_args)
145+
136146
bound = api._signature.bind(*args, **kwargs)
137147
bound.apply_defaults()
138148
flat_args = api._prepare_args(*bound.arguments.values())
@@ -187,6 +197,7 @@ def wrapper(*args: object, **kwargs: object) -> object:
187197
api._signature = signature or inspect.signature(
188198
cast("Callable[..., object]", fn)
189199
)
200+
api._ref_fn = None
190201
return wrapper # pyright: ignore[reportReturnType]
191202

192203
return _impl
@@ -289,6 +300,22 @@ def _impl(to_device_ir_fn: Callable[..., object]) -> Callable[..., Never]:
289300
return _impl # pyright: ignore[reportReturnType]
290301

291302

303+
def ref(
304+
original_fn: Callable[..., object],
305+
) -> _NoReturnDecorator[object]:
306+
def _impl(ref_fn: Callable[..., object]) -> Callable[..., Never]:
307+
assert is_api_func(original_fn), (
308+
f"{register_ref.__qualname__} can only be used on API functions"
309+
)
310+
assert original_fn._ref_fn is None, (
311+
"ref mode implementation can only be registered once per function"
312+
)
313+
original_fn._ref_fn = ref_fn
314+
return _no_call
315+
316+
return _impl # pyright: ignore[reportReturnType]
317+
318+
292319
def _default_type_function(
293320
fake_fn: Callable[..., object], tiles_as_sizes: bool
294321
) -> Callable[..., TypeInfo]:

helion/language/constexpr.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,8 @@ def _(state: CodegenState) -> ast.AST:
9595
value = value.__int__()
9696
assert isinstance(value, int)
9797
return expr_from_string(repr(value))
98+
99+
100+
@_decorators.ref(specialize)
101+
def _(value: int | torch.SymInt) -> int:
102+
return int(value)

helion/language/creation_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,16 @@ def _(
144144
return None
145145

146146

147+
@_decorators.ref(full)
148+
def _(
149+
shape: list[int | slice],
150+
value: float,
151+
dtype: torch.dtype = torch.float32,
152+
) -> torch.Tensor:
153+
processed_shape = [s.stop - s.start if isinstance(s, slice) else s for s in shape]
154+
return torch.full(processed_shape, value, dtype=dtype, device="cuda")
155+
156+
147157
def arange(
148158
*args: int,
149159
dtype: torch.dtype | None = None,

helion/language/device_print.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,8 @@ def _(state: CodegenState) -> None:
9090
)
9191
stmt = create(ast.Expr, value=call_expr)
9292
state.add_statement(stmt)
93+
94+
95+
@_decorators.ref(device_print)
96+
def _(prefix: str, *args: object) -> None:
97+
print(prefix, *args)

helion/language/inline_asm_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,15 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]:
205205
]
206206

207207
return inline_asm_call
208+
209+
210+
@_decorators.ref(inline_asm_elementwise)
211+
def _(
212+
asm: str,
213+
constraints: str,
214+
args: Sequence[torch.Tensor],
215+
dtype: torch.dtype | Sequence[torch.dtype],
216+
is_pure: bool,
217+
pack: int,
218+
) -> torch.Tensor | tuple[torch.Tensor, ...]:
219+
raise NotImplementedError("inline_asm_elementwise is not supported in ref mode")

helion/language/loops.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ast
44
import builtins
55
import inspect
6+
import itertools
67
from itertools import starmap
78
from typing import TYPE_CHECKING
89
from typing import Iterator
@@ -449,6 +450,81 @@ def _(state: CodegenState) -> ast.AST:
449450
return _codegen_loop_helper(state)
450451

451452

453+
@_decorators.ref(tile)
454+
def _(
455+
begin_or_end: int | torch.Tensor | list[int | torch.Tensor],
456+
end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None,
457+
block_size: int | torch.Tensor | list[int | torch.Tensor] | None = None,
458+
) -> Iterator[slice | tuple[slice, ...]]:
459+
# Convert tensor values to int
460+
def _to_int(value):
461+
if value is None:
462+
return None
463+
if isinstance(value, torch.Tensor):
464+
return int(value.item())
465+
return int(value)
466+
467+
# Step 1: Normalize begin and end values based on the number of arguments
468+
if end_or_none is not None:
469+
# Two positional args: begin_or_end is begin, end_or_none is end
470+
begin = begin_or_end
471+
end = end_or_none
472+
else:
473+
# One positional arg: begin_or_end is end, begin defaults to 0
474+
end = begin_or_end
475+
# Create begin with same structure as end, but all zeros
476+
if isinstance(end, (list, tuple)):
477+
begin = [0] * len(end)
478+
else:
479+
begin = 0
480+
481+
# Step 2: Convert inputs to lists for uniform handling
482+
def _normalize_to_list(
483+
value: int | torch.Tensor | list[int | torch.Tensor],
484+
) -> list[int | torch.Tensor]:
485+
if isinstance(value, (list, tuple)):
486+
return list(value)
487+
return [value]
488+
489+
begin_list = _normalize_to_list(begin)
490+
end_list = _normalize_to_list(end)
491+
492+
# Convert all values to int
493+
begin_list = [_to_int(b) for b in begin_list]
494+
end_list = [_to_int(e) for e in end_list]
495+
496+
# Step 3: Determine block_size based on the arguments
497+
if block_size is None:
498+
# Default block_size to end - begin for each dimension
499+
block_size_list = [e - b for b, e in zip(begin_list, end_list, strict=False)]
500+
else:
501+
block_size_list = _normalize_to_list(block_size)
502+
block_size_list = [
503+
_to_int(bs) if bs is not None else (e - b)
504+
for bs, b, e in zip(block_size_list, begin_list, end_list, strict=False)
505+
]
506+
507+
# Step 4: Yield tile ranges
508+
# Handle single dimension case
509+
if len(begin_list) == 1:
510+
b = begin_list[0]
511+
e = end_list[0]
512+
bs = block_size_list[0]
513+
for i in range(b, e, bs):
514+
yield slice(i, min(i + bs, e))
515+
else:
516+
# Handle multi-dimensional case
517+
ranges = []
518+
for b, e, bs in zip(begin_list, end_list, block_size_list, strict=False):
519+
dim_ranges = []
520+
for i in range(b, e, bs):
521+
dim_ranges.append(slice(i, min(i + bs, e)))
522+
ranges.append(dim_ranges)
523+
524+
for combo in itertools.product(*ranges):
525+
yield combo
526+
527+
452528
def _codegen_loop_helper(
453529
state: CodegenState,
454530
) -> ast.AST:
@@ -637,6 +713,32 @@ def _(state: CodegenState) -> ast.AST:
637713
return _codegen_loop_helper(state)
638714

639715

716+
@_decorators.ref(grid)
717+
def _(
718+
begin_or_end: int | torch.Tensor | list[int | torch.Tensor],
719+
end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None,
720+
step: object = None,
721+
) -> range | Iterator[tuple[int, ...]]:
722+
# Similar to tile but yields indices instead of slices
723+
if end_or_none is not None:
724+
begin = begin_or_end
725+
end = end_or_none
726+
else:
727+
end = begin_or_end
728+
if isinstance(end, (list, tuple)):
729+
begin = [0] * len(end)
730+
else:
731+
begin = 0
732+
733+
# Handle single dimension
734+
if not isinstance(begin, (list, tuple)):
735+
return range(begin, end)
736+
737+
# Handle multi-dimensional
738+
ranges = list(itertools.starmap(range, zip(begin, end, strict=False)))
739+
return itertools.product(*ranges)
740+
741+
640742
@_decorators.device_func_replacement(builtins.zip)
641743
@_decorators.api(is_device_only=True, cache_type=True)
642744
def _zip_replacement(
@@ -898,3 +1000,14 @@ def _(
8981000

8991001
# Return tuple(range(...)) which will trigger existing tuple/list unrolling
9001002
return tuple(range(begin_val, end_val, step))
1003+
1004+
1005+
@_decorators.ref(static_range)
1006+
def _(
1007+
begin_or_end: int,
1008+
end_or_none: int | None = None,
1009+
step: int = 1,
1010+
) -> range:
1011+
if end_or_none is not None:
1012+
return range(begin_or_end, end_or_none, step)
1013+
return range(begin_or_end)

0 commit comments

Comments
 (0)