75
75
require_torch_2 ,
76
76
require_torch_accelerator ,
77
77
require_torch_accelerator_with_training ,
78
- require_torch_gpu ,
79
78
require_torch_multi_accelerator ,
80
79
require_torch_version_greater ,
81
80
run_test_in_subprocess ,
@@ -1829,8 +1828,8 @@ def test_wrong_device_map_raises_error(self, device_map, msg_substring):
1829
1828
1830
1829
assert msg_substring in str (err_ctx .exception )
1831
1830
1832
- @parameterized .expand ([0 , "cuda" , torch .device ("cuda" )])
1833
- @require_torch_gpu
1831
+ @parameterized .expand ([0 , torch_device , torch .device (torch_device )])
1832
+ @require_torch_accelerator
1834
1833
def test_passing_non_dict_device_map_works (self , device_map ):
1835
1834
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1836
1835
model = self .model_class (** init_dict ).eval ()
@@ -1839,8 +1838,8 @@ def test_passing_non_dict_device_map_works(self, device_map):
1839
1838
loaded_model = self .model_class .from_pretrained (tmpdir , device_map = device_map )
1840
1839
_ = loaded_model (** inputs_dict )
1841
1840
1842
- @parameterized .expand ([("" , "cuda" ), ("" , torch .device ("cuda" ))])
1843
- @require_torch_gpu
1841
+ @parameterized .expand ([("" , torch_device ), ("" , torch .device (torch_device ))])
1842
+ @require_torch_accelerator
1844
1843
def test_passing_dict_device_map_works (self , name , device ):
1845
1844
# There are other valid dict-based `device_map` values too. It's best to refer to
1846
1845
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
@@ -1945,7 +1944,7 @@ def test_push_to_hub_library_name(self):
1945
1944
delete_repo (self .repo_id , token = TOKEN )
1946
1945
1947
1946
1948
- @require_torch_gpu
1947
+ @require_torch_accelerator
1949
1948
@require_torch_2
1950
1949
@is_torch_compile
1951
1950
@slow
@@ -2013,7 +2012,7 @@ def test_compile_with_group_offloading(self):
2013
2012
model .eval ()
2014
2013
# TODO: Can test for other group offloading kwargs later if needed.
2015
2014
group_offload_kwargs = {
2016
- "onload_device" : "cuda" ,
2015
+ "onload_device" : torch_device ,
2017
2016
"offload_device" : "cpu" ,
2018
2017
"offload_type" : "block_level" ,
2019
2018
"num_blocks_per_group" : 1 ,
0 commit comments