Skip to content

Commit cdcac4a

Browse files
committed
up
1 parent 9d313fc commit cdcac4a

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def __init__(
131131
self,
132132
tokenizer: AutoTokenizer,
133133
text_encoder: UMT5EncoderModel,
134-
transformer: WanTransformer3DModel,
135134
vae: AutoencoderKLWan,
136135
scheduler: FlowMatchEulerDiscreteScheduler,
136+
transformer: Optional[WanTransformer3DModel] = None,
137137
transformer_2: Optional[WanTransformer3DModel] = None,
138138
boundary_ratio: Optional[float] = None,
139139
expand_timesteps: bool = False, # Wan2.2 ti2v
@@ -526,7 +526,7 @@ def __call__(
526526
device=device,
527527
)
528528

529-
transformer_dtype = self.transformer.dtype
529+
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
530530
prompt_embeds = prompt_embeds.to(transformer_dtype)
531531
if negative_prompt_embeds is not None:
532532
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
@@ -536,7 +536,7 @@ def __call__(
536536
timesteps = self.scheduler.timesteps
537537

538538
# 5. Prepare latent variables
539-
num_channels_latents = self.transformer.config.in_channels
539+
num_channels_latents = self.transformer.config.in_channels if self.transformer is not None else self.transformer_2.config.in_channels
540540
latents = self.prepare_latents(
541541
batch_size * num_videos_per_prompt,
542542
num_channels_latents,

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
162162

163163
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
164164
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
165-
_optional_components = ["transformer_2", "image_encoder", "image_processor"]
165+
_optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
166166

167167
def __init__(
168168
self,
169169
tokenizer: AutoTokenizer,
170170
text_encoder: UMT5EncoderModel,
171-
transformer: WanTransformer3DModel,
172171
vae: AutoencoderKLWan,
173172
scheduler: FlowMatchEulerDiscreteScheduler,
174173
image_processor: CLIPImageProcessor = None,
175174
image_encoder: CLIPVisionModel = None,
175+
transformer: WanTransformer3DModel = None,
176176
transformer_2: WanTransformer3DModel = None,
177177
boundary_ratio: Optional[float] = None,
178178
expand_timesteps: bool = False,
@@ -669,12 +669,13 @@ def __call__(
669669
)
670670

671671
# Encode image embedding
672-
transformer_dtype = self.transformer.dtype
672+
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
673673
prompt_embeds = prompt_embeds.to(transformer_dtype)
674674
if negative_prompt_embeds is not None:
675675
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
676676

677-
if self.config.boundary_ratio is None and not self.config.expand_timesteps:
677+
# only wan 2.1 i2v transformer accepts image_embeds
678+
if self.transformer is not None and self.transformer.config.added_kv_proj_dim is not None:
678679
if image_embeds is None:
679680
if last_image is None:
680681
image_embeds = self.encode_image(image, device)
@@ -709,6 +710,7 @@ def __call__(
709710
last_image,
710711
)
711712
if self.config.expand_timesteps:
713+
# wan 2.2 5b i2v use firt_frame_mask to mask timesteps
712714
latents, condition, first_frame_mask = latents_outputs
713715
else:
714716
latents, condition = latents_outputs

0 commit comments

Comments
 (0)