@@ -23,7 +23,7 @@ def fp8_attention_kernel(
23
23
24
24
# Output tensor with 4D shape in FP8 format
25
25
out = torch .empty (
26
- [batch , heads , seq_len , head_dim ], dtype = torch .float8_e5m2 , device = q .device
26
+ [batch , heads , seq_len , head_dim ], dtype = torch .float8_e4m3fn , device = q .device
27
27
)
28
28
29
29
# Scale factor for attention
@@ -54,9 +54,7 @@ def fp8_attention_kernel(
54
54
k_tile_t = k_tile .transpose (0 , 1 ) # [dim, tile_n]
55
55
56
56
# Compute Q @ K^T with FP8 inputs, result in FP32
57
- qk = torch .matmul (q_tile , k_tile_t ).to (
58
- torch .float32
59
- ) # [tile_m, tile_n]
57
+ qk = hl .dot (q_tile , k_tile_t ) # [tile_m, tile_n]
60
58
61
59
# Scale QK scores first
62
60
qk_scaled = qk * sm_scale # [tile_m, tile_n]
@@ -90,28 +88,28 @@ def fp8_attention_kernel(
90
88
p_fp8 = p .to (v .dtype ) # Convert to same FP8 type as V
91
89
92
90
# Accumulate attention @ V with FP8 GEMM
93
- v_t = v_tile . transpose ( 0 , 1 ) # [tile_n, dim]
94
- pv = torch . matmul ( p_fp8 , v_t ). to ( torch . float32 ) # [tile_m , dim]
95
- acc = acc + pv
91
+ # v_tile is [dim, tile_n], we need to transpose for P @ V^T
92
+ v_t = v_tile . t () # [tile_n , dim]
93
+ acc = hl . dot ( p_fp8 , v_t , acc = acc ) # [tile_m, dim]
96
94
97
95
# Update max tracker
98
96
m_i = m_new
99
97
100
98
# Final normalization
101
99
acc = acc / l_i [:, None ]
102
100
# Convert to FP8 before writing to output
103
- out [b , h , tile_m , :] = acc .to (torch .float8_e5m2 )
101
+ out [b , h , tile_m , :] = acc .to (torch .float8_e4m3fn )
104
102
105
103
return out
106
104
107
105
108
106
def preprocess_fp8_attention_inputs (
109
107
q : torch .Tensor , k : torch .Tensor , v : torch .Tensor
110
108
) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
111
- q_fp8 = q .to (torch .float8_e5m2 )
112
- k_fp8 = k .to (torch .float8_e5m2 )
109
+ q_fp8 = q .to (torch .float8_e4m3fn )
110
+ k_fp8 = k .to (torch .float8_e4m3fn )
113
111
v = v .permute (0 , 1 , 3 , 2 )
114
- v_fp8 = v .to (torch .float8_e5m2 )
112
+ v_fp8 = v .to (torch .float8_e4m3fn )
115
113
batch , heads , seq_len , head_dim = q .shape
116
114
q_fp8_reshaped = q_fp8 .reshape (batch * heads , seq_len , head_dim )
117
115
k_fp8_reshaped = k_fp8 .reshape (batch * heads , seq_len , head_dim )
@@ -147,13 +145,25 @@ def _fp8_attention_pytorch_impl(
147
145
k_i = k_fp8 [i ] # [seq, dim] - already FP8
148
146
v_i = v_fp8 [i ] # [dim, seq] - pre-transposed, already FP8
149
147
150
- # For Q @ K^T, we need K^T to be column-major
151
- kt_fp8 = k_i .t () # column-major [dim, seq]
152
-
153
- # Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
154
- q_deq = q_i .to (torch .float32 )
155
- kt_deq = kt_fp8 .to (torch .float32 )
156
- qk = torch .matmul (q_deq , kt_deq )
148
+ # For Q @ K^T using torch._scaled_mm
149
+ # torch._scaled_mm requires column-major for second operand
150
+ # k_i is [seq, dim], we need K^T as [dim, seq] in column-major
151
+ # Direct conversion: k_i -> contiguous -> transpose view
152
+ kt_fp8_col_major = k_i .contiguous ().t () # [dim, seq] in column-major
153
+
154
+ # Create scale tensors
155
+ scale_q = torch .tensor (1.0 , device = q_i .device )
156
+ scale_k = torch .tensor (1.0 , device = k_i .device )
157
+
158
+ # Q @ K^T using torch._scaled_mm
159
+ qk = torch ._scaled_mm (
160
+ q_i ,
161
+ kt_fp8_col_major ,
162
+ scale_q ,
163
+ scale_k ,
164
+ use_fast_accum = False ,
165
+ out_dtype = torch .float32 ,
166
+ )
157
167
158
168
# Compute max before scaling
159
169
qk_max = torch .amax (qk , dim = - 1 , keepdim = True )
@@ -168,16 +178,26 @@ def _fp8_attention_pytorch_impl(
168
178
# Step 2: Attention @ V using FP8
169
179
# P is [seq, seq], V is [dim, seq]
170
180
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
171
- p_fp8 = p_norm .to (torch .float8_e5m2 ) # row-major [seq, seq]
181
+ p_fp8 = p_norm .to (torch .float8_e4m3fn ) # row-major [seq, seq]
172
182
173
183
# v_i is [dim, seq], already FP8
174
- vt_fp8 = v_i .t () # column-major [seq, dim]
175
-
176
- # P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
177
- p_deq = p_fp8 .to (torch .float32 )
178
- vt_deq = vt_fp8 .to (torch .float32 )
179
- out_i = torch .matmul (p_deq , vt_deq )
180
- out_i = out_i .to (torch .float8_e5m2 ) # convert back to FP8
184
+ # Direct conversion: v_i -> contiguous -> transpose view
185
+ vt_fp8_col_major = v_i .contiguous ().t () # [seq, dim] in column-major
186
+
187
+ # Create scale tensors for P @ V^T
188
+ scale_p = torch .tensor (1.0 , device = p_fp8 .device )
189
+ scale_v = torch .tensor (1.0 , device = v_i .device )
190
+
191
+ # P @ V^T using torch._scaled_mm
192
+ out_i = torch ._scaled_mm (
193
+ p_fp8 ,
194
+ vt_fp8_col_major ,
195
+ scale_p ,
196
+ scale_v ,
197
+ use_fast_accum = False ,
198
+ out_dtype = torch .float32 ,
199
+ )
200
+ out_i = out_i .to (torch .float8_e4m3fn ) # convert back to FP8 to match kernel
181
201
182
202
outputs .append (out_i )
183
203
@@ -192,7 +212,7 @@ def fp8_attention_pytorch(
192
212
v : torch .Tensor , # [batch, heads, seq, dim]
193
213
) -> Callable [[], torch .Tensor ]:
194
214
"""
195
- Baseline PyTorch implementation of FP8 attention using FP8 e5m2 .
215
+ Baseline PyTorch implementation of FP8 attention using torch._scaled_mm .
196
216
"""
197
217
batch , heads , seq_len , head_dim = q .shape
198
218
q_fp8 , k_fp8 , v_fp8 = preprocess_fp8_attention_inputs (q , k , v )
0 commit comments