@@ -4659,41 +4659,56 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
4659
4659
4660
4660
@pytest .mark .parametrize ("left_multiply" , [True , False ], ids = ["left" , "right" ])
4661
4661
@pytest .mark .parametrize (
4662
- "batch_left " , [True , False ], ids = ["batched_left " , "unbatched_left " ]
4662
+ "batch_blockdiag " , [True , False ], ids = ["batch_blockdiag " , "unbatched_blockdiag " ]
4663
4663
)
4664
4664
@pytest .mark .parametrize (
4665
- "batch_right " , [True , False ], ids = ["batched_right " , "unbatched_right " ]
4665
+ "batch_other " , [True , False ], ids = ["batched_other " , "unbatched_other " ]
4666
4666
)
4667
- def test_local_block_diag_dot_to_dot_block_diag (left_multiply , batch_left , batch_right ):
4667
+ def test_local_block_diag_dot_to_dot_block_diag (
4668
+ left_multiply , batch_blockdiag , batch_other
4669
+ ):
4668
4670
"""
4669
4671
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
4670
4672
"""
4673
+
4674
+ def has_blockdiag (graph ):
4675
+ return any (
4676
+ (
4677
+ var .owner
4678
+ and (
4679
+ isinstance (var .owner .op , BlockDiagonal )
4680
+ or (
4681
+ isinstance (var .owner .op , Blockwise )
4682
+ and isinstance (var .owner .op .core_op , BlockDiagonal )
4683
+ )
4684
+ )
4685
+ )
4686
+ for var in ancestors ([graph ])
4687
+ )
4688
+
4671
4689
a = tensor ("a" , shape = (4 , 2 ))
4672
- b = tensor ("b" , shape = (2 , 4 ) if not batch_left else (3 , 2 , 4 ))
4690
+ b = tensor ("b" , shape = (2 , 4 ) if not batch_blockdiag else (3 , 2 , 4 ))
4673
4691
c = tensor ("c" , shape = (4 , 4 ))
4674
- d = tensor ("d" , shape = (10 , 10 ))
4675
- e = tensor ("e" , shape = (10 , 10 ) if not batch_right else (3 , 1 , 10 , 10 ))
4676
-
4677
4692
x = pt .linalg .block_diag (a , b , c )
4678
4693
4694
+ d = tensor ("d" , shape = (10 , 10 ) if not batch_other else (3 , 1 , 10 , 10 ))
4695
+
4679
4696
# Test multiple clients are all rewritten
4680
4697
if left_multiply :
4681
- out = [ x @ d , x @ e ]
4698
+ out = x @ d
4682
4699
else :
4683
- out = [ d @ x , e @ x ]
4700
+ out = d @ x
4684
4701
4685
- with config .change_flags (optimizer_verbose = True ):
4686
- fn = pytensor .function ([a , b , c , d , e ], out , mode = rewrite_mode )
4687
-
4688
- assert not any (
4689
- isinstance (node .op , BlockDiagonal ) for node in fn .maker .fgraph .toposort ()
4690
- )
4702
+ assert has_blockdiag (out )
4703
+ fn = pytensor .function ([a , b , c , d ], out , mode = rewrite_mode )
4704
+ assert not has_blockdiag (fn .maker .fgraph .outputs [0 ])
4691
4705
4692
4706
fn_expected = pytensor .function (
4693
- [a , b , c , d , e ],
4707
+ [a , b , c , d ],
4694
4708
out ,
4695
4709
mode = Mode (linker = "py" , optimizer = None ),
4696
4710
)
4711
+ assert has_blockdiag (fn_expected .maker .fgraph .outputs [0 ])
4697
4712
4698
4713
# TODO: Count Dots
4699
4714
@@ -4702,18 +4717,15 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch
4702
4717
b_val = rng .normal (size = b .type .shape ).astype (b .type .dtype )
4703
4718
c_val = rng .normal (size = c .type .shape ).astype (c .type .dtype )
4704
4719
d_val = rng .normal (size = d .type .shape ).astype (d .type .dtype )
4705
- e_val = rng .normal (size = e .type .shape ).astype (e .type .dtype )
4706
4720
4707
- rewrite_outs = fn (a_val , b_val , c_val , d_val , e_val )
4708
- expected_outs = fn_expected (a_val , b_val , c_val , d_val , e_val )
4709
-
4710
- for out , expected in zip (rewrite_outs , expected_outs ):
4711
- np .testing .assert_allclose (
4712
- out ,
4713
- expected ,
4714
- atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4715
- rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4716
- )
4721
+ rewrite_out = fn (a_val , b_val , c_val , d_val )
4722
+ expected_out = fn_expected (a_val , b_val , c_val , d_val )
4723
+ np .testing .assert_allclose (
4724
+ rewrite_out ,
4725
+ expected_out ,
4726
+ atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4727
+ rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4728
+ )
4717
4729
4718
4730
4719
4731
@pytest .mark .parametrize ("rewrite" , [True , False ], ids = ["rewrite" , "no_rewrite" ])
0 commit comments