Skip to content

Commit c354236

Browse files
committed
[Benchmark] Support kernel variants; setup matmul tritonbench integration
stack-info: PR: #380, branch: yf225/stack/42
1 parent 93a39e0 commit c354236

File tree

1 file changed

+134
-77
lines changed

1 file changed

+134
-77
lines changed

benchmarks/run.py

Lines changed: 134 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from typing import Callable
2727

2828
# Maps tritonbench op names to Helion kernel examples
29-
KERNEL_MAPPINGS: dict[str, tuple[str, str, str]] = {
29+
# Can map to a single kernel or a list of kernel variants
30+
KERNEL_MAPPINGS: dict[str, tuple[str, str, str] | tuple[str, list[tuple[str, str]]]] = {
3031
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
3132
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
3233
"embedding": (
@@ -80,6 +81,14 @@
8081
"examples.layer_norm",
8182
"layer_norm_fwd",
8283
),
84+
# Multiple kernel variants:
85+
"gemm": (
86+
"tritonbench.operators.gemm.operator",
87+
[
88+
("examples.matmul", "matmul"),
89+
("examples.matmul_split_k", "matmul_split_k"),
90+
],
91+
),
8392
}
8493

8594

@@ -210,7 +219,7 @@ def run_kernel(
210219
tritonbench_args: list[str],
211220
input_shard_info: tuple[int, int] | None = None,
212221
) -> None:
213-
"""Run a single kernel benchmark."""
222+
"""Run a kernel benchmark, handling both single and multiple variants."""
214223
# Check if kernel is in the mapping table
215224
if kernel_name not in KERNEL_MAPPINGS:
216225
print(f"Error: Unknown kernel '{kernel_name}'", file=sys.stderr)
@@ -219,25 +228,33 @@ def run_kernel(
219228
)
220229
sys.exit(1)
221230

222-
tritonbench_module, module_path, func_name = KERNEL_MAPPINGS[kernel_name]
231+
mapping = KERNEL_MAPPINGS[kernel_name]
232+
233+
# Normalize to list of variants format
234+
if len(mapping) == 2 and isinstance(mapping[1], list):
235+
# Multiple variants with shared tritonbench module
236+
tritonbench_module = mapping[0]
237+
variants = mapping[1]
238+
else:
239+
# Single kernel with full mapping - convert to list format
240+
assert len(mapping) == 3 # Type narrowing for pyright
241+
tritonbench_module, module_path, func_name = mapping
242+
variants = [(module_path, func_name)]
243+
244+
# Run all variants in the same benchmark
245+
run_kernel_variants(
246+
kernel_name, tritonbench_module, variants, tritonbench_args, input_shard_info
247+
)
223248

224-
# Import from the mapped module
225-
try:
226-
module = importlib.import_module(module_path)
227-
if not hasattr(module, func_name):
228-
print(
229-
f"Error: Module '{module_path}' does not have a function named '{func_name}'",
230-
file=sys.stderr,
231-
)
232-
sys.exit(1)
233-
kernel_func = getattr(module, func_name)
234-
except ImportError as e:
235-
print(
236-
f"Error: Could not import {func_name} from {module_path}", file=sys.stderr
237-
)
238-
print(f"Import error: {e}", file=sys.stderr)
239-
sys.exit(1)
240-
return
249+
250+
def run_kernel_variants(
251+
kernel_name: str,
252+
tritonbench_module: str,
253+
variants: list[tuple[str, str]],
254+
tritonbench_args: list[str],
255+
input_shard_info: tuple[int, int] | None = None,
256+
) -> None:
257+
"""Run kernel variants in the same benchmark run."""
241258

242259
# Import tritonbench components
243260
try:
@@ -260,19 +277,26 @@ def run_kernel(
260277
assert "--op" not in tritonbench_args
261278
tritonbench_args = ["--op", operator_name, *tritonbench_args]
262279

263-
# Get module's TRITONBENCH_ARGS if any
264-
module_args = getattr(module, "TRITONBENCH_ARGS", {})
280+
# Collect all module args from all variants
281+
all_module_args = {}
282+
for module_path, _ in variants:
283+
try:
284+
module = importlib.import_module(module_path)
285+
module_args = getattr(module, "TRITONBENCH_ARGS", {})
286+
all_module_args.update(module_args)
287+
except ImportError:
288+
pass
265289

266290
# Add module args to tritonbench_args if not already present
267-
for arg_name, arg_value in module_args.items():
291+
for arg_name, arg_value in all_module_args.items():
268292
arg_flag = f"--{arg_name.replace('_', '-')}"
269293
if arg_flag not in tritonbench_args:
270294
tritonbench_args.extend([arg_flag, str(arg_value)])
271295

272296
# Parse known args and collect unknown ones for operator
273297
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
274298

275-
# Import and run the operator
299+
# Import and get the operator class
276300
try:
277301
operator_module = importlib.import_module(tritonbench_module)
278302
Operator = operator_module.Operator
@@ -285,64 +309,97 @@ def run_kernel(
285309
print(f"Import error: {e}", file=sys.stderr)
286310
sys.exit(1)
287311

288-
# Create the benchmark method
289-
def helion_method(
290-
self: object,
291-
*args: object,
292-
) -> Callable[..., object]:
293-
"""Helion implementation."""
294-
295-
# Reset all Helion kernels before creating the benchmark function
296-
# so that each input size can go through its own autotuning.
297-
from helion.runtime.kernel import Kernel
298-
299-
for attr_name in dir(module):
300-
attr = getattr(module, attr_name)
301-
if isinstance(attr, Kernel):
302-
attr.reset()
303-
304-
def _inner() -> Callable[..., Any] | object:
305-
# Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
306-
# This ensures we run autotuning even if the kernel has pre-specified configs
307-
if os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") != "1":
308-
# Find all Kernel objects in the module and force autotuning
309-
for attr_name in dir(module):
310-
attr = getattr(module, attr_name)
311-
if isinstance(attr, Kernel):
312-
attr.settings.force_autotune = True
313-
314-
result = kernel_func(*args)
315-
if callable(result):
316-
return result()
317-
return result
318-
319-
return _inner
320-
321-
# Method name for the benchmark
322-
helion_method_name = f"helion_{kernel_name}"
323-
324312
# Import register_benchmark API
325313
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
326314
register_benchmark,
327315
)
328316

329-
# Use register_benchmark decorator
330-
decorated_method = register_benchmark(
331-
operator_name=operator_name,
332-
func_name=helion_method_name,
333-
baseline=False,
334-
enabled=True,
335-
fwd_only=False,
336-
label=helion_method_name,
337-
)(helion_method)
338-
339-
# Set the decorated method on the Operator class
340-
setattr(Operator, helion_method_name, decorated_method)
341-
342-
print(
343-
f"Running {operator_name} benchmark with Helion implementation...\n",
344-
file=sys.stderr,
345-
)
317+
# Register all variants as separate methods
318+
for module_path, func_name in variants:
319+
# Import the kernel function
320+
try:
321+
module = importlib.import_module(module_path)
322+
if not hasattr(module, func_name):
323+
print(
324+
f"Error: Module '{module_path}' does not have a function named '{func_name}'",
325+
file=sys.stderr,
326+
)
327+
continue
328+
kernel_func = getattr(module, func_name)
329+
except ImportError as e:
330+
print(
331+
f"Error: Could not import {func_name} from {module_path}",
332+
file=sys.stderr,
333+
)
334+
print(f"Import error: {e}", file=sys.stderr)
335+
continue
336+
337+
# Create the benchmark method closure to capture the correct module and function
338+
def create_helion_method(
339+
mod: Any, # noqa: ANN401
340+
kfunc: Callable[..., Any],
341+
) -> Callable[..., Any]:
342+
def helion_method(
343+
self: object,
344+
*args: object,
345+
) -> Callable[..., object]:
346+
"""Helion implementation."""
347+
348+
# Reset all Helion kernels before creating the benchmark function
349+
# so that each input size can go through its own autotuning.
350+
from helion.runtime.kernel import Kernel
351+
352+
for attr_name in dir(mod):
353+
attr = getattr(mod, attr_name)
354+
if isinstance(attr, Kernel):
355+
attr.reset()
356+
357+
def _inner() -> Callable[..., Any] | object:
358+
# Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
359+
# This ensures we run autotuning even if the kernel has pre-specified configs
360+
if os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") != "1":
361+
# Find all Kernel objects in the module and force autotuning
362+
for attr_name in dir(mod):
363+
attr = getattr(mod, attr_name)
364+
if isinstance(attr, Kernel):
365+
attr.settings.force_autotune = True
366+
367+
result = kfunc(*args)
368+
if callable(result):
369+
return result()
370+
return result
371+
372+
return _inner
373+
374+
return helion_method
375+
376+
# Method name for the benchmark
377+
variant_name = func_name
378+
helion_method_name = f"helion_{variant_name}"
379+
380+
# Use register_benchmark decorator
381+
decorated_method = register_benchmark(
382+
operator_name=operator_name,
383+
func_name=helion_method_name,
384+
baseline=False,
385+
enabled=True,
386+
fwd_only=False,
387+
label=helion_method_name,
388+
)(create_helion_method(module, kernel_func))
389+
390+
# Set the decorated method on the Operator class
391+
setattr(Operator, helion_method_name, decorated_method)
392+
393+
if len(variants) == 1:
394+
print(
395+
f"Running {operator_name} benchmark with Helion implementation...\n",
396+
file=sys.stderr,
397+
)
398+
else:
399+
print(
400+
f"Running {operator_name} benchmark with {len(variants)} Helion implementations...\n",
401+
file=sys.stderr,
402+
)
346403

347404
# Create and run the operator with unknown args
348405
op = Operator(tb_args=tb_args, extra_args=unknown_args)

0 commit comments

Comments
 (0)