Skip to content

Commit d87134a

Browse files
authored
[tests] Add test slices for Cosmos (#11955)
* test * try fix
1 parent 67a8ec8 commit d87134a

File tree

4 files changed

+32
-16
lines changed

4 files changed

+32
-16
lines changed

tests/pipelines/cosmos/test_cosmos.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,15 @@ def test_inference(self):
153153
inputs = self.get_dummy_inputs(device)
154154
video = pipe(**inputs).frames
155155
generated_video = video[0]
156-
157156
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
158-
expected_video = torch.randn(9, 3, 32, 32)
159-
max_diff = np.abs(generated_video - expected_video).max()
160-
self.assertLessEqual(max_diff, 1e10)
157+
158+
# fmt: off
159+
expected_slice = torch.tensor([0.0, 0.9686, 0.8549, 0.8078, 0.0, 0.8431, 1.0, 0.4863, 0.7098, 0.1098, 0.8157, 0.4235, 0.6353, 0.2549, 0.5137, 0.5333])
160+
# fmt: on
161+
162+
generated_slice = generated_video.flatten()
163+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
164+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
161165

162166
def test_callback_inputs(self):
163167
sig = inspect.signature(self.pipeline_class.__call__)

tests/pipelines/cosmos/test_cosmos2_text2image.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,15 @@ def test_inference(self):
140140
inputs = self.get_dummy_inputs(device)
141141
image = pipe(**inputs).images
142142
generated_image = image[0]
143-
144143
self.assertEqual(generated_image.shape, (3, 32, 32))
145-
expected_video = torch.randn(3, 32, 32)
146-
max_diff = np.abs(generated_image - expected_video).max()
147-
self.assertLessEqual(max_diff, 1e10)
144+
145+
# fmt: off
146+
expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.4784, 0.4784, 0.4784, 0.4784, 0.4784, 0.4902, 0.4588, 0.5333])
147+
# fmt: on
148+
149+
generated_slice = generated_image.flatten()
150+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
151+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
148152

149153
def test_callback_inputs(self):
150154
sig = inspect.signature(self.pipeline_class.__call__)

tests/pipelines/cosmos/test_cosmos2_video2world.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,15 @@ def test_inference(self):
147147
inputs = self.get_dummy_inputs(device)
148148
video = pipe(**inputs).frames
149149
generated_video = video[0]
150-
151150
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
152-
expected_video = torch.randn(9, 3, 32, 32)
153-
max_diff = np.abs(generated_video - expected_video).max()
154-
self.assertLessEqual(max_diff, 1e10)
151+
152+
# fmt: off
153+
expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.5098, 0.5137, 0.5176, 0.5098, 0.5255, 0.5412, 0.5098, 0.5059])
154+
# fmt: on
155+
156+
generated_slice = generated_video.flatten()
157+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
158+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
155159

156160
def test_components_function(self):
157161
init_components = self.get_dummy_components()

tests/pipelines/cosmos/test_cosmos_video2world.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,15 @@ def test_inference(self):
159159
inputs = self.get_dummy_inputs(device)
160160
video = pipe(**inputs).frames
161161
generated_video = video[0]
162-
163162
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
164-
expected_video = torch.randn(9, 3, 32, 32)
165-
max_diff = np.abs(generated_video - expected_video).max()
166-
self.assertLessEqual(max_diff, 1e10)
163+
164+
# fmt: off
165+
expected_slice = torch.tensor([0.0, 0.8275, 0.7529, 0.7294, 0.0, 0.6, 1.0, 0.3804, 0.6667, 0.0863, 0.8784, 0.5922, 0.6627, 0.2784, 0.5725, 0.7765])
166+
# fmt: on
167+
168+
generated_slice = generated_video.flatten()
169+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
170+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
167171

168172
def test_components_function(self):
169173
init_components = self.get_dummy_components()

0 commit comments

Comments
 (0)