Skip to content

Commit 993fb64

Browse files
ricardoV94jessegrabowski
authored andcommitted
Cleanup test
1 parent 6736e8e commit 993fb64

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

tests/tensor/rewriting/test_math.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4750,41 +4750,56 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
47504750

47514751
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
47524752
@pytest.mark.parametrize(
4753-
"batch_left", [True, False], ids=["batched_left", "unbatched_left"]
4753+
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
47544754
)
47554755
@pytest.mark.parametrize(
4756-
"batch_right", [True, False], ids=["batched_right", "unbatched_right"]
4756+
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
47574757
)
4758-
def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch_right):
4758+
def test_local_block_diag_dot_to_dot_block_diag(
4759+
left_multiply, batch_blockdiag, batch_other
4760+
):
47594761
"""
47604762
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
47614763
"""
4764+
4765+
def has_blockdiag(graph):
4766+
return any(
4767+
(
4768+
var.owner
4769+
and (
4770+
isinstance(var.owner.op, BlockDiagonal)
4771+
or (
4772+
isinstance(var.owner.op, Blockwise)
4773+
and isinstance(var.owner.op.core_op, BlockDiagonal)
4774+
)
4775+
)
4776+
)
4777+
for var in ancestors([graph])
4778+
)
4779+
47624780
a = tensor("a", shape=(4, 2))
4763-
b = tensor("b", shape=(2, 4) if not batch_left else (3, 2, 4))
4781+
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
47644782
c = tensor("c", shape=(4, 4))
4765-
d = tensor("d", shape=(10, 10))
4766-
e = tensor("e", shape=(10, 10) if not batch_right else (3, 1, 10, 10))
4767-
47684783
x = pt.linalg.block_diag(a, b, c)
47694784

4785+
d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))
4786+
47704787
# Test multiple clients are all rewritten
47714788
if left_multiply:
4772-
out = [x @ d, x @ e]
4789+
out = x @ d
47734790
else:
4774-
out = [d @ x, e @ x]
4791+
out = d @ x
47754792

4776-
with config.change_flags(optimizer_verbose=True):
4777-
fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode)
4778-
4779-
assert not any(
4780-
isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort()
4781-
)
4793+
assert has_blockdiag(out)
4794+
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
4795+
assert not has_blockdiag(fn.maker.fgraph.outputs[0])
47824796

47834797
fn_expected = pytensor.function(
4784-
[a, b, c, d, e],
4798+
[a, b, c, d],
47854799
out,
47864800
mode=Mode(linker="py", optimizer=None),
47874801
)
4802+
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
47884803

47894804
# TODO: Count Dots
47904805

@@ -4793,18 +4808,15 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch
47934808
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
47944809
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
47954810
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4796-
e_val = rng.normal(size=e.type.shape).astype(e.type.dtype)
47974811

4798-
rewrite_outs = fn(a_val, b_val, c_val, d_val, e_val)
4799-
expected_outs = fn_expected(a_val, b_val, c_val, d_val, e_val)
4800-
4801-
for out, expected in zip(rewrite_outs, expected_outs):
4802-
np.testing.assert_allclose(
4803-
out,
4804-
expected,
4805-
atol=1e-6 if config.floatX == "float32" else 1e-12,
4806-
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4807-
)
4812+
rewrite_out = fn(a_val, b_val, c_val, d_val)
4813+
expected_out = fn_expected(a_val, b_val, c_val, d_val)
4814+
np.testing.assert_allclose(
4815+
rewrite_out,
4816+
expected_out,
4817+
atol=1e-6 if config.floatX == "float32" else 1e-12,
4818+
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4819+
)
48084820

48094821

48104822
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])

0 commit comments

Comments
 (0)