Skip to content

Commit 6f8bb55

Browse files
Optimize matmuls involving block diagonal matrices (#1493)
* Add `concat_with_broadcast` helper function Use new helper in xt.concat Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com> * block_diag dot rewrite Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com> --------- Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
1 parent d4e8f73 commit 6f8bb55

File tree

5 files changed

+271
-29
lines changed

5 files changed

+271
-29
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.scalar import upcast
2828
from pytensor.tensor import TensorLike, as_tensor_variable
2929
from pytensor.tensor import basic as ptb
30-
from pytensor.tensor.basic import alloc, second
30+
from pytensor.tensor.basic import alloc, join, second
3131
from pytensor.tensor.exceptions import NotScalarConstantError
3232
from pytensor.tensor.math import abs as pt_abs
3333
from pytensor.tensor.math import all as pt_all
@@ -2018,6 +2018,42 @@ def broadcast_with_others(a, others):
20182018
return brodacasted_vars
20192019

20202020

2021+
def concat_with_broadcast(tensor_list, axis=0):
2022+
"""
2023+
Concatenate a list of tensors, broadcasting the non-concatenated dimensions to align.
2024+
"""
2025+
if not tensor_list:
2026+
raise ValueError("Cannot concatenate an empty list of tensors.")
2027+
2028+
ndim = tensor_list[0].ndim
2029+
if not all(t.ndim == ndim for t in tensor_list):
2030+
raise TypeError(
2031+
"Only tensors with the same number of dimensions can be concatenated. "
2032+
f"Input ndims were: {[x.ndim for x in tensor_list]}"
2033+
)
2034+
2035+
axis = normalize_axis_index(axis=axis, ndim=ndim)
2036+
non_concat_shape = [1 if i != axis else None for i in range(ndim)]
2037+
2038+
for tensor_inp in tensor_list:
2039+
for i, (bcast, sh) in enumerate(
2040+
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
2041+
):
2042+
if bcast or i == axis:
2043+
continue
2044+
non_concat_shape[i] = sh
2045+
2046+
assert non_concat_shape.count(None) == 1
2047+
2048+
bcast_tensor_inputs = []
2049+
for tensor_inp in tensor_list:
2050+
# We modify the concat_axis in place, as we don't need the list anywhere else
2051+
non_concat_shape[axis] = tensor_inp.shape[axis]
2052+
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
2053+
2054+
return join(axis, *bcast_tensor_inputs)
2055+
2056+
20212057
__all__ = [
20222058
"searchsorted",
20232059
"cumsum",
@@ -2035,6 +2071,7 @@ def broadcast_with_others(a, others):
20352071
"ravel_multi_index",
20362072
"broadcast_shape",
20372073
"broadcast_to",
2074+
"concat_with_broadcast",
20382075
"geomspace",
20392076
"logspace",
20402077
"linspace",

pytensor/tensor/rewriting/math.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,21 @@
3232
moveaxis,
3333
ones_like,
3434
register_infer_shape,
35+
split,
3536
switch,
3637
zeros,
3738
zeros_like,
3839
)
40+
from pytensor.tensor.blockwise import Blockwise
3941
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
4042
from pytensor.tensor.exceptions import NotScalarConstantError
41-
from pytensor.tensor.extra_ops import broadcast_arrays
43+
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
4244
from pytensor.tensor.math import (
4345
Dot,
4446
Prod,
4547
Sum,
4648
_conj,
49+
_dot,
4750
_matmul,
4851
add,
4952
digamma,
@@ -96,6 +99,7 @@
9699
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
97100
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
98101
from pytensor.tensor.shape import Shape, Shape_i
102+
from pytensor.tensor.slinalg import BlockDiagonal
99103
from pytensor.tensor.subtensor import Subtensor
100104
from pytensor.tensor.type import (
101105
complex_dtypes,
@@ -146,6 +150,68 @@ def local_0_dot_x(fgraph, node):
146150
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
147151

148152

153+
@register_stabilize
154+
@node_rewriter([Blockwise])
155+
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
156+
r"""
157+
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
158+
159+
BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
160+
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
161+
a single dot on the larger matrix.
162+
"""
163+
if not isinstance(node.op.core_op, BlockDiagonal):
164+
return
165+
166+
# Check that the BlockDiagonal is an input to a Dot node:
167+
for client in itertools.chain.from_iterable(
168+
get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2]
169+
):
170+
if client.op not in (_dot, _matmul):
171+
continue
172+
173+
[blockdiag_result] = node.outputs
174+
blockdiag_inputs = node.inputs
175+
176+
dot_op = client.op
177+
178+
try:
179+
client_idx = client.inputs.index(blockdiag_result)
180+
except ValueError:
181+
# If the blockdiag result is not an input to the dot, there is at least one Op between them (usually a
182+
# DimShuffle). In this case, we need to figure out which of the inputs of the dot eventually leads to the
183+
# blockdiag result.
184+
for ancestor in client.inputs:
185+
if ancestor.owner and blockdiag_result in ancestor.owner.inputs:
186+
client_idx = client.inputs.index(ancestor)
187+
break
188+
189+
other_input = client.inputs[1 - client_idx]
190+
191+
split_axis = -2 if client_idx == 0 else -1
192+
split_size_axis = -1 if client_idx == 0 else -2
193+
194+
other_dot_input_split = split(
195+
other_input,
196+
splits_size=[
197+
component.shape[split_size_axis] for component in blockdiag_inputs
198+
],
199+
n_splits=len(blockdiag_inputs),
200+
axis=split_axis,
201+
)
202+
203+
split_dot_results = [
204+
dot_op(component, other_split)
205+
if client_idx == 0
206+
else dot_op(other_split, component)
207+
for component, other_split in zip(blockdiag_inputs, other_dot_input_split)
208+
]
209+
new_output = concat_with_broadcast(split_dot_results, axis=split_axis)
210+
211+
copy_stack_trace(node.outputs[0], new_output)
212+
return {client.outputs[0]: new_output}
213+
214+
149215
@register_canonicalize
150216
@node_rewriter([Dot, _matmul])
151217
def local_lift_transpose_through_dot(fgraph, node):
@@ -2582,7 +2648,6 @@ def add_calculate(num, denum, aslist=False, out_type=None):
25822648
name="add_canonizer_group",
25832649
)
25842650

2585-
25862651
register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")
25872652

25882653

@@ -3720,7 +3785,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
37203785
)
37213786
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")
37223787

