Skip to content

Commit 7ae4442

Browse files
committed
update
1 parent 183bcd5 commit 7ae4442

File tree

5 files changed

+297
-51
lines changed

5 files changed

+297
-51
lines changed

src/diffusers/modular_pipelines/wan/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
_import_structure["modular_blocks"] = [
2626
"ALL_BLOCKS",
2727
"AUTO_BLOCKS",
28+
"IMAGE2VIDEO_BLOCKS",
2829
"TEXT2VIDEO_BLOCKS",
2930
"WanAutoBeforeDenoiseStep",
3031
"WanAutoBlocks",
3132
"WanAutoBlocks",
3233
"WanAutoDecodeStep",
3334
"WanAutoDenoiseStep",
35+
"WanAutoVaeEncoderStep",
3436
]
3537
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
3638

@@ -45,11 +47,13 @@
4547
from .modular_blocks import (
4648
ALL_BLOCKS,
4749
AUTO_BLOCKS,
50+
IMAGE2VIDEO_BLOCKS,
4851
TEXT2VIDEO_BLOCKS,
4952
WanAutoBeforeDenoiseStep,
5053
WanAutoBlocks,
5154
WanAutoDecodeStep,
5255
WanAutoDenoiseStep,
56+
WanAutoVaeEncoderStep,
5357
)
5458
from .modular_pipeline import WanModularPipeline
5559
else:

src/diffusers/modular_pipelines/wan/before_denoise.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,10 @@ def intermediate_outputs(self) -> List[OutputParam]:
282282
return [
283283
OutputParam(
284284
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
285-
)
285+
),
286+
OutputParam("height", type_hint=int),
287+
OutputParam("width", type_hint=int),
288+
OutputParam("num_frames", type_hint=int),
286289
]
287290

288291
@staticmethod

src/diffusers/modular_pipelines/wan/denoise.py

Lines changed: 180 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,56 @@
3434
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3535

3636

37+
class WanI2VLoopBeforeDenoiser(PipelineBlock):
38+
model_name = "stable-diffusion-xl"
39+
40+
@property
41+
def expected_components(self) -> List[ComponentSpec]:
42+
return [
43+
ComponentSpec("scheduler", UniPCMultistepScheduler),
44+
]
45+
46+
@property
47+
def description(self) -> str:
48+
return (
49+
"Step within the denoising loop that prepares the latent input for the denoiser. "
50+
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
51+
"object (e.g. `WanI2VDenoiseLoopWrapper`)"
52+
)
53+
54+
@property
55+
def intermediate_inputs(self) -> List[str]:
56+
return [
57+
InputParam(
58+
"latents",
59+
required=True,
60+
type_hint=torch.Tensor,
61+
description="The initial latents to use for the denoising process.",
62+
),
63+
InputParam(
64+
"latent_condition",
65+
required=True,
66+
type_hint=torch.Tensor,
67+
description="The latent condition to use for the denoising process.",
68+
),
69+
]
70+
71+
@property
72+
def intermediate_outputs(self) -> List[OutputParam]:
73+
return [
74+
OutputParam(
75+
"concatenated_latents",
76+
type_hint=torch.Tensor,
77+
description="The concatenated noisy and conditioning latents to use for the denoising process.",
78+
),
79+
]
80+
81+
@torch.no_grad()
82+
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int):
83+
block_state.concatenated_latents = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
84+
return components, block_state
85+
86+
3787
class WanLoopDenoiser(PipelineBlock):
3888
model_name = "wan"
3989

