Skip to content

Commit 203dc52

Browse files
authored
[modular] add Modular flux for text-to-image (#11995)
* start flux. * more * up * up * up * up * get back the deleted files. * up * empathy
1 parent 56d4387 commit 203dc52

File tree

13 files changed

+1373
-5
lines changed

13 files changed

+1373
-5
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@
364364
else:
365365
_import_structure["modular_pipelines"].extend(
366366
[
367+
"FluxAutoBlocks",
368+
"FluxModularPipeline",
367369
"StableDiffusionXLAutoBlocks",
368370
"StableDiffusionXLModularPipeline",
369371
"WanAutoBlocks",
@@ -999,6 +1001,8 @@
9991001
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
10001002
else:
10011003
from .modular_pipelines import (
1004+
FluxAutoBlocks,
1005+
FluxModularPipeline,
10021006
StableDiffusionXLAutoBlocks,
10031007
StableDiffusionXLModularPipeline,
10041008
WanAutoBlocks,

src/diffusers/hooks/_helpers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _register(cls):
107107
def _register_attention_processors_metadata():
108108
from ..models.attention_processor import AttnProcessor2_0
109109
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110+
from ..models.transformers.transformer_flux import FluxAttnProcessor
110111
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
111112

112113
# AttnProcessor2_0
@@ -132,6 +133,11 @@ def _register_attention_processors_metadata():
132133
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
133134
),
134135
)
136+
# FluxAttnProcessor
137+
AttentionProcessorRegistry.register(
138+
model_class=FluxAttnProcessor,
139+
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
140+
)
135141

136142

137143
def _register_transformer_blocks_metadata():
@@ -271,4 +277,6 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
271277
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
272278
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
273279
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
280+
# not sure what this is yet.
281+
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
274282
# fmt: on

src/diffusers/modular_pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
]
4242
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
4343
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
44+
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
4445
_import_structure["components_manager"] = ["ComponentsManager"]
4546

4647
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -51,6 +52,7 @@
5152
from ..utils.dummy_pt_objects import * # noqa F403
5253
else:
5354
from .components_manager import ComponentsManager
55+
from .flux import FluxAutoBlocks, FluxModularPipeline
5456
from .modular_pipeline import (
5557
AutoPipelineBlocks,
5658
BlockState,
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
try:
17+
if not (is_transformers_available() and is_torch_available()):
18+
raise OptionalDependencyNotAvailable()
19+
except OptionalDependencyNotAvailable:
20+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
21+
22+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23+
else:
24+
_import_structure["encoders"] = ["FluxTextEncoderStep"]
25+
_import_structure["modular_blocks"] = [
26+
"ALL_BLOCKS",
27+
"AUTO_BLOCKS",
28+
"TEXT2IMAGE_BLOCKS",
29+
"FluxAutoBeforeDenoiseStep",
30+
"FluxAutoBlocks",
31+
"FluxAutoBlocks",
32+
"FluxAutoDecodeStep",
33+
"FluxAutoDenoiseStep",
34+
]
35+
_import_structure["modular_pipeline"] = ["FluxModularPipeline"]
36+
37+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
38+
try:
39+
if not (is_transformers_available() and is_torch_available()):
40+
raise OptionalDependencyNotAvailable()
41+
except OptionalDependencyNotAvailable:
42+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
43+
else:
44+
from .encoders import FluxTextEncoderStep
45+
from .modular_blocks import (
46+
ALL_BLOCKS,
47+
AUTO_BLOCKS,
48+
TEXT2IMAGE_BLOCKS,
49+
FluxAutoBeforeDenoiseStep,
50+
FluxAutoBlocks,
51+
FluxAutoDecodeStep,
52+
FluxAutoDenoiseStep,
53+
)
54+
from .modular_pipeline import FluxModularPipeline
55+
else:
56+
import sys
57+
58+
sys.modules[__name__] = _LazyModule(
59+
__name__,
60+
globals()["__file__"],
61+
_import_structure,
62+
module_spec=__spec__,
63+
)
64+
65+
for name, value in _dummy_objects.items():
66+
setattr(sys.modules[__name__], name, value)

0 commit comments

Comments
 (0)