26
26
from typing import Callable
27
27
28
28
# Maps tritonbench op names to Helion kernel examples
29
- KERNEL_MAPPINGS : dict [str , tuple [str , str , str ]] = {
29
+ # Can map to a single kernel or a list of kernel variants
30
+ KERNEL_MAPPINGS : dict [str , tuple [str , str , str ] | tuple [str , list [tuple [str , str ]]]] = {
30
31
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
31
32
"vector_add" : ("tritonbench.operators.vector_add.operator" , "examples.add" , "add" ),
32
33
"embedding" : (
80
81
"examples.layer_norm" ,
81
82
"layer_norm_fwd" ,
82
83
),
84
+ # Multiple kernel variants:
85
+ "gemm" : (
86
+ "tritonbench.operators.gemm.operator" ,
87
+ [
88
+ ("examples.matmul" , "matmul" ),
89
+ ("examples.matmul_split_k" , "matmul_split_k" ),
90
+ ],
91
+ ),
83
92
}
84
93
85
94
@@ -210,7 +219,7 @@ def run_kernel(
210
219
tritonbench_args : list [str ],
211
220
input_shard_info : tuple [int , int ] | None = None ,
212
221
) -> None :
213
- """Run a single kernel benchmark."""
222
+ """Run a kernel benchmark, handling both single and multiple variants ."""
214
223
# Check if kernel is in the mapping table
215
224
if kernel_name not in KERNEL_MAPPINGS :
216
225
print (f"Error: Unknown kernel '{ kernel_name } '" , file = sys .stderr )
@@ -219,25 +228,33 @@ def run_kernel(
219
228
)
220
229
sys .exit (1 )
221
230
222
- tritonbench_module , module_path , func_name = KERNEL_MAPPINGS [kernel_name ]
231
+ mapping = KERNEL_MAPPINGS [kernel_name ]
232
+
233
+ # Normalize to list of variants format
234
+ if len (mapping ) == 2 and isinstance (mapping [1 ], list ):
235
+ # Multiple variants with shared tritonbench module
236
+ tritonbench_module = mapping [0 ]
237
+ variants = mapping [1 ]
238
+ else :
239
+ # Single kernel with full mapping - convert to list format
240
+ assert len (mapping ) == 3 # Type narrowing for pyright
241
+ tritonbench_module , module_path , func_name = mapping
242
+ variants = [(module_path , func_name )]
243
+
244
+ # Run all variants in the same benchmark
245
+ run_kernel_variants (
246
+ kernel_name , tritonbench_module , variants , tritonbench_args , input_shard_info
247
+ )
223
248
224
- # Import from the mapped module
225
- try :
226
- module = importlib .import_module (module_path )
227
- if not hasattr (module , func_name ):
228
- print (
229
- f"Error: Module '{ module_path } ' does not have a function named '{ func_name } '" ,
230
- file = sys .stderr ,
231
- )
232
- sys .exit (1 )
233
- kernel_func = getattr (module , func_name )
234
- except ImportError as e :
235
- print (
236
- f"Error: Could not import { func_name } from { module_path } " , file = sys .stderr
237
- )
238
- print (f"Import error: { e } " , file = sys .stderr )
239
- sys .exit (1 )
240
- return
249
+
250
+ def run_kernel_variants (
251
+ kernel_name : str ,
252
+ tritonbench_module : str ,
253
+ variants : list [tuple [str , str ]],
254
+ tritonbench_args : list [str ],
255
+ input_shard_info : tuple [int , int ] | None = None ,
256
+ ) -> None :
257
+ """Run kernel variants in the same benchmark run."""
241
258
242
259
# Import tritonbench components
243
260
try :
@@ -260,19 +277,26 @@ def run_kernel(
260
277
assert "--op" not in tritonbench_args
261
278
tritonbench_args = ["--op" , operator_name , * tritonbench_args ]
262
279
263
- # Get module's TRITONBENCH_ARGS if any
264
- module_args = getattr (module , "TRITONBENCH_ARGS" , {})
280
+ # Collect all module args from all variants
281
+ all_module_args = {}
282
+ for module_path , _ in variants :
283
+ try :
284
+ module = importlib .import_module (module_path )
285
+ module_args = getattr (module , "TRITONBENCH_ARGS" , {})
286
+ all_module_args .update (module_args )
287
+ except ImportError :
288
+ pass
265
289
266
290
# Add module args to tritonbench_args if not already present
267
- for arg_name , arg_value in module_args .items ():
291
+ for arg_name , arg_value in all_module_args .items ():
268
292
arg_flag = f"--{ arg_name .replace ('_' , '-' )} "
269
293
if arg_flag not in tritonbench_args :
270
294
tritonbench_args .extend ([arg_flag , str (arg_value )])
271
295
272
296
# Parse known args and collect unknown ones for operator
273
297
tb_args , unknown_args = tb_parser .parse_known_args (tritonbench_args )
274
298
275
- # Import and run the operator
299
+ # Import and get the operator class
276
300
try :
277
301
operator_module = importlib .import_module (tritonbench_module )
278
302
Operator = operator_module .Operator
@@ -285,64 +309,97 @@ def run_kernel(
285
309
print (f"Import error: { e } " , file = sys .stderr )
286
310
sys .exit (1 )
287
311
288
- # Create the benchmark method
289
- def helion_method (
290
- self : object ,
291
- * args : object ,
292
- ) -> Callable [..., object ]:
293
- """Helion implementation."""
294
-
295
- # Reset all Helion kernels before creating the benchmark function
296
- # so that each input size can go through its own autotuning.
297
- from helion .runtime .kernel import Kernel
298
-
299
- for attr_name in dir (module ):
300
- attr = getattr (module , attr_name )
301
- if isinstance (attr , Kernel ):
302
- attr .reset ()
303
-
304
- def _inner () -> Callable [..., Any ] | object :
305
- # Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
306
- # This ensures we run autotuning even if the kernel has pre-specified configs
307
- if os .environ .get ("HELION_USE_DEFAULT_CONFIG" , "0" ) != "1" :
308
- # Find all Kernel objects in the module and force autotuning
309
- for attr_name in dir (module ):
310
- attr = getattr (module , attr_name )
311
- if isinstance (attr , Kernel ):
312
- attr .settings .force_autotune = True
313
-
314
- result = kernel_func (* args )
315
- if callable (result ):
316
- return result ()
317
- return result
318
-
319
- return _inner
320
-
321
- # Method name for the benchmark
322
- helion_method_name = f"helion_{ kernel_name } "
323
-
324
312
# Import register_benchmark API
325
313
from tritonbench .utils .triton_op import ( # pyright: ignore[reportMissingImports]
326
314
register_benchmark ,
327
315
)
328
316
329
- # Use register_benchmark decorator
330
- decorated_method = register_benchmark (
331
- operator_name = operator_name ,
332
- func_name = helion_method_name ,
333
- baseline = False ,
334
- enabled = True ,
335
- fwd_only = False ,
336
- label = helion_method_name ,
337
- )(helion_method )
338
-
339
- # Set the decorated method on the Operator class
340
- setattr (Operator , helion_method_name , decorated_method )
341
-
342
- print (
343
- f"Running { operator_name } benchmark with Helion implementation...\n " ,
344
- file = sys .stderr ,
345
- )
317
+ # Register all variants as separate methods
318
+ for module_path , func_name in variants :
319
+ # Import the kernel function
320
+ try :
321
+ module = importlib .import_module (module_path )
322
+ if not hasattr (module , func_name ):
323
+ print (
324
+ f"Error: Module '{ module_path } ' does not have a function named '{ func_name } '" ,
325
+ file = sys .stderr ,
326
+ )
327
+ continue
328
+ kernel_func = getattr (module , func_name )
329
+ except ImportError as e :
330
+ print (
331
+ f"Error: Could not import { func_name } from { module_path } " ,
332
+ file = sys .stderr ,
333
+ )
334
+ print (f"Import error: { e } " , file = sys .stderr )
335
+ continue
336
+
337
+ # Create the benchmark method closure to capture the correct module and function
338
+ def create_helion_method (
339
+ mod : Any , # noqa: ANN401
340
+ kfunc : Callable [..., Any ],
341
+ ) -> Callable [..., Any ]:
342
+ def helion_method (
343
+ self : object ,
344
+ * args : object ,
345
+ ) -> Callable [..., object ]:
346
+ """Helion implementation."""
347
+
348
+ # Reset all Helion kernels before creating the benchmark function
349
+ # so that each input size can go through its own autotuning.
350
+ from helion .runtime .kernel import Kernel
351
+
352
+ for attr_name in dir (mod ):
353
+ attr = getattr (mod , attr_name )
354
+ if isinstance (attr , Kernel ):
355
+ attr .reset ()
356
+
357
+ def _inner () -> Callable [..., Any ] | object :
358
+ # Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
359
+ # This ensures we run autotuning even if the kernel has pre-specified configs
360
+ if os .environ .get ("HELION_USE_DEFAULT_CONFIG" , "0" ) != "1" :
361
+ # Find all Kernel objects in the module and force autotuning
362
+ for attr_name in dir (mod ):
363
+ attr = getattr (mod , attr_name )
364
+ if isinstance (attr , Kernel ):
365
+ attr .settings .force_autotune = True
366
+
367
+ result = kfunc (* args )
368
+ if callable (result ):
369
+ return result ()
370
+ return result
371
+
372
+ return _inner
373
+
374
+ return helion_method
375
+
376
+ # Method name for the benchmark
377
+ variant_name = func_name
378
+ helion_method_name = f"helion_{ variant_name } "
379
+
380
+ # Use register_benchmark decorator
381
+ decorated_method = register_benchmark (
382
+ operator_name = operator_name ,
383
+ func_name = helion_method_name ,
384
+ baseline = False ,
385
+ enabled = True ,
386
+ fwd_only = False ,
387
+ label = helion_method_name ,
388
+ )(create_helion_method (module , kernel_func ))
389
+
390
+ # Set the decorated method on the Operator class
391
+ setattr (Operator , helion_method_name , decorated_method )
392
+
393
+ if len (variants ) == 1 :
394
+ print (
395
+ f"Running { operator_name } benchmark with Helion implementation...\n " ,
396
+ file = sys .stderr ,
397
+ )
398
+ else :
399
+ print (
400
+ f"Running { operator_name } benchmark with { len (variants )} Helion implementations...\n " ,
401
+ file = sys .stderr ,
402
+ )
346
403
347
404
# Create and run the operator with unknown args
348
405
op = Operator (tb_args = tb_args , extra_args = unknown_args )
0 commit comments