Skip to content

Commit 6832e6a

Browse files
committed
update
1 parent 3d2f8ae commit 6832e6a

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

src/diffusers/utils/torch_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
PyTorch utilities: Utilities related to PyTorch
1616
"""
1717

18+
import re
1819
from typing import List, Optional, Tuple, Union
1920

2021
from . import logging
@@ -195,3 +196,17 @@ def device_synchronize(device_type: Optional[str] = None):
195196
device_type = get_device()
196197
device_mod = getattr(torch, device_type, torch.cuda)
197198
device_mod.synchronize()
199+
200+
201+
def _find_modules_by_class_name(module: "torch.nn.Module", class_name: str) -> List[Tuple[str, "torch.nn.Module"]]:
202+
"""
203+
Recursively find all modules in a PyTorch module that match the specified class name. The class
204+
name could be partial/full name or a regex pattern.
205+
"""
206+
pattern = re.compile(class_name)
207+
matching_name_module_pairs = []
208+
for name, submodule in module.named_modules():
209+
submodule_cls = unwrap_module(submodule).__class__
210+
if pattern.search(submodule_cls.__name__):
211+
matching_name_module_pairs.append((name, submodule))
212+
return matching_name_module_pairs

0 commit comments

Comments
 (0)