Skip to content

One shot all reduce Example #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: joydddd/stack/13
Choose a base branch
from
Open

Conversation

joydddd
Copy link
Contributor

@joydddd joydddd commented Jul 8, 2025

joydddd added a commit that referenced this pull request Jul 8, 2025
stack-info: PR: #245, branch: joydddd/stack/12
@joydddd joydddd force-pushed the joydddd/stack/12 branch from cd18d79 to 19e6b5c Compare July 8, 2025 19:07
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 8, 2025
@joydddd joydddd changed the base branch from joydddd/stack/11 to main July 8, 2025 20:23
joydddd added a commit that referenced this pull request Jul 8, 2025
stack-info: PR: #245, branch: joydddd/stack/12
@joydddd joydddd force-pushed the joydddd/stack/12 branch from 19e6b5c to 2c79dd9 Compare July 8, 2025 20:23
@joydddd joydddd changed the base branch from main to joydddd/stack/11 July 8, 2025 20:24
@joydddd joydddd changed the base branch from joydddd/stack/11 to main July 8, 2025 20:29
joydddd added a commit that referenced this pull request Jul 8, 2025
stack-info: PR: #245, branch: joydddd/stack/12
@joydddd joydddd force-pushed the joydddd/stack/12 branch from 2c79dd9 to 374dfa3 Compare July 8, 2025 20:29
@joydddd joydddd changed the base branch from main to joydddd/stack/11 July 8, 2025 20:29
@joydddd joydddd changed the base branch from joydddd/stack/11 to main July 9, 2025 18:19
joydddd added a commit that referenced this pull request Jul 9, 2025
stack-info: PR: #245, branch: joydddd/stack/12
@joydddd joydddd force-pushed the joydddd/stack/12 branch from 374dfa3 to 1887088 Compare July 9, 2025 18:20
@joydddd joydddd changed the base branch from main to joydddd/stack/11 July 9, 2025 18:20
@joydddd joydddd changed the base branch from joydddd/stack/11 to main July 9, 2025 19:51
joydddd added a commit that referenced this pull request Jul 9, 2025
stack-info: PR: #245, branch: joydddd/stack/12
@joydddd joydddd force-pushed the joydddd/stack/12 branch from 1887088 to f3199f8 Compare July 9, 2025 19:51
@joydddd joydddd changed the base branch from main to joydddd/stack/11 July 9, 2025 19:51
@joydddd joydddd force-pushed the joydddd/stack/11 branch from 6815f03 to 3f492ce Compare July 9, 2025 19:52
joydddd added a commit that referenced this pull request Jul 9, 2025
stack-info: PR: #245, branch: joydddd/stack/12
@joydddd joydddd force-pushed the joydddd/stack/12 branch from f3199f8 to 027a16b Compare July 9, 2025 19:52
@joydddd joydddd changed the base branch from joydddd/stack/11 to main July 9, 2025 21:27
joydddd added a commit that referenced this pull request Jul 9, 2025
stack-info: PR: #245, branch: joydddd/stack/12
@joydddd joydddd force-pushed the joydddd/stack/12 branch from 027a16b to 2f52133 Compare July 9, 2025 21:28
@joydddd joydddd changed the base branch from main to joydddd/stack/11 July 9, 2025 21:28
@joydddd joydddd changed the base branch from joydddd/stack/11 to main July 9, 2025 21:30
@jansel
Copy link
Contributor

jansel commented Jul 23, 2025

One shot All Reduce Performance on 8xH100.

Benchmark: joydddd/kraken@helion_bench

Performance drop at 1m due to hardcoded block_size.

shape dtype nccl helion_1shot triton_1shot dist_1shot Speedup over nccl Best Backend
(4k) torch.bfloat16 21.408 14.304 13.440 16.032 1.593 triton_1shot
(8k) torch.bfloat16 22.112 14.432 13.280 15.968 1.665 triton_1shot
(16k) torch.bfloat16 23.904 14.912 13.696 16.960 1.745 triton_1shot
(32k) torch.bfloat16 24.032 15.232 14.688 18.080 1.636 triton_1shot
(64k) torch.bfloat16 24.128 17.312 17.088 19.840 1.412 triton_1shot
(128k) torch.bfloat16 24.416 20.704 21.376 25.088 1.179 helion_1shot
(256k) torch.bfloat16 24.736 29.184 29.312 30.400 1.000 nccl
(512k) torch.bfloat16 34.016 47.488 42.016 43.328 1.000 nccl
(1m) torch.bfloat16 55.808 91.168 63.488 65.376 1.000 nccl

What explains the gap between triton and helion here? How are you autotuning?

@joydddd
Copy link
Contributor Author

joydddd commented Jul 23, 2025

One shot All Reduce Performance on 8xH100.

Benchmark: joydddd/kraken@helion_bench
Performance drop at 1m due to hardcoded block_size.
shape dtype nccl helion_1shot triton_1shot dist_1shot Speedup over nccl Best Backend
(4k) torch.bfloat16 21.408 14.304 13.440 16.032 1.593 triton_1shot
(8k) torch.bfloat16 22.112 14.432 13.280 15.968 1.665 triton_1shot
(16k) torch.bfloat16 23.904 14.912 13.696 16.960 1.745 triton_1shot
(32k) torch.bfloat16 24.032 15.232 14.688 18.080 1.636 triton_1shot
(64k) torch.bfloat16 24.128 17.312 17.088 19.840 1.412 triton_1shot
(128k) torch.bfloat16 24.416 20.704 21.376 25.088 1.179 helion_1shot
(256k) torch.bfloat16 24.736 29.184 29.312 30.400 1.000 nccl
(512k) torch.bfloat16 34.016 47.488 42.016 43.328 1.000 nccl
(1m) torch.bfloat16 55.808 91.168 63.488 65.376 1.000 nccl

