File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change 15
15
PyTorch utilities: Utilities related to PyTorch
16
16
"""
17
17
18
+ import re
18
19
from typing import List , Optional , Tuple , Union
19
20
20
21
from . import logging
@@ -195,3 +196,17 @@ def device_synchronize(device_type: Optional[str] = None):
195
196
device_type = get_device ()
196
197
device_mod = getattr (torch , device_type , torch .cuda )
197
198
device_mod .synchronize ()
199
+
200
+
201
+ def _find_modules_by_class_name (module : "torch.nn.Module" , class_name : str ) -> List [Tuple [str , "torch.nn.Module" ]]:
202
+ """
203
+ Recursively find all modules in a PyTorch module that match the specified class name. The class
204
+ name could be partial/full name or a regex pattern.
205
+ """
206
+ pattern = re .compile (class_name )
207
+ matching_name_module_pairs = []
208
+ for name , submodule in module .named_modules ():
209
+ submodule_cls = unwrap_module (submodule ).__class__
210
+ if pattern .search (submodule_cls .__name__ ):
211
+ matching_name_module_pairs .append ((name , submodule ))
212
+ return matching_name_module_pairs
You can’t perform that action at this time.
0 commit comments