Skip to content

Commit 67a8ec8

Browse files
authored
[tests] Add test slices for Hunyuan Video (#11954)
update
1 parent cde02b0 commit 67a8ec8

File tree

4 files changed

+45
-20
lines changed

4 files changed

+45
-20
lines changed

tests/pipelines/hunyuan_video/test_hunyuan_image2video.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,19 @@ def test_inference(self):
229229
inputs = self.get_dummy_inputs(device)
230230
video = pipe(**inputs).frames
231231
generated_video = video[0]
232-
233232
# NOTE: The expected video has 4 lesser frames because they are dropped in the pipeline
234233
self.assertEqual(generated_video.shape, (5, 3, 16, 16))
235-
expected_video = torch.randn(5, 3, 16, 16)
236-
max_diff = np.abs(generated_video - expected_video).max()
237-
self.assertLessEqual(max_diff, 1e10)
234+
235+
# fmt: off
236+
expected_slice = torch.tensor([0.444, 0.479, 0.4485, 0.5752, 0.3539, 0.1548, 0.2706, 0.3593, 0.5323, 0.6635, 0.6795, 0.5255, 0.5091, 0.345, 0.4276, 0.4128])
237+
# fmt: on
238+
239+
generated_slice = generated_video.flatten()
240+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
241+
self.assertTrue(
242+
torch.allclose(generated_slice, expected_slice, atol=1e-3),
243+
"The generated video does not match the expected slice.",
244+
)
238245

239246
def test_callback_inputs(self):
240247
sig = inspect.signature(self.pipeline_class.__call__)

tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,18 @@ def test_inference(self):
192192
inputs = self.get_dummy_inputs(device)
193193
video = pipe(**inputs).frames
194194
generated_video = video[0]
195-
196195
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
197-
expected_video = torch.randn(9, 3, 16, 16)
198-
max_diff = np.abs(generated_video - expected_video).max()
199-
self.assertLessEqual(max_diff, 1e10)
196+
197+
# fmt: off
198+
expected_slice = torch.tensor([0.5832, 0.5498, 0.4839, 0.4744, 0.4515, 0.4832, 0.496, 0.563, 0.5918, 0.5979, 0.5101, 0.6168, 0.6613, 0.536, 0.55, 0.5775])
199+
# fmt: on
200+
201+
generated_slice = generated_video.flatten()
202+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
203+
self.assertTrue(
204+
torch.allclose(generated_slice, expected_slice, atol=1e-3),
205+
"The generated video does not match the expected slice.",
206+
)
200207

201208
def test_callback_inputs(self):
202209
sig = inspect.signature(self.pipeline_class.__call__)

tests/pipelines/hunyuan_video/test_hunyuan_video.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@
2626
HunyuanVideoPipeline,
2727
HunyuanVideoTransformer3DModel,
2828
)
29-
from diffusers.utils.testing_utils import (
30-
enable_full_determinism,
31-
torch_device,
32-
)
29+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
3330

3431
from ..test_pipelines_common import (
3532
FasterCacheTesterMixin,
@@ -206,11 +203,18 @@ def test_inference(self):
206203
inputs = self.get_dummy_inputs(device)
207204
video = pipe(**inputs).frames
208205
generated_video = video[0]
209-
210206
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
211-
expected_video = torch.randn(9, 3, 16, 16)
212-
max_diff = np.abs(generated_video - expected_video).max()
213-
self.assertLessEqual(max_diff, 1e10)
207+
208+
# fmt: off
209+
expected_slice = torch.tensor([0.3946, 0.4649, 0.3196, 0.4569, 0.3312, 0.3687, 0.3216, 0.3972, 0.4469, 0.3888, 0.3929, 0.3802, 0.3479, 0.3888, 0.3825, 0.3542])
210+
# fmt: on
211+
212+
generated_slice = generated_video.flatten()
213+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
214+
self.assertTrue(
215+
torch.allclose(generated_slice, expected_slice, atol=1e-3),
216+
"The generated video does not match the expected slice.",
217+
)
214218

215219
def test_callback_inputs(self):
216220
sig = inspect.signature(self.pipeline_class.__call__)

tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,18 @@ def test_inference(self):
227227
inputs = self.get_dummy_inputs(device)
228228
video = pipe(**inputs).frames
229229
generated_video = video[0]
230-
231230
self.assertEqual(generated_video.shape, (13, 3, 32, 32))
232-
expected_video = torch.randn(13, 3, 32, 32)
233-
max_diff = np.abs(generated_video - expected_video).max()
234-
self.assertLessEqual(max_diff, 1e10)
231+
232+
# fmt: off
233+
expected_slice = torch.tensor([0.363, 0.3384, 0.3426, 0.3512, 0.3372, 0.3276, 0.417, 0.4061, 0.5221, 0.467, 0.4813, 0.4556, 0.4107, 0.3945, 0.4049, 0.4551])
234+
# fmt: on
235+
236+
generated_slice = generated_video.flatten()
237+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
238+
self.assertTrue(
239+
torch.allclose(generated_slice, expected_slice, atol=1e-3),
240+
"The generated video does not match the expected slice.",
241+
)
235242

236243
def test_callback_inputs(self):
237244
sig = inspect.signature(self.pipeline_class.__call__)

0 commit comments

Comments
 (0)