Skip to content

Commit ce8f748

Browse files
committed
gather cuda
1 parent 68a6130 commit ce8f748

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/tools/gather.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,22 @@ def main(args):
1717
print('---------------------------------------------')
1818
print(' All Gather Method ')
1919
print('---------------------------------------------')
20-
local_tensor = torch.tensor([torch.rand(world_size)[local_rank]])
20+
local_tensor = torch.tensor([torch.rand(world_size)[local_rank]]).cuda()
2121

22-
print(f'cuda_device({torch.cuda.current_device()}) : local({local_tensor})')
22+
print(f'cuda_device(local({local_tensor})')
2323
tensor = all_gather(local_tensor)
24-
print(f'cuda_device({torch.cuda.current_device()}) : gather({tensor})')
24+
print(f'cuda_device(gather({tensor})')
2525

2626
synchronize()
2727

2828
if is_main_process() :
2929
print('---------------------------------------------')
3030
print(' Gather Method ')
3131
print('---------------------------------------------')
32-
local_tensor = torch.tensor([torch.rand(world_size)[local_rank]])
33-
print(f'cuda_device({torch.cuda.current_device()}) : local({local_tensor})')
32+
local_tensor = torch.tensor([torch.rand(world_size)[local_rank]]).cuda()
33+
print(f'cuda_device(local({local_tensor})')
3434
tensor = gather(local_tensor, dst=0)
35-
print(f'cuda_device({torch.cuda.current_device()}) : gather({tensor})')
35+
print(f'cuda_device(gather({tensor})')
3636

3737
if __name__ == "__main__":
3838
args = default_argument_parser().parse_args()

0 commit comments

Comments
 (0)