3723-
37243788
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
37253789
# i.e logit(sigmoid(x)) -> x
37263790
local_logit_sigmoid = PatternNodeRewriter(
@@ -3734,7 +3798,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
37343798
register_canonicalize(local_logit_sigmoid)
37353799
register_specialize(local_logit_sigmoid)
37363800

3737-
37383801
# sigmoid(log(x / (1-x)) -> x
37393802
# i.e., sigmoid(logit(x)) -> x
37403803
local_sigmoid_logit = PatternNodeRewriter(
@@ -3775,7 +3838,6 @@ def local_useless_conj(fgraph, node):
37753838

37763839
register_specialize(local_polygamma_to_tri_gamma)
37773840

3778-
37793841
local_log_kv = PatternNodeRewriter(
37803842
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
37813843
# During stabilize -x is converted to -1.0 * x

pytensor/xtensor/rewriting/shape.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from pytensor.graph import node_rewriter
33
from pytensor.tensor import (
44
broadcast_to,
5+
concat_with_broadcast,
56
expand_dims,
6-
join,
77
moveaxis,
88
specify_shape,
99
squeeze,
@@ -74,28 +74,7 @@ def lower_concat(fgraph, node):
7474

7575
# Convert input XTensors to Tensors and align batch dimensions
7676
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
77-
78-
# Broadcast non-concatenated dimensions of each input
79-
non_concat_shape = [None] * len(out_dims)
80-
for tensor_inp in tensor_inputs:
81-
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
82-
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
83-
for i, (bcast, sh) in enumerate(
84-
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
85-
):
86-
if bcast or i == concat_axis or non_concat_shape[i] is not None:
87-
continue
88-
non_concat_shape[i] = sh
89-
90-
assert non_concat_shape.count(None) == 1
91-
92-
bcast_tensor_inputs = []
93-
for tensor_inp in tensor_inputs:
94-
# We modify the concat_axis in place, as we don't need the list anywhere else
95-
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
96-
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
97-
98-
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
77+
joined_tensor = concat_with_broadcast(tensor_inputs, axis=concat_axis)
9978
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
10079
return [new_out]
10180

tests/tensor/rewriting/test_math.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
simplify_mul,
116116
)
117117
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
118+
from pytensor.tensor.slinalg import BlockDiagonal
118119
from pytensor.tensor.type import (
119120
TensorType,
120121
cmatrix,
@@ -4745,3 +4746,121 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
47454746
out.eval({a: a_test, b: b_test}, mode=test_mode),
47464747
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
47474748
)
4749+
4750+
4751+
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
4752+
@pytest.mark.parametrize(
4753+
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
4754+
)
4755+
@pytest.mark.parametrize(
4756+
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
4757+
)
4758+
def test_local_block_diag_dot_to_dot_block_diag(
4759+
left_multiply, batch_blockdiag, batch_other
4760+
):
4761+
"""
4762+
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
4763+
"""
4764+
4765+
def has_blockdiag(graph):
4766+
return any(
4767+
(
4768+
var.owner
4769+
and (
4770+
isinstance(var.owner.op, BlockDiagonal)
4771+
or (
4772+
isinstance(var.owner.op, Blockwise)
4773+
and isinstance(var.owner.op.core_op, BlockDiagonal)
4774+
)
4775+
)
4776+
)
4777+
for var in ancestors([graph])
4778+
)
4779+
4780+
a = tensor("a", shape=(4, 2))
4781+
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
4782+
c = tensor("c", shape=(4, 4))
4783+
x = pt.linalg.block_diag(a, b, c)
4784+
4785+
d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))
4786+
4787+
# Test multiple clients are all rewritten
4788+
if left_multiply:
4789+
out = x @ d
4790+
else:
4791+
out = d @ x
4792+
4793+
assert has_blockdiag(out)
4794+
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
4795+
assert not has_blockdiag(fn.maker.fgraph.outputs[0])
4796+
4797+
n_dots_rewrite = sum(
4798+
isinstance(node.op, Dot | Dot22)
4799+
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
4800+
for node in fn.maker.fgraph.apply_nodes
4801+
)
4802+
assert n_dots_rewrite == 3
4803+
4804+
fn_expected = pytensor.function(
4805+
[a, b, c, d],
4806+
out,
4807+
mode=Mode(linker="py", optimizer=None),
4808+
)
4809+
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
4810+
4811+
n_dots_no_rewrite = sum(
4812+
isinstance(node.op, Dot | Dot22)
4813+
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
4814+
for node in fn_expected.maker.fgraph.apply_nodes
4815+
)
4816+
assert n_dots_no_rewrite == 1
4817+
4818+
rng = np.random.default_rng()
4819+
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
4820+
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
4821+
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
4822+
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4823+
4824+
rewrite_out = fn(a_val, b_val, c_val, d_val)
4825+
expected_out = fn_expected(a_val, b_val, c_val, d_val)
4826+
np.testing.assert_allclose(
4827+
rewrite_out,
4828+
expected_out,
4829+
atol=1e-6 if config.floatX == "float32" else 1e-12,
4830+
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4831+
)
4832+
4833+
4834+
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
4835+
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
4836+
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):
4837+
rng = np.random.default_rng()
4838+
a_size = int(rng.uniform(1, int(0.8 * size)))
4839+
b_size = int(rng.uniform(1, int(0.8 * (size - a_size))))
4840+
c_size = size - a_size - b_size
4841+
4842+
a = tensor("a", shape=(a_size, a_size))
4843+
b = tensor("b", shape=(b_size, b_size))
4844+
c = tensor("c", shape=(c_size, c_size))
4845+
d = tensor("d", shape=(size,))
4846+
4847+
x = pt.linalg.block_diag(a, b, c)
4848+
out = x @ d
4849+
4850+
mode = get_default_mode()
4851+
if not rewrite:
4852+
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
4853+
fn = pytensor.function([a, b, c, d], out, mode=mode)
4854+
4855+
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
4856+
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
4857+
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
4858+
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4859+
4860+
benchmark(
4861+
fn,
4862+
a_val,
4863+
b_val,
4864+
c_val,
4865+
d_val,
4866+
)

0 commit comments

Comments
 (0)