Skip to content

Commit 558acbc

Browse files
use continue on rewrite failures when checking clients
1 parent 3ab2bb0 commit 558acbc

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,13 @@ def check_for_block_diag(x):
172172
# Check that the BlockDiagonal is an input to a Dot node:
173173
for client in get_clients_at_depth(fgraph, node, depth=1):
174174
if not isinstance(client.op, Dot):
175-
return
175+
continue
176176

177177
op = client.op
178178
x, y = client.inputs
179179

180180
if not (check_for_block_diag(x) or check_for_block_diag(y)):
181-
return None
181+
continue
182182

183183
# Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
184184
# non-block diagonal, and return a new block diagonal
@@ -214,7 +214,7 @@ def check_for_block_diag(x):
214214
else:
215215
# TODO: If shapes are statically known and all components have equal shapes, we could rewrite
216216
# this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
217-
return None
217+
continue
218218

219219
copy_stack_trace(node.outputs[0], new_output)
220220
return {client.outputs[0]: new_output}

0 commit comments

Comments
 (0)