Skip to content

Commit 18c8f10

Browse files
a-r-r-o-wDN6
andauthored
[refactor] Flux/Chroma single file implementation + Attention Dispatcher (#11916)
* update * update * add coauthor Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com> * improve test * handle ip adapter params correctly * fix chroma qkv fusion test * fix fastercache implementation * fix more tests * fight more tests * add back set_attention_backend * update * update * make style * make fix-copies * make ip adapter processor compatible with attention dispatcher * refactor chroma as well * remove rmsnorm assert * minify and deprecate npu/xla processors --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 7298bdd commit 18c8f10

24 files changed

+2329
-1006
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
[
164164
"AllegroTransformer3DModel",
165165
"AsymmetricAutoencoderKL",
166+
"AttentionBackendName",
166167
"AuraFlowTransformer2DModel",
167168
"AutoencoderDC",
168169
"AutoencoderKL",
@@ -238,6 +239,7 @@
238239
"VQModel",
239240
"WanTransformer3DModel",
240241
"WanVACETransformer3DModel",
242+
"attention_backend",
241243
]
242244
)
243245
_import_structure["modular_pipelines"].extend(
@@ -815,6 +817,7 @@
815817
from .models import (
816818
AllegroTransformer3DModel,
817819
AsymmetricAutoencoderKL,
820+
AttentionBackendName,
818821
AuraFlowTransformer2DModel,
819822
AutoencoderDC,
820823
AutoencoderKL,
@@ -889,6 +892,7 @@
889892
VQModel,
890893
WanTransformer3DModel,
891894
WanVACETransformer3DModel,
895+
attention_backend,
892896
)
893897
from .modular_pipelines import (
894898
ComponentsManager,

src/diffusers/hooks/faster_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020

21+
from ..models.attention import AttentionModuleMixin
2122
from ..models.attention_processor import Attention, MochiAttention
2223
from ..models.modeling_outputs import Transformer2DModelOutput
2324
from ..utils import logging
@@ -567,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
567568
_apply_faster_cache_on_denoiser(module, config)
568569

569570
for name, submodule in module.named_modules():
570-
if not isinstance(submodule, _ATTENTION_CLASSES):
571+
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
571572
continue
572573
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
573574
_apply_faster_cache_on_attention_class(name, submodule, config)

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020

21+
from ..models.attention import AttentionModuleMixin
2122
from ..models.attention_processor import Attention, MochiAttention
2223
from ..utils import logging
2324
from .hooks import HookRegistry, ModelHook
@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
227228
config.spatial_attention_block_skip_range = 2
228229

229230
for name, submodule in module.named_modules():
230-
if not isinstance(submodule, _ATTENTION_CLASSES):
231+
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
231232
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
232233
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
233234
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.

src/diffusers/loaders/ip_adapter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@
4040
from ..models.attention_processor import (
4141
AttnProcessor,
4242
AttnProcessor2_0,
43-
FluxAttnProcessor2_0,
44-
FluxIPAdapterJointAttnProcessor2_0,
4543
IPAdapterAttnProcessor,
4644
IPAdapterAttnProcessor2_0,
4745
IPAdapterXFormersAttnProcessor,
@@ -867,6 +865,9 @@ def unload_ip_adapter(self):
867865
>>> ...
868866
```
869867
"""
868+
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
869+
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
870+
870871
# remove CLIP image encoder
871872
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
872873
self.image_encoder = None
@@ -886,9 +887,9 @@ def unload_ip_adapter(self):
886887
# restore original Transformer attention processors layers
887888
attn_procs = {}
888889
for name, value in self.transformer.attn_processors.items():
889-
attn_processor_class = FluxAttnProcessor2_0()
890+
attn_processor_class = FluxAttnProcessor()
890891
attn_procs[name] = (
891-
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
892+
attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
892893
)
893894
self.transformer.set_attn_processor(attn_procs)
894895

src/diffusers/loaders/transformer_flux.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8686
return image_projection
8787

8888
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
89-
from ..models.attention_processor import (
90-
FluxIPAdapterJointAttnProcessor2_0,
91-
)
89+
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
9290

9391
if low_cpu_mem_usage:
9492
if is_accelerate_available():
@@ -120,7 +118,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
120118
else:
121119
cross_attention_dim = self.config.joint_attention_dim
122120
hidden_size = self.inner_dim
123-
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
121+
attn_processor_class = FluxIPAdapterAttnProcessor
124122
num_image_text_embeds = []
125123
for state_dict in state_dicts:
126124
if "proj.weight" in state_dict["image_proj"]:

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
if is_torch_available():
2828
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
29+
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
2930
_import_structure["auto_model"] = ["AutoModel"]
3031
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
3132
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
@@ -112,6 +113,7 @@
112113
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
113114
if is_torch_available():
114115
from .adapter import MultiAdapter, T2IAdapter
116+
from .attention_dispatch import AttentionBackendName, attention_backend
115117
from .auto_model import AutoModel
116118
from .autoencoders import (
117119
AsymmetricAutoencoderKL,

0 commit comments

Comments
 (0)