Skip to content

Commit 68591ad

Browse files
committed
Cleanup test
1 parent 9455b86 commit 68591ad

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

tests/tensor/rewriting/test_math.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4659,41 +4659,56 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
46594659

46604660
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
46614661
@pytest.mark.parametrize(
4662-
"batch_left", [True, False], ids=["batched_left", "unbatched_left"]
4662+
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
46634663
)
46644664
@pytest.mark.parametrize(
4665-
"batch_right", [True, False], ids=["batched_right", "unbatched_right"]
4665+
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
46664666
)
4667-
def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch_right):
4667+
def test_local_block_diag_dot_to_dot_block_diag(
4668+
left_multiply, batch_blockdiag, batch_other
4669+
):
46684670
"""
46694671
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
46704672
"""
4673+
4674+
def has_blockdiag(graph):
4675+
return any(
4676+
(
4677+
var.owner
4678+
and (
4679+
isinstance(var.owner.op, BlockDiagonal)
4680+
or (
4681+
isinstance(var.owner.op, Blockwise)
4682+
and isinstance(var.owner.op.core_op, BlockDiagonal)
4683+
)
4684+
)
4685+
)
4686+
for var in ancestors([graph])
4687+
)
4688+
46714689
a = tensor("a", shape=(4, 2))
4672-
b = tensor("b", shape=(2, 4) if not batch_left else (3, 2, 4))
4690+
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
46734691
c = tensor("c", shape=(4, 4))
4674-
d = tensor("d", shape=(10, 10))
4675-
e = tensor("e", shape=(10, 10) if not batch_right else (3, 1, 10, 10))
4676-
46774692
x = pt.linalg.block_diag(a, b, c)
46784693

4694+
d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))
4695+
46794696
# Test multiple clients are all rewritten
46804697
if left_multiply:
4681-
out = [x @ d, x @ e]
4698+
out = x @ d
46824699
else:
4683-
out = [d @ x, e @ x]
4700+
out = d @ x
46844701

4685-
with config.change_flags(optimizer_verbose=True):
4686-
fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode)
4687-
4688-
assert not any(
4689-
isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort()
4690-
)
4702+
assert has_blockdiag(out)
4703+
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
4704+
assert not has_blockdiag(fn.maker.fgraph.outputs[0])
46914705

46924706
fn_expected = pytensor.function(
4693-
[a, b, c, d, e],
4707+
[a, b, c, d],
46944708
out,
46954709
mode=Mode(linker="py", optimizer=None),
46964710
)
4711+
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
46974712

46984713
# TODO: Count Dots
46994714

@@ -4702,18 +4717,15 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch
47024717
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
47034718
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
47044719
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4705-
e_val = rng.normal(size=e.type.shape).astype(e.type.dtype)
47064720

4707-
rewrite_outs = fn(a_val, b_val, c_val, d_val, e_val)
4708-
expected_outs = fn_expected(a_val, b_val, c_val, d_val, e_val)
4709-
4710-
for out, expected in zip(rewrite_outs, expected_outs):
4711-
np.testing.assert_allclose(
4712-
out,
4713-
expected,
4714-
atol=1e-6 if config.floatX == "float32" else 1e-12,
4715-
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4716-
)
4721+
rewrite_out = fn(a_val, b_val, c_val, d_val)
4722+
expected_out = fn_expected(a_val, b_val, c_val, d_val)
4723+
np.testing.assert_allclose(
4724+
rewrite_out,
4725+
expected_out,
4726+
atol=1e-6 if config.floatX == "float32" else 1e-12,
4727+
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4728+
)
47174729

47184730

47194731
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])

0 commit comments

Comments
 (0)