@@ -172,6 +172,32 @@ def test_gradient_checkpointing_is_applied(self):
172
172
expected_set = {"FluxTransformer2DModel" }
173
173
super ().test_gradient_checkpointing_is_applied (expected_set = expected_set )
174
174
175
+ # The test exists for cases like
176
+ # https://github.com/huggingface/diffusers/issues/11874
177
+ def test_lora_exclude_modules (self ):
178
+ from peft import LoraConfig , get_peft_model_state_dict , inject_adapter_in_model , set_peft_model_state_dict
179
+
180
+ lora_rank = 4
181
+ target_module = "single_transformer_blocks.0.proj_out"
182
+ adapter_name = "foo"
183
+ init_dict , _ = self .prepare_init_args_and_inputs_for_common ()
184
+ model = self .model_class (** init_dict ).to (torch_device )
185
+
186
+ state_dict = model .state_dict ()
187
+ target_mod_shape = state_dict [f"{ target_module } .weight" ].shape
188
+ lora_state_dict = {
189
+ f"{ target_module } .lora_A.weight" : torch .ones (lora_rank , target_mod_shape [1 ]) * 22 ,
190
+ f"{ target_module } .lora_B.weight" : torch .ones (target_mod_shape [0 ], lora_rank ) * 33 ,
191
+ }
192
+ config = LoraConfig (
193
+ r = lora_rank , target_modules = ["single_transformer_blocks.0.proj_out" ], exclude_modules = ["proj_out" ]
194
+ )
195
+ inject_adapter_in_model (config , model , adapter_name = adapter_name , state_dict = lora_state_dict )
196
+ set_peft_model_state_dict (model , lora_state_dict , adapter_name )
197
+ retrieved_lora_state_dict = get_peft_model_state_dict (model , adapter_name = adapter_name )
198
+ assert (retrieved_lora_state_dict ["single_transformer_blocks.0.proj_out.lora_A.weight" ] == 22 ).all ()
199
+ assert (retrieved_lora_state_dict ["single_transformer_blocks.0.proj_out.lora_B.weight" ] == 33 ).all ()
200
+
175
201
176
202
class FluxTransformerCompileTests (TorchCompileTesterMixin , unittest .TestCase ):
177
203
model_class = FluxTransformer2DModel
0 commit comments