Skip to content

Commit 3f5b409

Browse files
committed
up
1 parent 04f2ff0 commit 3f5b409

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@
116116
"librosa",
117117
"numpy",
118118
"parameterized",
119-
"peft>=0.16.1",
119+
"peft>=0.15.0",
120+
# "peft>=0.16.1",
120121
"protobuf>=3.20.3,<4",
121122
"pytest",
122123
"pytest-timeout",

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"librosa": "librosa",
2424
"numpy": "numpy",
2525
"parameterized": "parameterized",
26-
"peft": "peft>=0.16.1",
26+
"peft": "peft>=0.15.0",
2727
"protobuf": "protobuf>=3.20.3,<4",
2828
"pytest": "pytest",
2929
"pytest-timeout": "pytest-timeout",

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,32 @@ def test_gradient_checkpointing_is_applied(self):
172172
expected_set = {"FluxTransformer2DModel"}
173173
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
174174

175+
# The test exists for cases like
176+
# https://github.com/huggingface/diffusers/issues/11874
177+
def test_lora_exclude_modules(self):
178+
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
179+
180+
lora_rank = 4
181+
target_module = "single_transformer_blocks.0.proj_out"
182+
adapter_name = "foo"
183+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
184+
model = self.model_class(**init_dict).to(torch_device)
185+
186+
state_dict = model.state_dict()
187+
target_mod_shape = state_dict[f"{target_module}.weight"].shape
188+
lora_state_dict = {
189+
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
190+
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
191+
}
192+
config = LoraConfig(
193+
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
194+
)
195+
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
196+
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
197+
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
198+
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
199+
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
200+
175201

176202
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
177203
model_class = FluxTransformer2DModel

0 commit comments

Comments
 (0)