@@ -367,7 +367,8 @@ def __init__(self):
367
367
def initialize_hook (self , module ):
368
368
def make_execution_order_update_callback (current_name , current_submodule ):
369
369
def callback ():
370
- logger .debug (f"Adding { current_name } to the execution order" )
370
+ if not torch .compiler .is_compiling ():
371
+ logger .debug (f"Adding { current_name } to the execution order" )
371
372
self .execution_order .append ((current_name , current_submodule ))
372
373
373
374
return callback
@@ -404,12 +405,13 @@ def post_forward(self, module, output):
404
405
# if the missing layers end up being executed in the future.
405
406
if execution_order_module_names != self ._layer_execution_tracker_module_names :
406
407
unexecuted_layers = list (self ._layer_execution_tracker_module_names - execution_order_module_names )
407
- logger .warning (
408
- "It seems like some layers were not executed during the forward pass. This may lead to problems when "
409
- "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
410
- "make sure that all layers are executed during the forward pass. The following layers were not executed:\n "
411
- f"{ unexecuted_layers = } "
412
- )
408
+ if not torch .compiler .is_compiling ():
409
+ logger .warning (
410
+ "It seems like some layers were not executed during the forward pass. This may lead to problems when "
411
+ "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
412
+ "make sure that all layers are executed during the forward pass. The following layers were not executed:\n "
413
+ f"{ unexecuted_layers = } "
414
+ )
413
415
414
416
# Remove the layer execution tracker hooks from the submodules
415
417
base_module_registry = module ._diffusers_hook
@@ -437,7 +439,8 @@ def post_forward(self, module, output):
437
439
for i in range (num_executed - 1 ):
438
440
name1 , _ = self .execution_order [i ]
439
441
name2 , _ = self .execution_order [i + 1 ]
440
- logger .debug (f"Applying lazy prefetch group offloading from { name1 } to { name2 } " )
442
+ if not torch .compiler .is_compiling ():
443
+ logger .debug (f"Applying lazy prefetch group offloading from { name1 } to { name2 } " )
441
444
group_offloading_hooks [i ].next_group = group_offloading_hooks [i + 1 ].group
442
445
group_offloading_hooks [i ].next_group .onload_self = False
443
446
0 commit comments