Skip to content

Commit 37c13ad

Browse files
committed
gather cuda rollback
1 parent ce8f748 commit 37c13ad

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/tools/gather.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,22 @@ def main(args):
1717
print('---------------------------------------------')
1818
print(' All Gather Method ')
1919
print('---------------------------------------------')
20-
local_tensor = torch.tensor([torch.rand(world_size)[local_rank]]).cuda()
20+
local_tensor = torch.tensor([torch.rand(world_size)[local_rank]])
2121

22-
print(f'cuda_device(local({local_tensor})')
22+
print(f'cuda_device({torch.cuda.current_device()}) : local({local_tensor})')
2323
tensor = all_gather(local_tensor)
24-
print(f'cuda_device(gather({tensor})')
24+
print(f'cuda_device({torch.cuda.current_device()}) : gather({tensor})')
2525

2626
synchronize()
2727

2828
if is_main_process() :
2929
print('---------------------------------------------')
3030
print(' Gather Method ')
3131
print('---------------------------------------------')
32-
local_tensor = torch.tensor([torch.rand(world_size)[local_rank]]).cuda()
33-
print(f'cuda_device(local({local_tensor})')
32+
local_tensor = torch.tensor([torch.rand(world_size)[local_rank]])
33+
print(f'cuda_device({torch.cuda.current_device()}) : local({local_tensor})')
3434
tensor = gather(local_tensor, dst=0)
35-
print(f'cuda_device(gather({tensor})')
35+
print(f'cuda_device({torch.cuda.current_device()}) : gather({tensor})')
3636

3737
if __name__ == "__main__":
3838
args = default_argument_parser().parse_args()

0 commit comments

Comments
 (0)