-
Notifications
You must be signed in to change notification settings - Fork 15
One shot all reduce Example #245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: joydddd/stack/13
Are you sure you want to change the base?
Conversation
stack-info: PR: #245, branch: joydddd/stack/12
stack-info: PR: #245, branch: joydddd/stack/12
stack-info: PR: #245, branch: joydddd/stack/12
stack-info: PR: #245, branch: joydddd/stack/12
stack-info: PR: #245, branch: joydddd/stack/12
stack-info: PR: #245, branch: joydddd/stack/12
stack-info: PR: #245, branch: joydddd/stack/12
What explains the gap between triton and helion here? How are you autotuning? |
No, we are not autotuning. I'd expect some performance boost once that is ready. Helion is launching an additional kernel to turn a pointer to dev_ptr array into a tensor. There is unfortunately no native way to get signal_pad/buffer pointers as a tensor from symmetric memory handlers now. If we decide this is what we'll need, we can talk to Ke Wen and see if they can add support. |
stack-info: PR: #245, branch: joydddd/stack/12
ee67366
to
3cd2ead
Compare
3cd2ead
to
b6d33df
Compare
# Symmemtric Memory Helpers | ||
@triton.jit | ||
def triton_copy( | ||
inp: tl.int64, # pyright: ignore[reportInvalidTypeForm] | ||
out: tl.tensor, | ||
SIZE: tl.constexpr, | ||
) -> None: | ||
tl.static_assert(out.dtype.is_ptr()) | ||
inp = inp.to(tl.pointer_type(out.dtype.element_ty)) # pyright: ignore[reportAttributeAccessIssue] | ||
addrs = tl.load(inp + tl.arange(0, SIZE)) | ||
tl.store(out + tl.arange(0, SIZE), addrs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Workaround before symmetric memory handler natively supports getting dev_ptrs
as a tensor, instead of a pointer to the array.
b6d33df
to
cad4646
Compare
cad4646
to
9ca6ba5
Compare
stack-info: PR: #245, branch: joydddd/stack/12
9ca6ba5
to
7552c17
Compare
What are the baselines doing? Passing the pointers as kernel args? If that is needed for perf I think we can make that work. |
The Triton baseline is passing If we want to keep the nice Helion abstraction of taking Tensors as arguments, we need to find a way to construct a Tensor out of shape, stride, & data_ptr directly without CPU sync or additional kernel launch. Let me ask Natalia and Ke to see if they plan to add native support for getting the dev_ptrs array as a tensor. |
See #393 for a more complete benchmark result & analysis/ after the multicastTensor implementation. |
Stacked PRs:
One shot all reduce Example