What explains the gap between triton and helion here? How are you autotuning?

No, we are not autotuning. I'd expect some performance boost once that is ready.

Helion is launching an additional kernel to turn a pointer to dev_ptr array into a tensor.

https://github.com/pytorch-labs/helion/pull/245/files#diff-4e027817dc836f1b7a352b698bd5cfdaf58189b065e007c80ff7fffc80383891R17

There is unfortunately no native way to get signal_pad/buffer pointers as a tensor from symmetric memory handlers now. If we decide this is what we'll need, we can talk to Ke Wen and see if they can add support.

@joydddd joydddd changed the base branch from joydddd/stack/13 to main July 24, 2025 23:44
joydddd added a commit that referenced this pull request Jul 24, 2025
stack-info: PR: #245, branch: joydddd/stack/12
@joydddd joydddd force-pushed the joydddd/stack/12 branch from ee67366 to 3cd2ead Compare July 24, 2025 23:44
@joydddd joydddd changed the base branch from main to joydddd/stack/13 July 24, 2025 23:44
@joydddd joydddd changed the base branch from joydddd/stack/13 to main July 25, 2025 00:16
@joydddd joydddd force-pushed the joydddd/stack/12 branch from 3cd2ead to b6d33df Compare July 25, 2025 00:16
@joydddd joydddd changed the title One shot all reduce & symm mem sync One shot all reduce Example Jul 25, 2025
@joydddd joydddd changed the base branch from main to joydddd/stack/13 July 25, 2025 00:16
Comment on lines +15 to +26
# Symmemtric Memory Helpers
@triton.jit
def triton_copy(
inp: tl.int64, # pyright: ignore[reportInvalidTypeForm]
out: tl.tensor,
SIZE: tl.constexpr,
) -> None:
tl.static_assert(out.dtype.is_ptr())
inp = inp.to(tl.pointer_type(out.dtype.element_ty)) # pyright: ignore[reportAttributeAccessIssue]
addrs = tl.load(inp + tl.arange(0, SIZE))
tl.store(out + tl.arange(0, SIZE), addrs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workaround before symmetric memory handler natively supports getting dev_ptrs as a tensor, instead of a pointer to the array.

@joydddd joydddd changed the base branch from joydddd/stack/13 to main July 25, 2025 17:49
@joydddd joydddd force-pushed the joydddd/stack/12 branch from b6d33df to cad4646 Compare July 25, 2025 17:49
@joydddd joydddd changed the base branch from main to joydddd/stack/13 July 25, 2025 17:49
@joydddd joydddd changed the base branch from joydddd/stack/13 to main July 28, 2025 17:18
@joydddd joydddd force-pushed the joydddd/stack/12 branch from cad4646 to 9ca6ba5 Compare July 28, 2025 17:18
@joydddd joydddd changed the base branch from main to joydddd/stack/13 July 28, 2025 17:19
stack-info: PR: #245, branch: joydddd/stack/12
@joydddd joydddd changed the base branch from joydddd/stack/13 to main July 29, 2025 01:14
@joydddd joydddd force-pushed the joydddd/stack/12 branch from 9ca6ba5 to 7552c17 Compare July 29, 2025 01:14
@joydddd joydddd marked this pull request as ready for review July 29, 2025 01:17
@jansel
Copy link
Contributor

jansel commented Jul 29, 2025

There is unfortunately no native way to get signal_pad/buffer pointers as a tensor from symmetric memory handlers now. If we decide this is what we'll need, we can talk to Ke Wen and see if they can add support.

What are the baselines doing? Passing the pointers as kernel args? If that is needed for perf I think we can make that work.

@joydddd
Copy link
Contributor Author

joydddd commented Jul 29, 2025

There is unfortunately no native way to get signal_pad/buffer pointers as a tensor from symmetric memory handlers now. If we decide this is what we'll need, we can talk to Ke Wen and see if they can add support.

What are the baselines doing? Passing the pointers as kernel args? If that is needed for perf I think we can make that work.

The Triton baseline is passing
signal_pad_ptrs_dev - Device pointer to array of signal pad pointers _distributed_c10d.pyi:776
as a triton kernel arg, and then the number of signal pad.

If we want to keep the nice Helion abstraction of taking Tensors as arguments, we need to find a way to construct a Tensor out of shape, stride, & data_ptr directly without CPU sync or additional kernel launch.

Let me ask Natalia and Ke to see if they plan to add native support for getting the dev_ptrs array as a tensor.

@joydddd joydddd changed the base branch from main to joydddd/stack/13 July 29, 2025 17:39
@joydddd
Copy link
Contributor Author

joydddd commented Jul 29, 2025

There is unfortunately no native way to get signal_pad/buffer pointers as a tensor from symmetric memory handlers now. If we decide this is what we'll need, we can talk to Ke Wen and see if they can add support.

What are the baselines doing? Passing the pointers as kernel args? If that is needed for perf I think we can make that work.

See #393 for a more complete benchmark result & analysis/ after the multicastTensor implementation.

@joydddd joydddd changed the base branch from joydddd/stack/13 to main July 30, 2025 06:10
@joydddd joydddd changed the base branch from main to joydddd/stack/13 July 30, 2025 06:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants