Skip to content

Commit 89c1f24

Browse files
committed
Add hl.dot() API; Use hl.dot instead of torch.matmul for FP8 GEMM ops in Helion kernel
stack-info: PR: #356, branch: yf225/stack/39
1 parent 869b590 commit 89c1f24

File tree

9 files changed

+2036
-68
lines changed

9 files changed

+2036
-68
lines changed

examples/fp8_attention.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def fp8_attention_kernel(
2323

2424
# Output tensor with 4D shape in FP8 format
2525
out = torch.empty(
26-
[batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device
26+
[batch, heads, seq_len, head_dim], dtype=torch.float8_e4m3fn, device=q.device
2727
)
2828

2929
# Scale factor for attention
@@ -54,9 +54,7 @@ def fp8_attention_kernel(
5454
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]
5555

5656
# Compute Q @ K^T with FP8 inputs, result in FP32
57-
qk = torch.matmul(q_tile, k_tile_t).to(
58-
torch.float32
59-
) # [tile_m, tile_n]
57+
qk = hl.dot(q_tile, k_tile_t) # [tile_m, tile_n]
6058

6159
# Scale QK scores first
6260
qk_scaled = qk * sm_scale # [tile_m, tile_n]
@@ -90,28 +88,28 @@ def fp8_attention_kernel(
9088
p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V
9189

9290
# Accumulate attention @ V with FP8 GEMM
93-
v_t = v_tile.transpose(0, 1) # [tile_n, dim]
94-
pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim]
95-
acc = acc + pv
91+
# v_tile is [dim, tile_n], we need to transpose for P @ V^T
92+
v_t = v_tile.t() # [tile_n, dim]
93+
acc = hl.dot(p_fp8, v_t, acc=acc) # [tile_m, dim]
9694

9795
# Update max tracker
9896
m_i = m_new
9997

10098
# Final normalization
10199
acc = acc / l_i[:, None]
102100
# Convert to FP8 before writing to output
103-
out[b, h, tile_m, :] = acc.to(torch.float8_e5m2)
101+
out[b, h, tile_m, :] = acc.to(torch.float8_e4m3fn)
104102

105103
return out
106104

107105

108106
def preprocess_fp8_attention_inputs(
109107
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
110108
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
111-
q_fp8 = q.to(torch.float8_e5m2)
112-
k_fp8 = k.to(torch.float8_e5m2)
109+
q_fp8 = q.to(torch.float8_e4m3fn)
110+
k_fp8 = k.to(torch.float8_e4m3fn)
113111
v = v.permute(0, 1, 3, 2)
114-
v_fp8 = v.to(torch.float8_e5m2)
112+
v_fp8 = v.to(torch.float8_e4m3fn)
115113
batch, heads, seq_len, head_dim = q.shape
116114
q_fp8_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim)
117115
k_fp8_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim)
@@ -147,13 +145,25 @@ def _fp8_attention_pytorch_impl(
147145
k_i = k_fp8[i] # [seq, dim] - already FP8
148146
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8
149147

150-
# For Q @ K^T, we need K^T to be column-major
151-
kt_fp8 = k_i.t() # column-major [dim, seq]
152-
153-
# Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
154-
q_deq = q_i.to(torch.float32)
155-
kt_deq = kt_fp8.to(torch.float32)
156-
qk = torch.matmul(q_deq, kt_deq)
148+
# For Q @ K^T using torch._scaled_mm
149+
# torch._scaled_mm requires column-major for second operand
150+
# k_i is [seq, dim], we need K^T as [dim, seq] in column-major
151+
# Direct conversion: k_i -> contiguous -> transpose view
152+
kt_fp8_col_major = k_i.contiguous().t() # [dim, seq] in column-major
153+
154+
# Create scale tensors
155+
scale_q = torch.tensor(1.0, device=q_i.device)
156+
scale_k = torch.tensor(1.0, device=k_i.device)
157+
158+
# Q @ K^T using torch._scaled_mm
159+
qk = torch._scaled_mm(
160+
q_i,
161+
kt_fp8_col_major,
162+
scale_q,
163+
scale_k,
164+
use_fast_accum=False,
165+
out_dtype=torch.float32,
166+
)
157167

158168
# Compute max before scaling
159169
qk_max = torch.amax(qk, dim=-1, keepdim=True)
@@ -168,16 +178,26 @@ def _fp8_attention_pytorch_impl(
168178
# Step 2: Attention @ V using FP8
169179
# P is [seq, seq], V is [dim, seq]
170180
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
171-
p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq]
181+
p_fp8 = p_norm.to(torch.float8_e4m3fn) # row-major [seq, seq]
172182

173183
# v_i is [dim, seq], already FP8
174-
vt_fp8 = v_i.t() # column-major [seq, dim]
175-
176-
# P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
177-
p_deq = p_fp8.to(torch.float32)
178-
vt_deq = vt_fp8.to(torch.float32)
179-
out_i = torch.matmul(p_deq, vt_deq)
180-
out_i = out_i.to(torch.float8_e5m2) # convert back to FP8
184+
# Direct conversion: v_i -> contiguous -> transpose view
185+
vt_fp8_col_major = v_i.contiguous().t() # [seq, dim] in column-major
186+
187+
# Create scale tensors for P @ V^T
188+
scale_p = torch.tensor(1.0, device=p_fp8.device)
189+
scale_v = torch.tensor(1.0, device=v_i.device)
190+
191+
# P @ V^T using torch._scaled_mm
192+
out_i = torch._scaled_mm(
193+
p_fp8,
194+
vt_fp8_col_major,
195+
scale_p,
196+
scale_v,
197+
use_fast_accum=False,
198+
out_dtype=torch.float32,
199+
)
200+
out_i = out_i.to(torch.float8_e4m3fn) # convert back to FP8 to match kernel
181201

182202
outputs.append(out_i)
183203

@@ -192,7 +212,7 @@ def fp8_attention_pytorch(
192212
v: torch.Tensor, # [batch, heads, seq, dim]
193213
) -> Callable[[], torch.Tensor]:
194214
"""
195-
Baseline PyTorch implementation of FP8 attention using FP8 e5m2.
215+
Baseline PyTorch implementation of FP8 attention using torch._scaled_mm.
196216
"""
197217
batch, heads, seq_len, head_dim = q.shape
198218
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)

examples/fp8_gemm.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
from __future__ import annotations
22

3+
import os
4+
35
import torch
46

57
import helion
68
from helion._testing import run_example
79
import helion.language as hl
810

11+
# Override default config to work around Triton tl.dot requirement:
12+
# `AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 32`
13+
config = None
14+
if os.environ.get("HELION_USE_DEFAULT_CONFIG") == "1":
15+
config = helion.Config(block_sizes=[32, 32, 32])
16+
917

10-
@helion.kernel(static_shapes=True)
18+
@helion.kernel(static_shapes=True, config=config)
1119
def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1220
"""FP8 General Matrix Multiplication (GEMM).
1321
@@ -37,11 +45,8 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
3745
x_tile = x[tile_m, tile_k]
3846
y_tile = y[tile_k, tile_n]
3947

40-
# Use torch.matmul which will be lowered to tl.dot
41-
# When the inputs are FP8, tl.dot handles them natively
42-
# The result needs to be converted to FP32 for accumulation
43-
result = torch.matmul(x_tile, y_tile).to(torch.float32)
44-
acc = acc + result
48+
# Use hl.dot for FP8 GEMM
49+
acc = hl.dot(x_tile, y_tile, acc=acc)
4550
out[tile_m, tile_n] = acc.to(torch.float16)
4651

4752
return out

helion/_compiler/indexing_strategy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,14 @@ def codegen_load(
7070
extra_mask: ast.AST | None,
7171
) -> ast.AST:
7272
indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
73-
extra = ", other=0" if indexing.has_mask() else ""
73+
extra = ""
74+
if indexing.has_mask():
75+
# For FP8 dtypes, use other=0.0 (float literal) instead of other=0 (int literal)
76+
# because Triton cannot cast integer 0 to FP8 types
77+
if fake_tensor.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
78+
extra = ", other=0.0"
79+
else:
80+
extra = ", other=0"
7481
name = state.device_function.tensor_arg(fake_tensor).name
7582
return expr_from_string(
7683
f"tl.load({name} + offset, mask{extra})",

helion/_compiler/inductor_lowering.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -848,15 +848,19 @@ def reduce_3d_dot(
848848
rhs_node = node.args[1]
849849
assert isinstance(lhs, ast.AST)
850850
assert isinstance(rhs, ast.AST)
851+
assert isinstance(lhs_node, torch.fx.Node)
852+
assert isinstance(rhs_node, torch.fx.Node)
851853

852854
# Check if inputs are FP8 - if so, don't specify input_precision to allow native FP8 computation
853-
lhs_dtype = lhs_node.meta["val"].dtype # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
854-
rhs_dtype = rhs_node.meta["val"].dtype # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
855+
lhs_dtype = lhs_node.meta["val"].dtype
856+
rhs_dtype = rhs_node.meta["val"].dtype
855857
if lhs_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and rhs_dtype in [
856858
torch.float8_e4m3fn,
857859
torch.float8_e5m2,
858860
]:
859-
datatype = None # Let Triton use native FP8 computation
861+
raise NotImplementedError(
862+
"FP8 GEMM via torch API is not supported yet. Please use hl.dot() instead."
863+
)
860864

861865
lhs_size = lhs_node.meta["val"].size() # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
862866
rhs_size = rhs_node.meta["val"].size() # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
@@ -1138,7 +1142,7 @@ def proxy_arg(self, i: int) -> object:
11381142

11391143
def ast_arg(self, i: int) -> ast.AST:
11401144
rv = self.ast_args[i]
1141-
if isinstance(rv, int | float | bool):
1145+
if isinstance(rv, int | float | bool | None):
11421146
rv = ast.Constant(value=rv)
11431147
assert isinstance(rv, ast.AST), "TODO: convert nested/defaults"
11441148
return rv

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .loops import grid as grid
1111
from .loops import static_range as static_range
1212
from .loops import tile as tile
13+
from .matmul_ops import dot as dot
1314
from .memory_ops import atomic_add as atomic_add
1415
from .memory_ops import load as load
1516
from .memory_ops import store as store

0 commit comments

Comments
 (0)