29
29
constant ,
30
30
expand_dims ,
31
31
get_underlying_scalar_constant_value ,
32
- join ,
33
32
moveaxis ,
34
33
ones_like ,
35
34
register_infer_shape ,
41
40
from pytensor .tensor .blockwise import Blockwise
42
41
from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
43
42
from pytensor .tensor .exceptions import NotScalarConstantError
44
- from pytensor .tensor .extra_ops import broadcast_arrays
43
+ from pytensor .tensor .extra_ops import broadcast_arrays , concat_with_broadcast
45
44
from pytensor .tensor .math import (
46
45
Dot ,
47
46
Prod ,
@@ -151,6 +150,7 @@ def local_0_dot_x(fgraph, node):
151
150
152
151
153
152
@register_stabilize
153
+ @register_specialize
154
154
@node_rewriter ([Blockwise ])
155
155
def local_block_diag_dot_to_dot_block_diag (fgraph , node ):
156
156
r"""
@@ -174,29 +174,31 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
174
174
):
175
175
continue
176
176
177
- op = client .op
177
+ [blockdiag_result ] = node .outputs
178
+ blockdiag_inputs = node .inputs
178
179
179
- client_idx = client .inputs . index ( node . outputs [ 0 ])
180
+ dot_op = client .op
180
181
182
+ client_idx = client .inputs .index (blockdiag_result )
181
183
other_input = client .inputs [1 - client_idx ]
182
- components = node .inputs
183
184
184
185
split_axis = - 2 if client_idx == 0 else - 1
185
186
shape_idx = - 1 if client_idx == 0 else - 2
186
187
187
188
other_dot_input_split = split (
188
189
other_input ,
189
- splits_size = [component .shape [shape_idx ] for component in components ],
190
- n_splits = len (components ),
190
+ splits_size = [component .shape [shape_idx ] for component in blockdiag_inputs ],
191
+ n_splits = len (blockdiag_inputs ),
191
192
axis = split_axis ,
192
193
)
193
- new_components = [
194
- op (component , other_split )
194
+
195
+ split_dot_results = [
196
+ dot_op (component , other_split )
195
197
if client_idx == 0
196
- else op (other_split , component )
197
- for component , other_split in zip (components , other_dot_input_split )
198
+ else dot_op (other_split , component )
199
+ for component , other_split in zip (blockdiag_inputs , other_dot_input_split )
198
200
]
199
- new_output = join ( split_axis , * new_components )
201
+ new_output = concat_with_broadcast ( split_dot_results , dim = split_axis )
200
202
201
203
copy_stack_trace (node .outputs [0 ], new_output )
202
204
return {client .outputs [0 ]: new_output }
0 commit comments