@@ -102,7 +152,7 @@ def __call__(
102152
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
103153

104154
# Prepare mini‐batches according to guidance method and `guider_input_fields`
105-
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
155+
# Each guider_state_batch will have .prompt_embeds.
106156
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
107157
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
108158
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
@@ -120,7 +170,112 @@ def __call__(
120170
guider_state_batch.noise_pred = components.transformer(
121171
hidden_states=block_state.latents.to(transformer_dtype),
122172
timestep=t.flatten(),
123-
encoder_hidden_states=prompt_embeds,
173+
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
174+
attention_kwargs=block_state.attention_kwargs,
175+
return_dict=False,
176+
)[0]
177+
components.guider.cleanup_models(components.transformer)
178+
179+
# Perform guidance
180+
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
181+
182+
return components, block_state
183+
184+
185+
class WanI2VLoopDenoiser(PipelineBlock):
186+
model_name = "wan"
187+
188+
@property
189+
def expected_components(self) -> List[ComponentSpec]:
190+
return [
191+
ComponentSpec(
192+
"guider",
193+
ClassifierFreeGuidance,
194+
config=FrozenDict({"guidance_scale": 5.0}),
195+
default_creation_method="from_config",
196+
),
197+
ComponentSpec("transformer", WanTransformer3DModel),
198+
]
199+
200+
@property
201+
def description(self) -> str:
202+
return (
203+
"Step within the denoising loop that denoise the latents with guidance. "
204+
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
205+
"object (e.g. `WanDenoiseLoopWrapper`)"
206+
)
207+
208+
@property
209+
def inputs(self) -> List[Tuple[str, Any]]:
210+
return [
211+
InputParam("attention_kwargs"),
212+
]
213+
214+
@property
215+
def intermediate_inputs(self) -> List[str]:
216+
return [
217+
InputParam(
218+
"concatenated_latents",
219+
required=True,
220+
type_hint=torch.Tensor,
221+
description="The initial latents to use for the denoising process.",
222+
),
223+
InputParam(
224+
"encoder_hidden_states_image",
225+
required=True,
226+
type_hint=torch.Tensor,
227+
description="The encoder hidden states for the image inputs.",
228+
),
229+
InputParam(
230+
"num_inference_steps",
231+
required=True,
232+
type_hint=int,
233+
description="The number of inference steps to use for the denoising process.",
234+
),
235+
InputParam(
236+
kwargs_type="guider_input_fields",
237+
description=(
238+
"All conditional model inputs that need to be prepared with guider. "
239+
"It should contain prompt_embeds/negative_prompt_embeds. "
240+
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
241+
),
242+
),
243+
]
244+
245+
@torch.no_grad()
246+
def __call__(
247+
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
248+
) -> PipelineState:
249+
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
250+
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
251+
guider_input_fields = {
252+
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
253+
}
254+
transformer_dtype = components.transformer.dtype
255+
256+
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
257+
258+
# Prepare mini‐batches according to guidance method and `guider_input_fields`
259+
# Each guider_state_batch will have .prompt_embeds.
260+
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
261+
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
262+
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
263+
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
264+
265+
# run the denoiser for each guidance batch
266+
for guider_state_batch in guider_state:
267+
components.guider.prepare_models(components.transformer)
268+
cond_kwargs = guider_state_batch.as_dict()
269+
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
270+
prompt_embeds = cond_kwargs.pop("prompt_embeds")
271+
272+
# Predict the noise residual
273+
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
274+
guider_state_batch.noise_pred = components.transformer(
275+
hidden_states=block_state.concatenated_latents.to(transformer_dtype),
276+
timestep=t.flatten(),
277+
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
278+
encoder_hidden_states_image=block_state.encoder_hidden_states_image.to(transformer_dtype),
124279
attention_kwargs=block_state.attention_kwargs,
125280
return_dict=False,
126281
)[0]
@@ -247,7 +402,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
247402
WanLoopDenoiser,
248403
WanLoopAfterDenoiser,
249404
]
250-
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
405+
block_names = ["denoiser", "after_denoiser"]
251406

252407
@property
253408
def description(self) -> str:
@@ -257,5 +412,26 @@ def description(self) -> str:
257412
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
258413
" - `WanLoopDenoiser`\n"
259414
" - `WanLoopAfterDenoiser`\n"
260-
"This block supports both text2vid tasks."
415+
"This block supports the text2vid task."
416+
)
417+
418+
419+
class WanI2VDenoiseStep(WanDenoiseLoopWrapper):
420+
block_classes = [
421+
WanI2VLoopBeforeDenoiser,
422+
WanI2VLoopDenoiser,
423+
WanLoopAfterDenoiser,
424+
]
425+
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
426+
427+
@property
428+
def description(self) -> str:
429+
return (
430+
"Denoise step that iteratively denoises the latents with conditional first- and last-frame support. \n"
431+
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
432+
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
433+
" - `WanI2VLoopBeforeDenoiser`\n"
434+
" - `WanI2VLoopDenoiser`\n"
435+
" - `WanI2VLoopAfterDenoiser`\n"
436+
"This block supports the image-to-video and first-last-frame-to-video tasks."
261437
)

src/diffusers/modular_pipelines/wan/encoders.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
259259
return components, state
260260

261261

262-
class WanImageEncodeStep(PipelineBlock):
262+
class WanImageEncoderStep(PipelineBlock):
263263
model_name = "wan"
264264

265265
@property
@@ -368,15 +368,15 @@ def inputs(self) -> List[InputParam]:
368368
return [
369369
InputParam("image", required=True),
370370
InputParam("last_image", required=False),
371-
InputParam("height", type_hint=int),
372-
InputParam("width", type_hint=int),
373-
InputParam("num_frames", type_hint=int),
374371
]
375372

