Skip to content

Commit 890a01f

Browse files
authored
Merge branch 'main' into up-huggingface-hub
2 parents b686f98 + 3d2f8ae commit 890a01f

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ def __init__(self):
367367
def initialize_hook(self, module):
368368
def make_execution_order_update_callback(current_name, current_submodule):
369369
def callback():
370-
logger.debug(f"Adding {current_name} to the execution order")
370+
if not torch.compiler.is_compiling():
371+
logger.debug(f"Adding {current_name} to the execution order")
371372
self.execution_order.append((current_name, current_submodule))
372373

373374
return callback
@@ -404,12 +405,13 @@ def post_forward(self, module, output):
404405
# if the missing layers end up being executed in the future.
405406
if execution_order_module_names != self._layer_execution_tracker_module_names:
406407
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
407-
logger.warning(
408-
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
409-
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
410-
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
411-
f"{unexecuted_layers=}"
412-
)
408+
if not torch.compiler.is_compiling():
409+
logger.warning(
410+
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
411+
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
412+
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
413+
f"{unexecuted_layers=}"
414+
)
413415

414416
# Remove the layer execution tracker hooks from the submodules
415417
base_module_registry = module._diffusers_hook
@@ -437,7 +439,8 @@ def post_forward(self, module, output):
437439
for i in range(num_executed - 1):
438440
name1, _ = self.execution_order[i]
439441
name2, _ = self.execution_order[i + 1]
440-
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
442+
if not torch.compiler.is_compiling():
443+
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
441444
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
442445
group_offloading_hooks[i].next_group.onload_self = False
443446

0 commit comments

Comments
 (0)