@@ -162,17 +162,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
162
162
163
163
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
164
164
_callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
165
- _optional_components = ["transformer_2" , "image_encoder" , "image_processor" ]
165
+ _optional_components = ["transformer" , " transformer_2" , "image_encoder" , "image_processor" ]
166
166
167
167
def __init__ (
168
168
self ,
169
169
tokenizer : AutoTokenizer ,
170
170
text_encoder : UMT5EncoderModel ,
171
- transformer : WanTransformer3DModel ,
172
171
vae : AutoencoderKLWan ,
173
172
scheduler : FlowMatchEulerDiscreteScheduler ,
174
173
image_processor : CLIPImageProcessor = None ,
175
174
image_encoder : CLIPVisionModel = None ,
175
+ transformer : WanTransformer3DModel = None ,
176
176
transformer_2 : WanTransformer3DModel = None ,
177
177
boundary_ratio : Optional [float ] = None ,
178
178
expand_timesteps : bool = False ,
@@ -669,12 +669,13 @@ def __call__(
669
669
)
670
670
671
671
# Encode image embedding
672
- transformer_dtype = self .transformer .dtype
672
+ transformer_dtype = self .transformer .dtype if self . transformer is not None else self . transformer_2 . dtype
673
673
prompt_embeds = prompt_embeds .to (transformer_dtype )
674
674
if negative_prompt_embeds is not None :
675
675
negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
676
676
677
- if self .config .boundary_ratio is None and not self .config .expand_timesteps :
677
+ # only wan 2.1 i2v transformer accepts image_embeds
678
+ if self .transformer is not None and self .transformer .config .added_kv_proj_dim is not None :
678
679
if image_embeds is None :
679
680
if last_image is None :
680
681
image_embeds = self .encode_image (image , device )
@@ -709,6 +710,7 @@ def __call__(
709
710
last_image ,
710
711
)
711
712
if self .config .expand_timesteps :
713
+ # wan 2.2 5b i2v use firt_frame_mask to mask timesteps
712
714
latents , condition , first_frame_mask = latents_outputs
713
715
else :
714
716
latents , condition = latents_outputs
0 commit comments