12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import gc
15
+ import tempfile
16
16
import unittest
17
17
18
+ import numpy as np
18
19
import torch
19
20
from transformers import AutoTokenizer , T5EncoderModel
20
- import numpy as np
21
- import tempfile
22
21
23
- from diffusers import AutoencoderKLWan , WanPipeline , WanTransformer3DModel , UniPCMultistepScheduler
22
+ from diffusers import AutoencoderKLWan , UniPCMultistepScheduler , WanPipeline , WanTransformer3DModel
24
23
from diffusers .utils .testing_utils import (
25
- backend_empty_cache ,
26
24
enable_full_determinism ,
27
- require_torch_accelerator ,
28
- slow ,
29
25
torch_device ,
30
26
)
31
27
32
28
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS , TEXT_TO_IMAGE_IMAGE_PARAMS , TEXT_TO_IMAGE_PARAMS
33
29
from ..test_pipelines_common import PipelineTesterMixin
34
30
35
31
36
-
37
-
38
32
enable_full_determinism ()
39
33
40
34
@@ -139,7 +133,9 @@ def test_inference(self):
139
133
device = "cpu"
140
134
141
135
components = self .get_dummy_components ()
142
- pipe = self .pipeline_class (** components , )
136
+ pipe = self .pipeline_class (
137
+ ** components ,
138
+ )
143
139
pipe .to (device )
144
140
pipe .set_progress_bar_config (disable = None )
145
141
@@ -159,14 +155,13 @@ def test_inference(self):
159
155
@unittest .skip ("Test not supported" )
160
156
def test_attention_slicing_forward_pass (self ):
161
157
pass
162
-
163
- def test_save_load_optional_components (self , expected_max_difference = 1e-4 ):
164
158
159
+ def test_save_load_optional_components (self , expected_max_difference = 1e-4 ):
165
160
optional_component = "transformer"
166
161
167
162
components = self .get_dummy_components ()
168
163
components [optional_component ] = None
169
- components ["boundary_ratio" ] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0
164
+ components ["boundary_ratio" ] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0
170
165
171
166
pipe = self .pipeline_class (** components )
172
167
for component in pipe .components .values ():
@@ -189,9 +184,10 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
189
184
pipe_loaded .to (torch_device )
190
185
pipe_loaded .set_progress_bar_config (disable = None )
191
186
192
- self .assertTrue (getattr (pipe_loaded , "transformer" ) is None ,
193
- f"`transformer` did not stay set to None after loading." ,
194
- )
187
+ self .assertTrue (
188
+ getattr (pipe_loaded , "transformer" ) is None ,
189
+ "`transformer` did not stay set to None after loading." ,
190
+ )
195
191
196
192
inputs = self .get_dummy_inputs (generator_device )
197
193
torch .manual_seed (0 )
@@ -201,7 +197,6 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
201
197
self .assertLess (max_diff , expected_max_difference )
202
198
203
199
204
-
205
200
class Wan225BPipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
206
201
pipeline_class = WanPipeline
207
202
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs" }
@@ -230,8 +225,8 @@ def get_dummy_components(self):
230
225
out_channels = 12 ,
231
226
is_residual = True ,
232
227
patch_size = 2 ,
233
- latents_mean = [0.0 ] * 48 ,
234
- latents_std = [1.0 ] * 48 ,
228
+ latents_mean = [0.0 ] * 48 ,
229
+ latents_std = [1.0 ] * 48 ,
235
230
dim_mult = [1 , 1 , 1 , 1 ],
236
231
num_res_blocks = 1 ,
237
232
scale_factor_spatial = 16 ,
@@ -295,7 +290,9 @@ def test_inference(self):
295
290
device = "cpu"
296
291
297
292
components = self .get_dummy_components ()
298
- pipe = self .pipeline_class (** components , )
293
+ pipe = self .pipeline_class (
294
+ ** components ,
295
+ )
299
296
pipe .to (device )
300
297
pipe .set_progress_bar_config (disable = None )
301
298
@@ -311,7 +308,10 @@ def test_inference(self):
311
308
312
309
generated_slice = generated_video .flatten ()
313
310
generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
314
- self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-3 ), f"generated_slice: { generated_slice } , expected_slice: { expected_slice } " )
311
+ self .assertTrue (
312
+ torch .allclose (generated_slice , expected_slice , atol = 1e-3 ),
313
+ f"generated_slice: { generated_slice } , expected_slice: { expected_slice } " ,
314
+ )
315
315
316
316
@unittest .skip ("Test not supported" )
317
317
def test_attention_slicing_forward_pass (self ):
@@ -327,7 +327,6 @@ def test_components_function(self):
327
327
self .assertTrue (set (pipe .components .keys ()) == set (init_components .keys ()))
328
328
329
329
def test_save_load_optional_components (self , expected_max_difference = 1e-4 ):
330
-
331
330
optional_component = "transformer_2"
332
331
333
332
components = self .get_dummy_components ()
@@ -353,16 +352,17 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
353
352
pipe_loaded .to (torch_device )
354
353
pipe_loaded .set_progress_bar_config (disable = None )
355
354
356
- self .assertTrue (getattr (pipe_loaded , optional_component ) is None ,
357
- f"`{ optional_component } ` did not stay set to None after loading." ,
358
- )
355
+ self .assertTrue (
356
+ getattr (pipe_loaded , optional_component ) is None ,
357
+ f"`{ optional_component } ` did not stay set to None after loading." ,
358
+ )
359
359
360
360
inputs = self .get_dummy_inputs (generator_device )
361
361
torch .manual_seed (0 )
362
362
output_loaded = pipe_loaded (** inputs )[0 ]
363
363
364
364
max_diff = np .abs (output .detach ().cpu ().numpy () - output_loaded .detach ().cpu ().numpy ()).max ()
365
365
self .assertLess (max_diff , expected_max_difference )
366
-
366
+
367
367
def test_inference_batch_single_identical (self ):
368
- self ._test_inference_batch_single_identical (expected_max_diff = 2e-3 )
368
+ self ._test_inference_batch_single_identical (expected_max_diff = 2e-3 )
0 commit comments