Skip to content

Commit 9db9be6

Browse files
authored
[tests] Add fast test slices for HiDream-Image (#11953)
update
1 parent d87134a commit 9db9be6

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/pipelines/hidream_image/test_pipeline_hidream.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,15 @@ def test_inference(self):
146146
inputs = self.get_dummy_inputs(device)
147147
image = pipe(**inputs)[0]
148148
generated_image = image[0]
149-
150149
self.assertEqual(generated_image.shape, (128, 128, 3))
151-
expected_image = torch.randn(128, 128, 3).numpy()
152-
max_diff = np.abs(generated_image - expected_image).max()
153-
self.assertLessEqual(max_diff, 1e10)
150+
151+
# fmt: off
152+
expected_slice = np.array([0.4507, 0.5256, 0.4205, 0.5791, 0.4848, 0.4831, 0.4443, 0.5107, 0.6586, 0.3163, 0.7318, 0.5933, 0.6252, 0.5512, 0.5357, 0.5983])
153+
# fmt: on
154+
155+
generated_slice = generated_image.flatten()
156+
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
157+
self.assertTrue(np.allclose(generated_slice, expected_slice, atol=1e-3))
154158

155159
def test_inference_batch_single_identical(self):
156160
super().test_inference_batch_single_identical(expected_max_diff=3e-4)

0 commit comments

Comments
 (0)