Skip to content

Commit 2ff808d

Browse files
committed
address review comments
1 parent 22f3273 commit 2ff808d

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

src/diffusers/modular_pipelines/wan/denoise.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636

3737
class WanI2VLoopBeforeDenoiser(PipelineBlock):
38-
model_name = "stable-diffusion-xl"
38+
model_name = "wan"
3939

4040
@property
4141
def expected_components(self) -> List[ComponentSpec]:
@@ -72,15 +72,15 @@ def intermediate_inputs(self) -> List[str]:
7272
def intermediate_outputs(self) -> List[OutputParam]:
7373
return [
7474
OutputParam(
75-
"concatenated_latents",
75+
"latent_model_inputs",
7676
type_hint=torch.Tensor,
7777
description="The concatenated noisy and conditioning latents to use for the denoising process.",
7878
),
7979
]
8080

8181
@torch.no_grad()
8282
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int):
83-
block_state.concatenated_latents = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
83+
block_state.latent_model_inputs = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
8484
return components, block_state
8585

8686

@@ -215,13 +215,13 @@ def inputs(self) -> List[Tuple[str, Any]]:
215215
def intermediate_inputs(self) -> List[str]:
216216
return [
217217
InputParam(
218-
"concatenated_latents",
218+
"latent_model_inputs",
219219
required=True,
220220
type_hint=torch.Tensor,
221221
description="The initial latents to use for the denoising process.",
222222
),
223223
InputParam(
224-
"encoder_hidden_states_image",
224+
"image_embeds",
225225
required=True,
226226
type_hint=torch.Tensor,
227227
description="The encoder hidden states for the image inputs.",
@@ -272,10 +272,10 @@ def __call__(
272272
# Predict the noise residual
273273
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
274274
guider_state_batch.noise_pred = components.transformer(
275-
hidden_states=block_state.concatenated_latents.to(transformer_dtype),
275+
hidden_states=block_state.latent_model_inputs.to(transformer_dtype),
276276
timestep=t.flatten(),
277277
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
278-
encoder_hidden_states_image=block_state.encoder_hidden_states_image.to(transformer_dtype),
278+
encoder_hidden_states_image=block_state.image_embeds.to(transformer_dtype),
279279
attention_kwargs=block_state.attention_kwargs,
280280
return_dict=False,
281281
)[0]

src/diffusers/modular_pipelines/wan/encoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def inputs(self) -> List[InputParam]:
296296
def intermediate_outputs(self) -> List[OutputParam]:
297297
return [
298298
OutputParam(
299-
"encoder_hidden_states_image",
299+
"image_embeds",
300300
type_hint=torch.Tensor,
301301
description="image embeddings used to guide the image generation",
302302
),
@@ -335,7 +335,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
335335
if block_state.last_image is not None:
336336
image = [block_state.image, block_state.last_image]
337337

338-
block_state.encoder_hidden_states_image = self.encode_image(components, image, block_state.device)
338+
block_state.image_embeds = self.encode_image(components, image, block_state.device)
339339

340340
# Add outputs
341341
self.set_block_state(state, block_state)

0 commit comments

Comments
 (0)