376373
@property
377374
def intermediate_inputs(self) -> List[InputParam]:
378375
return [
379-
InputParam("num_channels_latents", type_hint=int),
376+
InputParam("height", type_hint=int),
377+
InputParam("width", type_hint=int),
378+
InputParam("num_frames", type_hint=int),
379+
InputParam("batch_size", type_hint=int),
380380
InputParam("generator"),
381381
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
382382
]
@@ -388,11 +388,12 @@ def intermediate_outputs(self) -> List[OutputParam]:
388388
"latent_condition",
389389
type_hint=torch.Tensor,
390390
description="The latents representing the reference first-frame/last-frame for conditioned video generation.",
391-
)
391+
),
392+
OutputParam("num_channels_latents", type_hint=int),
392393
]
393394

395+
@staticmethod
394396
def _encode_vae_image(
395-
self,
396397
components: WanModularPipeline,
397398
batch_size: int,
398399
height: int,
@@ -404,11 +405,13 @@ def _encode_vae_image(
404405
last_image: Optional[torch.Tensor] = None,
405406
generator: Optional[torch.Generator] = None,
406407
):
407-
latent_height = height // self.vae_scale_factor_spatial
408-
latent_width = width // self.vae_scale_factor_spatial
408+
latent_height = height // components.vae_scale_factor_spatial
409+
latent_width = width // components.vae_scale_factor_spatial
409410

410411
latents_mean = (
411-
torch.tensor(components.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
412+
torch.tensor(components.vae.config.latents_mean)
413+
.view(1, components.vae.config.z_dim, 1, 1, 1)
414+
.to(device, dtype)
412415
)
413416
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
414417
1, components.vae.config.z_dim, 1, 1, 1
@@ -429,11 +432,11 @@ def _encode_vae_image(
429432

430433
if isinstance(generator, list):
431434
latent_condition = [
432-
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
435+
retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") for _ in generator
433436
]
434437
latent_condition = torch.cat(latent_condition)
435438
else:
436-
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
439+
latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax")
437440
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
438441

439442
latent_condition = latent_condition.to(dtype)
@@ -445,9 +448,13 @@ def _encode_vae_image(
445448
else:
446449
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
447450
first_frame_mask = mask_lat_size[:, :, 0:1]
448-
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
451+
first_frame_mask = torch.repeat_interleave(
452+
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
453+
)
449454
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
450-
mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
455+
mask_lat_size = mask_lat_size.view(
456+
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
457+
)
451458
mask_lat_size = mask_lat_size.transpose(1, 2)
452459
mask_lat_size = mask_lat_size.to(latent_condition.device)
453460
latent_condition = torch.concat([mask_lat_size, latent_condition], dim=1)
@@ -458,32 +465,30 @@ def _encode_vae_image(
458465
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
459466
block_state = self.get_block_state(state)
460467
block_state.device = components._execution_device
461-
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
462-
block_state.num_channels_latents = self.vae.config.z_dim
463-
block_state.batch_size = (
464-
block_state.batch_size if block_state.batch_size is not None else block_state.image.shape[0]
465-
)
468+
block_state.num_channels_latents = components.vae.config.z_dim
466469

467-
block_state.image = self.video_processor.preprocess(
470+
block_state.image = components.video_processor.preprocess(
468471
block_state.image, height=block_state.height, width=block_state.width
469472
).to(block_state.device, dtype=torch.float32)
473+
470474
if block_state.last_image is not None:
471-
block_state.last_image = self.video_processor.preprocess(
475+
block_state.last_image = components.video_processor.preprocess(
472476
block_state.last_image, height=block_state.height, width=block_state.width
473477
).to(block_state.device, dtype=torch.float32)
474478

475479
block_state.latent_condition = self._encode_vae_image(
476480
components,
477-
batch_size=block_state.batch_size,
478-
height=block_state.height,
479-
width=block_state.width,
480-
num_frames=block_state.num_frames,
481-
image=block_state.image,
482-
device=block_state.device,
483-
dtype=block_state.dtype,
484-
last_image=block_state.last_image,
485-
generator=block_state.generator,
481+
block_state.batch_size,
482+
block_state.height,
483+
block_state.width,
484+
block_state.num_frames,
485+
block_state.image,
486+
block_state.device,
487+
block_state.dtype,
488+
block_state.last_image,
489+
block_state.generator,
486490
)
487491

488492
self.set_block_state(state, block_state)
493+
489494
return components, state

0 commit comments

Comments
 (0)