Skip to content

Commit 7c3820b

Browse files
look for _matmul in local_block_diag_dot_to_dot_block_diag
1 parent 993fb64 commit 7c3820b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
zeros,
3939
zeros_like,
4040
)
41+
from pytensor.tensor.blockwise import Blockwise
4142
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
4243
from pytensor.tensor.exceptions import NotScalarConstantError
4344
from pytensor.tensor.extra_ops import broadcast_arrays
@@ -169,7 +170,7 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
169170
isinstance(client.op, Dot)
170171
and all(input.ndim == 2 for input in client.inputs)
171172
)
172-
or client.op == _matrix_matrix_matmul
173+
or client.op == _matmul
173174
):
174175
continue
175176

0 commit comments

Comments
 (0)