File tree Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -17,22 +17,22 @@ def main(args):
17
17
print ('---------------------------------------------' )
18
18
print (' All Gather Method ' )
19
19
print ('---------------------------------------------' )
20
- local_tensor = torch .tensor ([torch .rand (world_size )[local_rank ]]). cuda ()
20
+ local_tensor = torch .tensor ([torch .rand (world_size )[local_rank ]])
21
21
22
- print (f'cuda_device(local({ local_tensor } )' )
22
+ print (f'cuda_device({ torch . cuda . current_device () } ) : local({ local_tensor } )' )
23
23
tensor = all_gather (local_tensor )
24
- print (f'cuda_device(gather({ tensor } )' )
24
+ print (f'cuda_device({ torch . cuda . current_device () } ) : gather({ tensor } )' )
25
25
26
26
synchronize ()
27
27
28
28
if is_main_process () :
29
29
print ('---------------------------------------------' )
30
30
print (' Gather Method ' )
31
31
print ('---------------------------------------------' )
32
- local_tensor = torch .tensor ([torch .rand (world_size )[local_rank ]]). cuda ()
33
- print (f'cuda_device(local({ local_tensor } )' )
32
+ local_tensor = torch .tensor ([torch .rand (world_size )[local_rank ]])
33
+ print (f'cuda_device({ torch . cuda . current_device () } ) : local({ local_tensor } )' )
34
34
tensor = gather (local_tensor , dst = 0 )
35
- print (f'cuda_device(gather({ tensor } )' )
35
+ print (f'cuda_device({ torch . cuda . current_device () } ) : gather({ tensor } )' )
36
36
37
37
if __name__ == "__main__" :
38
38
args = default_argument_parser ().parse_args ()
You can’t perform that action at this time.
0 commit comments