Skip to content

Commit a7bce5f

Browse files
committed
tests
1 parent 67347f6 commit a7bce5f

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

tests/pipelines/wan/test_wan_22.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import gc
15+
import tempfile
1616
import unittest
1717

18+
import numpy as np
1819
import torch
1920
from transformers import AutoTokenizer, T5EncoderModel
20-
import numpy as np
21-
import tempfile
2221

23-
from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel, UniPCMultistepScheduler
22+
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanPipeline, WanTransformer3DModel
2423
from diffusers.utils.testing_utils import (
25-
backend_empty_cache,
2624
enable_full_determinism,
27-
require_torch_accelerator,
28-
slow,
2925
torch_device,
3026
)
3127

3228
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3329
from ..test_pipelines_common import PipelineTesterMixin
3430

3531

36-
37-
3832
enable_full_determinism()
3933

4034

@@ -139,7 +133,9 @@ def test_inference(self):
139133
device = "cpu"
140134

141135
components = self.get_dummy_components()
142-
pipe = self.pipeline_class(**components, )
136+
pipe = self.pipeline_class(
137+
**components,
138+
)
143139
pipe.to(device)
144140
pipe.set_progress_bar_config(disable=None)
145141

@@ -159,14 +155,13 @@ def test_inference(self):
159155
@unittest.skip("Test not supported")
160156
def test_attention_slicing_forward_pass(self):
161157
pass
162-
163-
def test_save_load_optional_components(self, expected_max_difference=1e-4):
164158

159+
def test_save_load_optional_components(self, expected_max_difference=1e-4):
165160
optional_component = "transformer"
166161

167162
components = self.get_dummy_components()
168163
components[optional_component] = None
169-
components["boundary_ratio"] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0
164+
components["boundary_ratio"] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0
170165

171166
pipe = self.pipeline_class(**components)
172167
for component in pipe.components.values():
@@ -189,9 +184,10 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
189184
pipe_loaded.to(torch_device)
190185
pipe_loaded.set_progress_bar_config(disable=None)
191186

192-
self.assertTrue(getattr(pipe_loaded, "transformer") is None,
193-
f"`transformer` did not stay set to None after loading.",
194-
)
187+
self.assertTrue(
188+
getattr(pipe_loaded, "transformer") is None,
189+
"`transformer` did not stay set to None after loading.",
190+
)
195191

196192
inputs = self.get_dummy_inputs(generator_device)
197193
torch.manual_seed(0)
@@ -201,7 +197,6 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
201197
self.assertLess(max_diff, expected_max_difference)
202198

203199

204-
205200
class Wan225BPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
206201
pipeline_class = WanPipeline
207202
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
@@ -230,8 +225,8 @@ def get_dummy_components(self):
230225
out_channels=12,
231226
is_residual=True,
232227
patch_size=2,
233-
latents_mean = [0.0] * 48,
234-
latents_std = [1.0] * 48,
228+
latents_mean=[0.0] * 48,
229+
latents_std=[1.0] * 48,
235230
dim_mult=[1, 1, 1, 1],
236231
num_res_blocks=1,
237232
scale_factor_spatial=16,
@@ -295,7 +290,9 @@ def test_inference(self):
295290
device = "cpu"
296291

297292
components = self.get_dummy_components()
298-
pipe = self.pipeline_class(**components, )
293+
pipe = self.pipeline_class(
294+
**components,
295+
)
299296
pipe.to(device)
300297
pipe.set_progress_bar_config(disable=None)
301298

@@ -311,7 +308,10 @@ def test_inference(self):
311308

312309
generated_slice = generated_video.flatten()
313310
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
314-
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3), f"generated_slice: {generated_slice}, expected_slice: {expected_slice}")
311+
self.assertTrue(
312+
torch.allclose(generated_slice, expected_slice, atol=1e-3),
313+
f"generated_slice: {generated_slice}, expected_slice: {expected_slice}",
314+
)
315315

316316
@unittest.skip("Test not supported")
317317
def test_attention_slicing_forward_pass(self):
@@ -327,7 +327,6 @@ def test_components_function(self):
327327
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
328328

329329
def test_save_load_optional_components(self, expected_max_difference=1e-4):
330-
331330
optional_component = "transformer_2"
332331

333332
components = self.get_dummy_components()
@@ -353,16 +352,17 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
353352
pipe_loaded.to(torch_device)
354353
pipe_loaded.set_progress_bar_config(disable=None)
355354

356-
self.assertTrue(getattr(pipe_loaded, optional_component) is None,
357-
f"`{optional_component}` did not stay set to None after loading.",
358-
)
355+
self.assertTrue(
356+
getattr(pipe_loaded, optional_component) is None,
357+
f"`{optional_component}` did not stay set to None after loading.",
358+
)
359359

360360
inputs = self.get_dummy_inputs(generator_device)
361361
torch.manual_seed(0)
362362
output_loaded = pipe_loaded(**inputs)[0]
363363

364364
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
365365
self.assertLess(max_diff, expected_max_difference)
366-
366+
367367
def test_inference_batch_single_identical(self):
368-
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
368+
self._test_inference_batch_single_identical(expected_max_diff=2e-3)

0 commit comments

Comments
 (0)