@@ -4750,41 +4750,56 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
4750
4750
4751
4751
@pytest .mark .parametrize ("left_multiply" , [True , False ], ids = ["left" , "right" ])
4752
4752
@pytest .mark .parametrize (
4753
- "batch_left " , [True , False ], ids = ["batched_left " , "unbatched_left " ]
4753
+ "batch_blockdiag " , [True , False ], ids = ["batch_blockdiag " , "unbatched_blockdiag " ]
4754
4754
)
4755
4755
@pytest .mark .parametrize (
4756
- "batch_right " , [True , False ], ids = ["batched_right " , "unbatched_right " ]
4756
+ "batch_other " , [True , False ], ids = ["batched_other " , "unbatched_other " ]
4757
4757
)
4758
- def test_local_block_diag_dot_to_dot_block_diag (left_multiply , batch_left , batch_right ):
4758
+ def test_local_block_diag_dot_to_dot_block_diag (
4759
+ left_multiply , batch_blockdiag , batch_other
4760
+ ):
4759
4761
"""
4760
4762
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
4761
4763
"""
4764
+
4765
+ def has_blockdiag (graph ):
4766
+ return any (
4767
+ (
4768
+ var .owner
4769
+ and (
4770
+ isinstance (var .owner .op , BlockDiagonal )
4771
+ or (
4772
+ isinstance (var .owner .op , Blockwise )
4773
+ and isinstance (var .owner .op .core_op , BlockDiagonal )
4774
+ )
4775
+ )
4776
+ )
4777
+ for var in ancestors ([graph ])
4778
+ )
4779
+
4762
4780
a = tensor ("a" , shape = (4 , 2 ))
4763
- b = tensor ("b" , shape = (2 , 4 ) if not batch_left else (3 , 2 , 4 ))
4781
+ b = tensor ("b" , shape = (2 , 4 ) if not batch_blockdiag else (3 , 2 , 4 ))
4764
4782
c = tensor ("c" , shape = (4 , 4 ))
4765
- d = tensor ("d" , shape = (10 , 10 ))
4766
- e = tensor ("e" , shape = (10 , 10 ) if not batch_right else (3 , 1 , 10 , 10 ))
4767
-
4768
4783
x = pt .linalg .block_diag (a , b , c )
4769
4784
4785
+ d = tensor ("d" , shape = (10 , 10 ) if not batch_other else (3 , 1 , 10 , 10 ))
4786
+
4770
4787
# Test multiple clients are all rewritten
4771
4788
if left_multiply :
4772
- out = [ x @ d , x @ e ]
4789
+ out = x @ d
4773
4790
else :
4774
- out = [ d @ x , e @ x ]
4791
+ out = d @ x
4775
4792
4776
- with config .change_flags (optimizer_verbose = True ):
4777
- fn = pytensor .function ([a , b , c , d , e ], out , mode = rewrite_mode )
4778
-
4779
- assert not any (
4780
- isinstance (node .op , BlockDiagonal ) for node in fn .maker .fgraph .toposort ()
4781
- )
4793
+ assert has_blockdiag (out )
4794
+ fn = pytensor .function ([a , b , c , d ], out , mode = rewrite_mode )
4795
+ assert not has_blockdiag (fn .maker .fgraph .outputs [0 ])
4782
4796
4783
4797
fn_expected = pytensor .function (
4784
- [a , b , c , d , e ],
4798
+ [a , b , c , d ],
4785
4799
out ,
4786
4800
mode = Mode (linker = "py" , optimizer = None ),
4787
4801
)
4802
+ assert has_blockdiag (fn_expected .maker .fgraph .outputs [0 ])
4788
4803
4789
4804
# TODO: Count Dots
4790
4805
@@ -4793,18 +4808,15 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch
4793
4808
b_val = rng .normal (size = b .type .shape ).astype (b .type .dtype )
4794
4809
c_val = rng .normal (size = c .type .shape ).astype (c .type .dtype )
4795
4810
d_val = rng .normal (size = d .type .shape ).astype (d .type .dtype )
4796
- e_val = rng .normal (size = e .type .shape ).astype (e .type .dtype )
4797
4811
4798
- rewrite_outs = fn (a_val , b_val , c_val , d_val , e_val )
4799
- expected_outs = fn_expected (a_val , b_val , c_val , d_val , e_val )
4800
-
4801
- for out , expected in zip (rewrite_outs , expected_outs ):
4802
- np .testing .assert_allclose (
4803
- out ,
4804
- expected ,
4805
- atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4806
- rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4807
- )
4812
+ rewrite_out = fn (a_val , b_val , c_val , d_val )
4813
+ expected_out = fn_expected (a_val , b_val , c_val , d_val )
4814
+ np .testing .assert_allclose (
4815
+ rewrite_out ,
4816
+ expected_out ,
4817
+ atol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4818
+ rtol = 1e-6 if config .floatX == "float32" else 1e-12 ,
4819
+ )
4808
4820
4809
4821
4810
4822
@pytest .mark .parametrize ("rewrite" , [True , False ], ids = ["rewrite" , "no_rewrite" ])
0 commit comments