Skip to content

Commit e0bcea7

Browse files
committed
Simplify linalg rewrites with pattern matching
1 parent 7dbc034 commit e0bcea7

File tree

5 files changed

+390
-617
lines changed

5 files changed

+390
-617
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.rewriting.basic import register_specialize
1717
from pytensor.tensor.rewriting.blockwise import blockwise_of
18-
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
1918
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
2019
from pytensor.tensor.variable import TensorVariable
2120

@@ -79,28 +78,26 @@ def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
7978
# the root variable is the pre-DimShuffled input.
8079
# Otherwise, `a` is considered the root variable.
8180
# We also return whether the root `a` is transposed.
81+
root_a = a
8282
transposed = False
83-
if a.owner is not None and isinstance(a.owner.op, DimShuffle):
84-
if a.owner.op.is_left_expand_dims:
85-
[a] = a.owner.inputs
86-
elif is_matrix_transpose(a):
87-
[a] = a.owner.inputs
88-
transposed = True
89-
return a, transposed
83+
match a.owner_op_and_inputs:
84+
case (DimShuffle(is_left_expand_dims=True), root_a): # type: ignore[misc]
85+
transposed = False
86+
case (DimShuffle(is_left_expanded_matrix_transpose=True), root_a): # type: ignore[misc]
87+
transposed = True # type: ignore[unreachable]
88+
89+
return root_a, transposed
9090

9191
def find_solve_clients(var, assume_a):
9292
clients = []
9393
for cl, idx in fgraph.clients[var]:
94-
if (
95-
idx == 0
96-
and isinstance(cl.op, Blockwise)
97-
and isinstance(cl.op.core_op, Solve)
98-
and (cl.op.core_op.assume_a == assume_a)
99-
):
100-
clients.append(cl)
101-
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
102-
# If it's a left expand_dims, recurse on the output
103-
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
94+
match (idx, cl.op, *cl.outputs):
95+
case (0, Blockwise(Solve(assume_a=assume_a_var)), *_) if (
96+
assume_a_var == assume_a
97+
):
98+
clients.append(cl)
99+
case (0, DimShuffle(is_left_expand_dims=True), cl_out):
100+
clients.extend(find_solve_clients(cl_out, assume_a))
104101
return clients
105102

106103
assume_a = node.op.core_op.assume_a
@@ -119,11 +116,11 @@ def find_solve_clients(var, assume_a):
119116

120117
# Find Solves using A.T
121118
for cl, _ in fgraph.clients[A]:
122-
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
123-
A_T = cl.out
124-
A_solve_clients_and_transpose.extend(
125-
(client, True) for client in find_solve_clients(A_T, assume_a)
126-
)
119+
match (cl.op, *cl.outputs):
120+
case (DimShuffle(is_left_expanded_matrix_transpose=True), A_T):
121+
A_solve_clients_and_transpose.extend(
122+
(client, True) for client in find_solve_clients(A_T, assume_a)
123+
)
127124

128125
if not eager and len(A_solve_clients_and_transpose) == 1:
129126
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
@@ -185,34 +182,34 @@ def _scan_split_non_sequence_decomposition_and_solve(
185182
changed = False
186183
while True:
187184
for inner_node in new_scan_fgraph.toposort():
188-
if (
189-
isinstance(inner_node.op, Blockwise)
190-
and isinstance(inner_node.op.core_op, Solve)
191-
and inner_node.op.core_op.assume_a in allowed_assume_a
192-
):
193-
A, _b = inner_node.inputs
194-
if all(
195-
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
196-
for root_inp in graph_inputs([A])
185+
match (inner_node.op, *inner_node.inputs):
186+
case (Blockwise(Solve(assume_a=assume_a_var)), A, _b) if (
187+
assume_a_var in allowed_assume_a
197188
):
198-
if new_scan_fgraph is scan_op.fgraph:
199-
# Clone the first time to avoid mutating the original fgraph
200-
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
201-
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
202-
inner_node = equiv[inner_node] # type: ignore
203-
204-
replace_dict = _split_decomp_and_solve_steps(
205-
new_scan_fgraph,
206-
inner_node,
207-
eager=True,
208-
allowed_assume_a=allowed_assume_a,
209-
)
210-
assert isinstance(replace_dict, dict) and len(replace_dict) > 0, (
211-
"Rewrite failed"
212-
)
213-
new_scan_fgraph.replace_all(replace_dict.items())
214-
changed = True
215-
break # Break to start over with a fresh toposort
189+
if all(
190+
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
191+
for root_inp in graph_inputs([A])
192+
):
193+
if new_scan_fgraph is scan_op.fgraph:
194+
# Clone the first time to avoid mutating the original fgraph
195+
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
196+
non_sequences = {
197+
equiv[non_seq] for non_seq in non_sequences
198+
}
199+
inner_node = equiv[inner_node] # type: ignore
200+
201+
replace_dict = _split_decomp_and_solve_steps(
202+
new_scan_fgraph,
203+
inner_node,
204+
eager=True,
205+
allowed_assume_a=allowed_assume_a,
206+
)
207+
assert (
208+
isinstance(replace_dict, dict) and len(replace_dict) > 0
209+
), "Rewrite failed"
210+
new_scan_fgraph.replace_all(replace_dict.items())
211+
changed = True
212+
break # Break to start over with a fresh toposort
216213
else: # no_break
217214
break # Nothing else changed
218215

0 commit comments

Comments
 (0)