34
34
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
35
35
36
36
37
+ class WanI2VLoopBeforeDenoiser (PipelineBlock ):
38
+ model_name = "stable-diffusion-xl"
39
+
40
+ @property
41
+ def expected_components (self ) -> List [ComponentSpec ]:
42
+ return [
43
+ ComponentSpec ("scheduler" , UniPCMultistepScheduler ),
44
+ ]
45
+
46
+ @property
47
+ def description (self ) -> str :
48
+ return (
49
+ "Step within the denoising loop that prepares the latent input for the denoiser. "
50
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
51
+ "object (e.g. `WanI2VDenoiseLoopWrapper`)"
52
+ )
53
+
54
+ @property
55
+ def intermediate_inputs (self ) -> List [str ]:
56
+ return [
57
+ InputParam (
58
+ "latents" ,
59
+ required = True ,
60
+ type_hint = torch .Tensor ,
61
+ description = "The initial latents to use for the denoising process." ,
62
+ ),
63
+ InputParam (
64
+ "latent_condition" ,
65
+ required = True ,
66
+ type_hint = torch .Tensor ,
67
+ description = "The latent condition to use for the denoising process." ,
68
+ ),
69
+ ]
70
+
71
+ @property
72
+ def intermediate_outputs (self ) -> List [OutputParam ]:
73
+ return [
74
+ OutputParam (
75
+ "concatenated_latents" ,
76
+ type_hint = torch .Tensor ,
77
+ description = "The concatenated noisy and conditioning latents to use for the denoising process." ,
78
+ ),
79
+ ]
80
+
81
+ @torch .no_grad ()
82
+ def __call__ (self , components : WanModularPipeline , block_state : BlockState , i : int , t : int ):
83
+ block_state .concatenated_latents = torch .cat ([block_state .latents , block_state .latent_condition ], dim = 1 )
84
+ return components , block_state
85
+
86
+
37
87
class WanLoopDenoiser (PipelineBlock ):
38
88
model_name = "wan"
39
89
@@ -102,7 +152,7 @@ def __call__(
102
152
components .guider .set_state (step = i , num_inference_steps = block_state .num_inference_steps , timestep = t )
103
153
104
154
# Prepare mini‐batches according to guidance method and `guider_input_fields`
105
- # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds .
155
+ # Each guider_state_batch will have .prompt_embeds.
106
156
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
107
157
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
108
158
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
@@ -120,7 +170,112 @@ def __call__(
120
170
guider_state_batch .noise_pred = components .transformer (
121
171
hidden_states = block_state .latents .to (transformer_dtype ),
122
172
timestep = t .flatten (),
123
- encoder_hidden_states = prompt_embeds ,
173
+ encoder_hidden_states = prompt_embeds .to (transformer_dtype ),
174
+ attention_kwargs = block_state .attention_kwargs ,
175
+ return_dict = False ,
176
+ )[0 ]
177
+ components .guider .cleanup_models (components .transformer )
178
+
179
+ # Perform guidance
180
+ block_state .noise_pred , block_state .scheduler_step_kwargs = components .guider (guider_state )
181
+
182
+ return components , block_state
183
+
184
+
185
+ class WanI2VLoopDenoiser (PipelineBlock ):
186
+ model_name = "wan"
187
+
188
+ @property
189
+ def expected_components (self ) -> List [ComponentSpec ]:
190
+ return [
191
+ ComponentSpec (
192
+ "guider" ,
193
+ ClassifierFreeGuidance ,
194
+ config = FrozenDict ({"guidance_scale" : 5.0 }),
195
+ default_creation_method = "from_config" ,
196
+ ),
197
+ ComponentSpec ("transformer" , WanTransformer3DModel ),
198
+ ]
199
+
200
+ @property
201
+ def description (self ) -> str :
202
+ return (
203
+ "Step within the denoising loop that denoise the latents with guidance. "
204
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
205
+ "object (e.g. `WanDenoiseLoopWrapper`)"
206
+ )
207
+
208
+ @property
209
+ def inputs (self ) -> List [Tuple [str , Any ]]:
210
+ return [
211
+ InputParam ("attention_kwargs" ),
212
+ ]
213
+
214
+ @property
215
+ def intermediate_inputs (self ) -> List [str ]:
216
+ return [
217
+ InputParam (
218
+ "concatenated_latents" ,
219
+ required = True ,
220
+ type_hint = torch .Tensor ,
221
+ description = "The initial latents to use for the denoising process." ,
222
+ ),
223
+ InputParam (
224
+ "encoder_hidden_states_image" ,
225
+ required = True ,
226
+ type_hint = torch .Tensor ,
227
+ description = "The encoder hidden states for the image inputs." ,
228
+ ),
229
+ InputParam (
230
+ "num_inference_steps" ,
231
+ required = True ,
232
+ type_hint = int ,
233
+ description = "The number of inference steps to use for the denoising process." ,
234
+ ),
235
+ InputParam (
236
+ kwargs_type = "guider_input_fields" ,
237
+ description = (
238
+ "All conditional model inputs that need to be prepared with guider. "
239
+ "It should contain prompt_embeds/negative_prompt_embeds. "
240
+ "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
241
+ ),
242
+ ),
243
+ ]
244
+
245
+ @torch .no_grad ()
246
+ def __call__ (
247
+ self , components : WanModularPipeline , block_state : BlockState , i : int , t : torch .Tensor
248
+ ) -> PipelineState :
249
+ # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
250
+ # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
251
+ guider_input_fields = {
252
+ "prompt_embeds" : ("prompt_embeds" , "negative_prompt_embeds" ),
253
+ }
254
+ transformer_dtype = components .transformer .dtype
255
+
256
+ components .guider .set_state (step = i , num_inference_steps = block_state .num_inference_steps , timestep = t )
257
+
258
+ # Prepare mini‐batches according to guidance method and `guider_input_fields`
259
+ # Each guider_state_batch will have .prompt_embeds.
260
+ # e.g. for CFG, we prepare two batches: one for uncond, one for cond
261
+ # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
262
+ # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
263
+ guider_state = components .guider .prepare_inputs (block_state , guider_input_fields )
264
+
265
+ # run the denoiser for each guidance batch
266
+ for guider_state_batch in guider_state :
267
+ components .guider .prepare_models (components .transformer )
268
+ cond_kwargs = guider_state_batch .as_dict ()
269
+ cond_kwargs = {k : v for k , v in cond_kwargs .items () if k in guider_input_fields }
270
+ prompt_embeds = cond_kwargs .pop ("prompt_embeds" )
271
+
272
+ # Predict the noise residual
273
+ # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
274
+ guider_state_batch .noise_pred = components .transformer (
275
+ hidden_states = block_state .concatenated_latents .to (transformer_dtype ),
276
+ timestep = t .flatten (),
277
+ encoder_hidden_states = prompt_embeds .to (transformer_dtype ),
278
+ encoder_hidden_states_image = block_state .encoder_hidden_states_image .to (transformer_dtype ),
124
279
attention_kwargs = block_state .attention_kwargs ,
125
280
return_dict = False ,
126
281
)[0 ]
@@ -247,7 +402,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
247
402
WanLoopDenoiser ,
248
403
WanLoopAfterDenoiser ,
249
404
]
250
- block_names = ["before_denoiser" , " denoiser" , "after_denoiser" ]
405
+ block_names = ["denoiser" , "after_denoiser" ]
251
406
252
407
@property
253
408
def description (self ) -> str :
@@ -257,5 +412,26 @@ def description(self) -> str:
257
412
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n "
258
413
" - `WanLoopDenoiser`\n "
259
414
" - `WanLoopAfterDenoiser`\n "
260
- "This block supports both text2vid tasks."
415
+ "This block supports the text2vid task."
416
+ )
417
+
418
+
419
+ class WanI2VDenoiseStep (WanDenoiseLoopWrapper ):
420
+ block_classes = [
421
+ WanI2VLoopBeforeDenoiser ,
422
+ WanI2VLoopDenoiser ,
423
+ WanLoopAfterDenoiser ,
424
+ ]
425
+ block_names = ["before_denoiser" , "denoiser" , "after_denoiser" ]
426
+
427
+ @property
428
+ def description (self ) -> str :
429
+ return (
430
+ "Denoise step that iteratively denoises the latents with conditional first- and last-frame support. \n "
431
+ "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n "
432
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n "
433
+ " - `WanI2VLoopBeforeDenoiser`\n "
434
+ " - `WanI2VLoopDenoiser`\n "
435
+ " - `WanI2VLoopAfterDenoiser`\n "
436
+ "This block supports the image-to-video and first-last-frame-to-video tasks."
261
437
)
0 commit comments