|
35 | 35 |
|
36 | 36 |
|
37 | 37 | class WanI2VLoopBeforeDenoiser(PipelineBlock):
|
38 |
| - model_name = "stable-diffusion-xl" |
| 38 | + model_name = "wan" |
39 | 39 |
|
40 | 40 | @property
|
41 | 41 | def expected_components(self) -> List[ComponentSpec]:
|
@@ -72,15 +72,15 @@ def intermediate_inputs(self) -> List[str]:
|
72 | 72 | def intermediate_outputs(self) -> List[OutputParam]:
|
73 | 73 | return [
|
74 | 74 | OutputParam(
|
75 |
| - "concatenated_latents", |
| 75 | + "latent_model_inputs", |
76 | 76 | type_hint=torch.Tensor,
|
77 | 77 | description="The concatenated noisy and conditioning latents to use for the denoising process.",
|
78 | 78 | ),
|
79 | 79 | ]
|
80 | 80 |
|
81 | 81 | @torch.no_grad()
|
82 | 82 | def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int):
|
83 |
| - block_state.concatenated_latents = torch.cat([block_state.latents, block_state.latent_condition], dim=1) |
| 83 | + block_state.latent_model_inputs = torch.cat([block_state.latents, block_state.latent_condition], dim=1) |
84 | 84 | return components, block_state
|
85 | 85 |
|
86 | 86 |
|
@@ -215,13 +215,13 @@ def inputs(self) -> List[Tuple[str, Any]]:
|
215 | 215 | def intermediate_inputs(self) -> List[str]:
|
216 | 216 | return [
|
217 | 217 | InputParam(
|
218 |
| - "concatenated_latents", |
| 218 | + "latent_model_inputs", |
219 | 219 | required=True,
|
220 | 220 | type_hint=torch.Tensor,
|
221 | 221 | description="The initial latents to use for the denoising process.",
|
222 | 222 | ),
|
223 | 223 | InputParam(
|
224 |
| - "encoder_hidden_states_image", |
| 224 | + "image_embeds", |
225 | 225 | required=True,
|
226 | 226 | type_hint=torch.Tensor,
|
227 | 227 | description="The encoder hidden states for the image inputs.",
|
@@ -272,10 +272,10 @@ def __call__(
|
272 | 272 | # Predict the noise residual
|
273 | 273 | # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
274 | 274 | guider_state_batch.noise_pred = components.transformer(
|
275 |
| - hidden_states=block_state.concatenated_latents.to(transformer_dtype), |
| 275 | + hidden_states=block_state.latent_model_inputs.to(transformer_dtype), |
276 | 276 | timestep=t.flatten(),
|
277 | 277 | encoder_hidden_states=prompt_embeds.to(transformer_dtype),
|
278 |
| - encoder_hidden_states_image=block_state.encoder_hidden_states_image.to(transformer_dtype), |
| 278 | + encoder_hidden_states_image=block_state.image_embeds.to(transformer_dtype), |
279 | 279 | attention_kwargs=block_state.attention_kwargs,
|
280 | 280 | return_dict=False,
|
281 | 281 | )[0]
|
|
0 commit comments