From cdcac4a5dd38137f43e0a54fc6a7b63fb52b8167 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Jul 2025 22:08:42 +0200 Subject: [PATCH 01/13] up --- src/diffusers/pipelines/wan/pipeline_wan.py | 6 +++--- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index f52bf33d810b..39a1ab2f46bf 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -131,9 +131,9 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + transformer: Optional[WanTransformer3DModel] = None, transformer_2: Optional[WanTransformer3DModel] = None, boundary_ratio: Optional[float] = None, expand_timesteps: bool = False, # Wan2.2 ti2v @@ -526,7 +526,7 @@ def __call__( device=device, ) - transformer_dtype = self.transformer.dtype + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) @@ -536,7 +536,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels + num_channels_latents = self.transformer.config.in_channels if self.transformer is not None else self.transformer_2.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index a072824a4854..0eb65eb88eb9 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -162,17 +162,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - _optional_components = ["transformer_2", "image_encoder", "image_processor"] + _optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, image_processor: CLIPImageProcessor = None, image_encoder: CLIPVisionModel = None, + transformer: WanTransformer3DModel = None, transformer_2: WanTransformer3DModel = None, boundary_ratio: Optional[float] = None, expand_timesteps: bool = False, @@ -669,12 +669,13 @@ def __call__( ) # Encode image embedding - transformer_dtype = self.transformer.dtype + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - if self.config.boundary_ratio is None and not self.config.expand_timesteps: + # only wan 2.1 i2v transformer accepts image_embeds + if self.transformer is not None and self.transformer.config.added_kv_proj_dim is not None: if image_embeds is None: if last_image is None: image_embeds = self.encode_image(image, device) @@ -709,6 +710,7 @@ def __call__( last_image, ) if self.config.expand_timesteps: + # wan 2.2 5b i2v use firt_frame_mask to mask timesteps latents, condition, first_frame_mask = latents_outputs else: latents, condition = latents_outputs From 5c995d93c92c9c3026a080dd7e9d8a8ad96dc7ae Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 30 Jul 2025 22:55:46 +0200 Subject: [PATCH 02/13] up --- src/diffusers/pipelines/wan/pipeline_wan.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 39a1ab2f46bf..78fe71ea9138 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -125,7 +125,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - _optional_components = ["transformer_2"] + _optional_components = ["transformer", "transformer_2"] def __init__( self, @@ -536,7 +536,11 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels if self.transformer is not None else self.transformer_2.config.in_channels + num_channels_latents = ( + self.transformer.config.in_channels + if self.transformer is not None + else self.transformer_2.config.in_channels + ) latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, From 2fd1e25cc17152af4fe1f0da6e2b0183de5895c3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 2 Aug 2025 02:09:36 +0200 Subject: [PATCH 03/13] make it work with batch_sie >1 --- src/diffusers/models/transformers/transformer_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 8a18ea5f3e2a..2b6d5953fc4f 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -324,7 +324,7 @@ def forward( ): timestep = self.timesteps_proj(timestep) if timestep_seq_len is not None: - timestep = timestep.unflatten(0, (1, timestep_seq_len)) + timestep = timestep.unflatten(0, (-1, timestep_seq_len)) time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: From 5e17dde3bb445a91e3337b36dc0f625d4c4a6f90 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 2 Aug 2025 02:09:47 +0200 Subject: [PATCH 04/13] add tests --- tests/pipelines/wan/test_wan.py | 18 +- tests/pipelines/wan/test_wan_22.py | 368 ++++++++++++++++++ .../pipelines/wan/test_wan_image_to_video.py | 39 +- 3 files changed, 371 insertions(+), 54 deletions(-) create mode 100644 tests/pipelines/wan/test_wan_22.py diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py index a7e4e27813b3..048ff259642c 100644 --- a/tests/pipelines/wan/test_wan.py +++ b/tests/pipelines/wan/test_wan.py @@ -85,29 +85,13 @@ def get_dummy_components(self): rope_max_seq_len=32, ) - torch.manual_seed(0) - transformer_2 = WanTransformer3DModel( - patch_size=(1, 2, 2), - num_attention_heads=2, - attention_head_dim=12, - in_channels=16, - out_channels=16, - text_dim=32, - freq_dim=256, - ffn_dim=32, - num_layers=2, - cross_attn_norm=True, - qk_norm="rms_norm_across_heads", - rope_max_seq_len=32, - ) - components = { "transformer": transformer, "vae": vae, "scheduler": scheduler, "text_encoder": text_encoder, "tokenizer": tokenizer, - "transformer_2": transformer_2, + "transformer_2": None, } return components diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py new file mode 100644 index 000000000000..3c8444f69943 --- /dev/null +++ b/tests/pipelines/wan/test_wan_22.py @@ -0,0 +1,368 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch +from transformers import AutoTokenizer, T5EncoderModel +import numpy as np +import tempfile + +from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel, UniPCMultistepScheduler +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + + + +enable_full_determinism() + + +class Wan22PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + test_optional_components = ["transformer_2", "transformer"] + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + torch.manual_seed(0) + transformer_2 = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer_2": transformer_2, + "boundary_ratio": 0.875, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components, ) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + + # fmt: off + expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + + optional_component = "transformer" + + components = self.get_dummy_components() + components[optional_component] = None + components["boundary_ratio"] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0 + + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + self.assertTrue(getattr(pipe_loaded, "transformer") is None, + f"`transformer` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess(max_diff, expected_max_difference) + + + +class Wan225BPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=48, + in_channels=12, + out_channels=12, + is_residual=True, + patch_size=2, + latents_mean = [0.0] * 48, + latents_std = [1.0] * 48, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + scale_factor_spatial=16, + scale_factor_temporal=4, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=48, + out_channels=48, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer_2": None, + "boundary_ratio": None, + "expand_timesteps": True, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components, ) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + + # fmt: off + expected_slice = torch.tensor([[0.4814, 0.5336, 0.5094, 0.4922, 0.5061, 0.4923, 0.5043, 0.4923, 0.6821, + 0.5965, 0.6753, 0.6014, 0.6939, 0.6076, 0.5133, 0.5651]]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3), f"generated_slice: {generated_slice}, expected_slice: {expected_slice}") + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + def test_components_function(self): + init_components = self.get_dummy_components() + init_components.pop("boundary_ratio") + init_components.pop("expand_timesteps") + pipe = self.pipeline_class(**init_components) + + self.assertTrue(hasattr(pipe, "components")) + self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + + optional_component = "transformer_2" + + components = self.get_dummy_components() + components[optional_component] = None + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + self.assertTrue(getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess(max_diff, expected_max_difference) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) \ No newline at end of file diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py index c693f4fcb247..d58113200884 100644 --- a/tests/pipelines/wan/test_wan_image_to_video.py +++ b/tests/pipelines/wan/test_wan_image_to_video.py @@ -86,23 +86,6 @@ def get_dummy_components(self): image_dim=4, ) - torch.manual_seed(0) - transformer_2 = WanTransformer3DModel( - patch_size=(1, 2, 2), - num_attention_heads=2, - attention_head_dim=12, - in_channels=36, - out_channels=16, - text_dim=32, - freq_dim=256, - ffn_dim=32, - num_layers=2, - cross_attn_norm=True, - qk_norm="rms_norm_across_heads", - rope_max_seq_len=32, - image_dim=4, - ) - torch.manual_seed(0) image_encoder_config = CLIPVisionConfig( hidden_size=4, @@ -126,7 +109,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "image_encoder": image_encoder, "image_processor": image_processor, - "transformer_2": transformer_2, + "transformer_2": None, } return components @@ -242,24 +225,6 @@ def get_dummy_components(self): pos_embed_seq_len=2 * (4 * 4 + 1), ) - torch.manual_seed(0) - transformer_2 = WanTransformer3DModel( - patch_size=(1, 2, 2), - num_attention_heads=2, - attention_head_dim=12, - in_channels=36, - out_channels=16, - text_dim=32, - freq_dim=256, - ffn_dim=32, - num_layers=2, - cross_attn_norm=True, - qk_norm="rms_norm_across_heads", - rope_max_seq_len=32, - image_dim=4, - pos_embed_seq_len=2 * (4 * 4 + 1), - ) - torch.manual_seed(0) image_encoder_config = CLIPVisionConfig( hidden_size=4, @@ -283,7 +248,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "image_encoder": image_encoder, "image_processor": image_processor, - "transformer_2": transformer_2, + "transformer_2": None, } return components From a7bce5fa7428ff75b800692a8278a2ab71a482df Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 2 Aug 2025 02:10:52 +0200 Subject: [PATCH 05/13] tests --- tests/pipelines/wan/test_wan_22.py | 54 +++++++++++++++--------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py index 3c8444f69943..2055701cd68b 100644 --- a/tests/pipelines/wan/test_wan_22.py +++ b/tests/pipelines/wan/test_wan_22.py @@ -12,20 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc +import tempfile import unittest +import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel -import numpy as np -import tempfile -from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel, UniPCMultistepScheduler +from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanPipeline, WanTransformer3DModel from diffusers.utils.testing_utils import ( - backend_empty_cache, enable_full_determinism, - require_torch_accelerator, - slow, torch_device, ) @@ -33,8 +29,6 @@ from ..test_pipelines_common import PipelineTesterMixin - - enable_full_determinism() @@ -139,7 +133,9 @@ def test_inference(self): device = "cpu" components = self.get_dummy_components() - pipe = self.pipeline_class(**components, ) + pipe = self.pipeline_class( + **components, + ) pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -159,14 +155,13 @@ def test_inference(self): @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): pass - - def test_save_load_optional_components(self, expected_max_difference=1e-4): + def test_save_load_optional_components(self, expected_max_difference=1e-4): optional_component = "transformer" components = self.get_dummy_components() components[optional_component] = None - components["boundary_ratio"] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0 + components["boundary_ratio"] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0 pipe = self.pipeline_class(**components) for component in pipe.components.values(): @@ -189,9 +184,10 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None) - self.assertTrue(getattr(pipe_loaded, "transformer") is None, - f"`transformer` did not stay set to None after loading.", - ) + self.assertTrue( + getattr(pipe_loaded, "transformer") is None, + "`transformer` did not stay set to None after loading.", + ) inputs = self.get_dummy_inputs(generator_device) torch.manual_seed(0) @@ -201,7 +197,6 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): self.assertLess(max_diff, expected_max_difference) - class Wan225BPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = WanPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} @@ -230,8 +225,8 @@ def get_dummy_components(self): out_channels=12, is_residual=True, patch_size=2, - latents_mean = [0.0] * 48, - latents_std = [1.0] * 48, + latents_mean=[0.0] * 48, + latents_std=[1.0] * 48, dim_mult=[1, 1, 1, 1], num_res_blocks=1, scale_factor_spatial=16, @@ -295,7 +290,9 @@ def test_inference(self): device = "cpu" components = self.get_dummy_components() - pipe = self.pipeline_class(**components, ) + pipe = self.pipeline_class( + **components, + ) pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -311,7 +308,10 @@ def test_inference(self): generated_slice = generated_video.flatten() generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) - self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3), f"generated_slice: {generated_slice}, expected_slice: {expected_slice}") + self.assertTrue( + torch.allclose(generated_slice, expected_slice, atol=1e-3), + f"generated_slice: {generated_slice}, expected_slice: {expected_slice}", + ) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): @@ -327,7 +327,6 @@ def test_components_function(self): self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) def test_save_load_optional_components(self, expected_max_difference=1e-4): - optional_component = "transformer_2" components = self.get_dummy_components() @@ -353,9 +352,10 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None) - self.assertTrue(getattr(pipe_loaded, optional_component) is None, - f"`{optional_component}` did not stay set to None after loading.", - ) + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) inputs = self.get_dummy_inputs(generator_device) torch.manual_seed(0) @@ -363,6 +363,6 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() self.assertLess(max_diff, expected_max_difference) - + def test_inference_batch_single_identical(self): - self._test_inference_batch_single_identical(expected_max_diff=2e-3) \ No newline at end of file + self._test_inference_batch_single_identical(expected_max_diff=2e-3) From 1189a35edd934b67c4c8a5c3fd40f93f4319fe4c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 3 Aug 2025 20:00:02 +0200 Subject: [PATCH 06/13] add more tests --- tests/pipelines/wan/test_wan.py | 38 ++ .../wan/test_wan_22_image_to_video.py | 388 ++++++++++++++++++ .../pipelines/wan/test_wan_image_to_video.py | 86 +++- 3 files changed, 502 insertions(+), 10 deletions(-) create mode 100644 tests/pipelines/wan/test_wan_22_image_to_video.py diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py index 048ff259642c..0aeb2a6b6b73 100644 --- a/tests/pipelines/wan/test_wan.py +++ b/tests/pipelines/wan/test_wan.py @@ -139,6 +139,44 @@ def test_inference(self): def test_attention_slicing_forward_pass(self): pass + # _optional_components include transformer, transformer_2, but only transformer_2 is optional for this wan2.1 t2v pipeline + def test_save_load_optional_components(self, expected_max_difference=1e-4): + optional_component = "transformer_2" + + components = self.get_dummy_components() + components[optional_component] = None + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess(max_diff, expected_max_difference) @slow @require_torch_accelerator diff --git a/tests/pipelines/wan/test_wan_22_image_to_video.py b/tests/pipelines/wan/test_wan_22_image_to_video.py new file mode 100644 index 000000000000..ed30dddb80a1 --- /dev/null +++ b/tests/pipelines/wan/test_wan_22_image_to_video.py @@ -0,0 +1,388 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanImageToVideoPipeline, WanTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin +from PIL import Image + + +enable_full_determinism() + + +class Wan22ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + torch.manual_seed(0) + transformer_2 = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer_2": transformer_2, + "image_encoder": None, + "image_processor": None, + "boundary_ratio": 0.875, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image_height = 16 + image_width = 16 + image = Image.new("RGB", (image_width, image_height)) + inputs = { + "image": image, + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "height": image_height, + "width": image_width, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class( + **components, + ) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + + # fmt: off + expected_slice = torch.tensor([0.4527, 0.4526, 0.4498, 0.4539, 0.4521, 0.4524, 0.4533, 0.4535, 0.5154, + 0.5353, 0.5200, 0.5174, 0.5434, 0.5301, 0.5199, 0.5216]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3), f"generated_slice: {generated_slice}, expected_slice: {expected_slice}") + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + optional_component = ["transformer", "image_encoder", "image_processor"] + + components = self.get_dummy_components() + for component in optional_component: + components[component] = None + components["boundary_ratio"] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0 + + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + self.assertTrue( + getattr(pipe_loaded, "transformer") is None, + "`transformer` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess(max_diff, expected_max_difference) + + +class Wan225BImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=48, + in_channels=12, + out_channels=12, + is_residual=True, + patch_size=2, + latents_mean=[0.0] * 48, + latents_std=[1.0] * 48, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + scale_factor_spatial=16, + scale_factor_temporal=4, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=48, + out_channels=48, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer_2": None, + "image_encoder": None, + "image_processor": None, + "boundary_ratio": None, + "expand_timesteps": True, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image_height = 32 + image_width = 32 + image = Image.new("RGB", (image_width, image_height)) + inputs = { + "image": image, + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "height": image_height, + "width": image_width, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class( + **components, + ) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + + # fmt: off + expected_slice = torch.tensor([[0.4833, 0.4305, 0.5100, 0.4299, 0.5056, 0.4298, 0.5052, 0.4332, 0.5550, + 0.6092, 0.5536, 0.5928, 0.5199, 0.5864, 0.6705, 0.5493]]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue( + torch.allclose(generated_slice, expected_slice, atol=1e-3), + f"generated_slice: {generated_slice}, expected_slice: {expected_slice}", + ) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + def test_components_function(self): + init_components = self.get_dummy_components() + init_components.pop("boundary_ratio") + init_components.pop("expand_timesteps") + pipe = self.pipeline_class(**init_components) + + self.assertTrue(hasattr(pipe, "components")) + self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + optional_component = ["transformer_2", "image_encoder", "image_processor"] + + components = self.get_dummy_components() + for component in optional_component: + components[component] = None + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for component in optional_component: + self.assertTrue( + getattr(pipe_loaded, component) is None, + f"`{component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess(max_diff, expected_max_difference) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + @unittest.skip("Test not supported") + def test_callback_inputs(self): + pass diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py index d58113200884..d50b9b0055d2 100644 --- a/tests/pipelines/wan/test_wan_image_to_video.py +++ b/tests/pipelines/wan/test_wan_image_to_video.py @@ -165,11 +165,44 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): pass - @unittest.skip( - "TODO: refactor this test: one component can be optional for certain checkpoints but not for others" - ) - def test_save_load_optional_components(self): - pass + # _optional_components include transformer, transformer_2 and image_encoder, image_processor, but only transformer_2 is optional for wan2.1 i2v pipeline + def test_save_load_optional_components(self, expected_max_difference=1e-4): + optional_component = "transformer_2" + + components = self.get_dummy_components() + components[optional_component] = None + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess(max_diff, expected_max_difference) class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @@ -306,8 +339,41 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): pass - @unittest.skip( - "TODO: refactor this test: one component can be optional for certain checkpoints but not for others" - ) - def test_save_load_optional_components(self): - pass + # _optional_components include transformer, transformer_2 and image_encoder, image_processor, but only transformer_2 is optional for wan2.1 FLFT2V pipeline + def test_save_load_optional_components(self, expected_max_difference=1e-4): + optional_component = "transformer_2" + + components = self.get_dummy_components() + components[optional_component] = None + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + torch.manual_seed(0) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess(max_diff, expected_max_difference) \ No newline at end of file From d710e593e341e8409d10d71c1687267d513af4f3 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 3 Aug 2025 21:50:33 +0200 Subject: [PATCH 07/13] style --- tests/pipelines/wan/test_wan.py | 3 +++ tests/pipelines/wan/test_wan_22_image_to_video.py | 9 ++++++--- tests/pipelines/wan/test_wan_image_to_video.py | 6 ++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py index 0aeb2a6b6b73..90b7978ec760 100644 --- a/tests/pipelines/wan/test_wan.py +++ b/tests/pipelines/wan/test_wan.py @@ -13,8 +13,10 @@ # limitations under the License. import gc +import tempfile import unittest +import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel @@ -178,6 +180,7 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() self.assertLess(max_diff, expected_max_difference) + @slow @require_torch_accelerator class WanPipelineIntegrationTests(unittest.TestCase): diff --git a/tests/pipelines/wan/test_wan_22_image_to_video.py b/tests/pipelines/wan/test_wan_22_image_to_video.py index ed30dddb80a1..ac764e992746 100644 --- a/tests/pipelines/wan/test_wan_22_image_to_video.py +++ b/tests/pipelines/wan/test_wan_22_image_to_video.py @@ -17,6 +17,7 @@ import numpy as np import torch +from PIL import Image from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanImageToVideoPipeline, WanTransformer3DModel @@ -27,7 +28,6 @@ from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin -from PIL import Image enable_full_determinism() @@ -157,7 +157,10 @@ def test_inference(self): generated_slice = generated_video.flatten() generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) - self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3), f"generated_slice: {generated_slice}, expected_slice: {expected_slice}") + self.assertTrue( + torch.allclose(generated_slice, expected_slice, atol=1e-3), + f"generated_slice: {generated_slice}, expected_slice: {expected_slice}", + ) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): @@ -371,7 +374,7 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): self.assertTrue( getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading.", - ) + ) inputs = self.get_dummy_inputs(generator_device) torch.manual_seed(0) diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py index d50b9b0055d2..1c938ce2dea3 100644 --- a/tests/pipelines/wan/test_wan_image_to_video.py +++ b/tests/pipelines/wan/test_wan_image_to_video.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest +import numpy as np import torch from PIL import Image from transformers import ( @@ -25,7 +27,7 @@ ) from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanImageToVideoPipeline, WanTransformer3DModel -from diffusers.utils.testing_utils import enable_full_determinism +from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -376,4 +378,4 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() - self.assertLess(max_diff, expected_max_difference) \ No newline at end of file + self.assertLess(max_diff, expected_max_difference) From bc7a8339a9f5d4cdb6c1b4456d6ab031de00d4fb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 3 Aug 2025 23:11:33 +0200 Subject: [PATCH 08/13] up --- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 0eb65eb88eb9..b7fd0b05980f 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -675,7 +675,7 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # only wan 2.1 i2v transformer accepts image_embeds - if self.transformer is not None and self.transformer.config.added_kv_proj_dim is not None: + if self.transformer is not None and self.transformer.config.image_dim is not None: if image_embeds is None: if last_image is None: image_embeds = self.encode_image(image, device) From 359bc7bc8ae424247796bd872ab004a20c970cc4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 3 Aug 2025 23:33:21 +0200 Subject: [PATCH 09/13] jpdate test after fixing 5b patchify --- tests/pipelines/wan/test_wan_22.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py index 2055701cd68b..b2539c716860 100644 --- a/tests/pipelines/wan/test_wan_22.py +++ b/tests/pipelines/wan/test_wan_22.py @@ -302,8 +302,8 @@ def test_inference(self): self.assertEqual(generated_video.shape, (9, 3, 32, 32)) # fmt: off - expected_slice = torch.tensor([[0.4814, 0.5336, 0.5094, 0.4922, 0.5061, 0.4923, 0.5043, 0.4923, 0.6821, - 0.5965, 0.6753, 0.6014, 0.6939, 0.6076, 0.5133, 0.5651]]) + expected_slice = torch.tensor([[[0.4814, 0.4298, 0.5094, 0.4289, 0.5061, 0.4301, 0.5043, 0.4284, 0.5375, + 0.5965, 0.5527, 0.6014, 0.5228, 0.6076, 0.6644, 0.5651]]]) # fmt: on generated_slice = generated_video.flatten() From c82d4fae47c380dd57ee07a3dc38049f32b57e1c Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 3 Aug 2025 11:38:33 -1000 Subject: [PATCH 10/13] Update tests/pipelines/wan/test_wan_22.py --- tests/pipelines/wan/test_wan_22.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py index b2539c716860..7370efca019b 100644 --- a/tests/pipelines/wan/test_wan_22.py +++ b/tests/pipelines/wan/test_wan_22.py @@ -50,7 +50,6 @@ class Wan22PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False supports_dduf = False - test_optional_components = ["transformer_2", "transformer"] def get_dummy_components(self): torch.manual_seed(0) From 7a66c4c2279f8d56d365d3c1ca720540e4f6e5ed Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 3 Aug 2025 23:43:02 +0200 Subject: [PATCH 11/13] up --- tests/pipelines/wan/test_wan_22_image_to_video.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/wan/test_wan_22_image_to_video.py b/tests/pipelines/wan/test_wan_22_image_to_video.py index ac764e992746..3f72a74e4498 100644 --- a/tests/pipelines/wan/test_wan_22_image_to_video.py +++ b/tests/pipelines/wan/test_wan_22_image_to_video.py @@ -195,10 +195,11 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4): pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None) - self.assertTrue( - getattr(pipe_loaded, "transformer") is None, - "`transformer` did not stay set to None after loading.", - ) + for component in optional_component: + self.assertTrue( + getattr(pipe_loaded, component) is None, + f"`{component}` did not stay set to None after loading.", + ) inputs = self.get_dummy_inputs(generator_device) torch.manual_seed(0) From a48bd362ca08016559588bbe9b6210c4ab2d1d07 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 3 Aug 2025 22:34:30 -1000 Subject: [PATCH 12/13] Update tests/pipelines/wan/test_wan_22.py --- tests/pipelines/wan/test_wan_22.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py index 7370efca019b..0b43b5ff1282 100644 --- a/tests/pipelines/wan/test_wan_22.py +++ b/tests/pipelines/wan/test_wan_22.py @@ -116,7 +116,7 @@ def get_dummy_inputs(self, device, seed=0): generator = torch.Generator(device=device).manual_seed(seed) inputs = { "prompt": "dance monkey", - "negative_prompt": "negative", # TODO + "negative_prompt": "negative", "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, From 2ea15e71f6c46d7cd0f8ac0a49f0503303943dce Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 4 Aug 2025 08:41:04 +0000 Subject: [PATCH 13/13] Apply style fixes --- tests/pipelines/wan/test_wan_22.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py index 0b43b5ff1282..9fdae6698069 100644 --- a/tests/pipelines/wan/test_wan_22.py +++ b/tests/pipelines/wan/test_wan_22.py @@ -116,7 +116,7 @@ def get_dummy_inputs(self, device, seed=0): generator = torch.Generator(device=device).manual_seed(seed) inputs = { "prompt": "dance monkey", - "negative_prompt": "negative", + "negative_prompt": "negative", "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0,