We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
_matmul
local_block_diag_dot_to_dot_block_diag
1 parent 993fb64 commit 7c3820bCopy full SHA for 7c3820b
pytensor/tensor/rewriting/math.py
@@ -38,6 +38,7 @@
38
zeros,
39
zeros_like,
40
)
41
+from pytensor.tensor.blockwise import Blockwise
42
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
43
from pytensor.tensor.exceptions import NotScalarConstantError
44
from pytensor.tensor.extra_ops import broadcast_arrays
@@ -169,7 +170,7 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
169
170
isinstance(client.op, Dot)
171
and all(input.ndim == 2 for input in client.inputs)
172
- or client.op == _matrix_matrix_matmul
173
+ or client.op == _matmul
174
):
175
continue
176
0 commit comments