Skip to content

Commit a9655b0

Browse files
Look deeper for Dots
1 parent aedc8b4 commit a9655b0

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def local_0_dot_x(fgraph, node):
150150

151151

152152
@register_stabilize
153-
@register_specialize
154153
@node_rewriter([Blockwise])
155154
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
156155
r"""
@@ -164,7 +163,9 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
164163
return
165164

166165
# Check that the BlockDiagonal is an input to a Dot node:
167-
for client in get_clients_at_depth(fgraph, node, depth=1):
166+
for client in itertools.chain.from_iterable(
167+
get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2]
168+
):
168169
if not (
169170
(
170171
isinstance(client.op, Dot)
@@ -179,7 +180,17 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
179180

180181
dot_op = client.op
181182

182-
client_idx = client.inputs.index(blockdiag_result)
183+
try:
184+
client_idx = client.inputs.index(blockdiag_result)
185+
except ValueError:
186+
# If the blockdiag result is not an input to the dot, there is at least one Op between them (usually a
187+
# DimShuffle). In this case, we need to figure out which of the inputs of the dot eventually leads to the
188+
# blockdiag result.
189+
for ancestor in client.inputs:
190+
if ancestor.owner and blockdiag_result in ancestor.owner.inputs:
191+
client_idx = client.inputs.index(ancestor)
192+
break
193+
183194
other_input = client.inputs[1 - client_idx]
184195

185196
split_axis = -2 if client_idx == 0 else -1

tests/tensor/rewriting/test_math.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4794,14 +4794,26 @@ def has_blockdiag(graph):
47944794
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
47954795
assert not has_blockdiag(fn.maker.fgraph.outputs[0])
47964796

4797+
n_dots_rewrite = sum(
4798+
isinstance(node.op, Dot | Dot22)
4799+
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
4800+
for node in fn.maker.fgraph.apply_nodes
4801+
)
4802+
assert n_dots_rewrite == 3
4803+
47974804
fn_expected = pytensor.function(
47984805
[a, b, c, d],
47994806
out,
48004807
mode=Mode(linker="py", optimizer=None),
48014808
)
48024809
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
48034810

4804-
# TODO: Count Dots
4811+
n_dots_no_rewrite = sum(
4812+
isinstance(node.op, Dot | Dot22)
4813+
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
4814+
for node in fn_expected.maker.fgraph.apply_nodes
4815+
)
4816+
assert n_dots_no_rewrite == 1
48054817

48064818
rng = np.random.default_rng()
48074819
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)

0 commit comments

Comments
 (0)