Skip to content

Commit 3c50d2a

Browse files
Use concat_with_broadcast in join split dot results
1 parent b084a3e commit 3c50d2a

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
constant,
3030
expand_dims,
3131
get_underlying_scalar_constant_value,
32-
join,
3332
moveaxis,
3433
ones_like,
3534
register_infer_shape,
@@ -41,7 +40,7 @@
4140
from pytensor.tensor.blockwise import Blockwise
4241
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
4342
from pytensor.tensor.exceptions import NotScalarConstantError
44-
from pytensor.tensor.extra_ops import broadcast_arrays
43+
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
4544
from pytensor.tensor.math import (
4645
Dot,
4746
Prod,
@@ -151,6 +150,7 @@ def local_0_dot_x(fgraph, node):
151150

152151

153152
@register_stabilize
153+
@register_specialize
154154
@node_rewriter([Blockwise])
155155
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
156156
r"""
@@ -174,29 +174,31 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
174174
):
175175
continue
176176

177-
op = client.op
177+
[blockdiag_result] = node.outputs
178+
blockdiag_inputs = node.inputs
178179

179-
client_idx = client.inputs.index(node.outputs[0])
180+
dot_op = client.op
180181

182+
client_idx = client.inputs.index(blockdiag_result)
181183
other_input = client.inputs[1 - client_idx]
182-
components = node.inputs
183184

184185
split_axis = -2 if client_idx == 0 else -1
185186
shape_idx = -1 if client_idx == 0 else -2
186187

187188
other_dot_input_split = split(
188189
other_input,
189-
splits_size=[component.shape[shape_idx] for component in components],
190-
n_splits=len(components),
190+
splits_size=[component.shape[shape_idx] for component in blockdiag_inputs],
191+
n_splits=len(blockdiag_inputs),
191192
axis=split_axis,
192193
)
193-
new_components = [
194-
op(component, other_split)
194+
195+
split_dot_results = [
196+
dot_op(component, other_split)
195197
if client_idx == 0
196-
else op(other_split, component)
197-
for component, other_split in zip(components, other_dot_input_split)
198+
else dot_op(other_split, component)
199+
for component, other_split in zip(blockdiag_inputs, other_dot_input_split)
198200
]
199-
new_output = join(split_axis, *new_components)
201+
new_output = concat_with_broadcast(split_dot_results, dim=split_axis)
200202

201203
copy_stack_trace(node.outputs[0], new_output)
202204
return {client.outputs[0]: new_output}

0 commit comments

Comments
 (0)