Skip to content

Commit 749bfda

Browse files
block_diag dot rewrite
Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
1 parent 2aad463 commit 749bfda

File tree

2 files changed

+189
-5
lines changed

2 files changed

+189
-5
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@
3232
moveaxis,
3333
ones_like,
3434
register_infer_shape,
35+
split,
3536
switch,
3637
zeros,
3738
zeros_like,
3839
)
40+
from pytensor.tensor.blockwise import Blockwise
3941
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
4042
from pytensor.tensor.exceptions import NotScalarConstantError
41-
from pytensor.tensor.extra_ops import broadcast_arrays
43+
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
4244
from pytensor.tensor.math import (
4345
Dot,
4446
Prod,
@@ -96,6 +98,7 @@
9698
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
9799
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
98100
from pytensor.tensor.shape import Shape, Shape_i
101+
from pytensor.tensor.slinalg import BlockDiagonal
99102
from pytensor.tensor.subtensor import Subtensor
100103
from pytensor.tensor.type import (
101104
complex_dtypes,
@@ -146,6 +149,72 @@ def local_0_dot_x(fgraph, node):
146149
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
147150

148151

152+
@register_stabilize
153+
@node_rewriter([Blockwise])
154+
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
155+
r"""
156+
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
157+
158+
BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
159+
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
160+
a single dot on the larger matrix.
161+
"""
162+
if not isinstance(node.op.core_op, BlockDiagonal):
163+
return
164+
165+
# Check that the BlockDiagonal is an input to a Dot node:
166+
for client in itertools.chain.from_iterable(
167+
get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2]
168+
):
169+
if not (
170+
(
171+
isinstance(client.op, Dot)
172+
and all(input.ndim == 2 for input in client.inputs)
173+
)
174+
or client.op == _matmul
175+
):
176+
continue
177+
178+
[blockdiag_result] = node.outputs
179+
blockdiag_inputs = node.inputs
180+
181+
dot_op = client.op
182+
183+
try:
184+
client_idx = client.inputs.index(blockdiag_result)
185+
except ValueError:
186+
# If the blockdiag result is not an input to the dot, there is at least one Op between them (usually a
187+
# DimShuffle). In this case, we need to figure out which of the inputs of the dot eventually leads to the
188+
# blockdiag result.
189+
for ancestor in client.inputs:
190+
if ancestor.owner and blockdiag_result in ancestor.owner.inputs:
191+
client_idx = client.inputs.index(ancestor)
192+
break
193+
194+
other_input = client.inputs[1 - client_idx]
195+
196+
split_axis = -2 if client_idx == 0 else -1
197+
shape_idx = -1 if client_idx == 0 else -2
198+
199+
other_dot_input_split = split(
200+
other_input,
201+
splits_size=[component.shape[shape_idx] for component in blockdiag_inputs],
202+
n_splits=len(blockdiag_inputs),
203+
axis=split_axis,
204+
)
205+
206+
split_dot_results = [
207+
dot_op(component, other_split)
208+
if client_idx == 0
209+
else dot_op(other_split, component)
210+
for component, other_split in zip(blockdiag_inputs, other_dot_input_split)
211+
]
212+
new_output = concat_with_broadcast(split_dot_results, axis=split_axis)
213+
214+
copy_stack_trace(node.outputs[0], new_output)
215+
return {client.outputs[0]: new_output}
216+
217+
149218
@register_canonicalize
150219
@node_rewriter([Dot, _matmul])
151220
def local_lift_transpose_through_dot(fgraph, node):
@@ -2582,7 +2651,6 @@ def add_calculate(num, denum, aslist=False, out_type=None):
25822651
name="add_canonizer_group",
25832652
)
25842653

2585-
25862654
register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")
25872655

25882656

@@ -3720,7 +3788,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
37203788
)
37213789
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")
37223790

3723-
37243791
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
37253792
# i.e logit(sigmoid(x)) -> x
37263793
local_logit_sigmoid = PatternNodeRewriter(
@@ -3734,7 +3801,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
37343801
register_canonicalize(local_logit_sigmoid)
37353802
register_specialize(local_logit_sigmoid)
37363803

3737-
37383804
# sigmoid(log(x / (1-x)) -> x
37393805
# i.e., sigmoid(logit(x)) -> x
37403806
local_sigmoid_logit = PatternNodeRewriter(
@@ -3775,7 +3841,6 @@ def local_useless_conj(fgraph, node):
37753841

37763842
register_specialize(local_polygamma_to_tri_gamma)
37773843

3778-
37793844
local_log_kv = PatternNodeRewriter(
37803845
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
37813846
# During stabilize -x is converted to -1.0 * x

tests/tensor/rewriting/test_math.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
simplify_mul,
116116
)
117117
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
118+
from pytensor.tensor.slinalg import BlockDiagonal
118119
from pytensor.tensor.type import (
119120
TensorType,
120121
cmatrix,
@@ -4745,3 +4746,121 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
47454746
out.eval({a: a_test, b: b_test}, mode=test_mode),
47464747
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
47474748
)
4749+
4750+
4751+
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
4752+
@pytest.mark.parametrize(
4753+
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
4754+
)
4755+
@pytest.mark.parametrize(
4756+
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
4757+
)
4758+
def test_local_block_diag_dot_to_dot_block_diag(
4759+
left_multiply, batch_blockdiag, batch_other
4760+
):
4761+
"""
4762+
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
4763+
"""
4764+
4765+
def has_blockdiag(graph):
4766+
return any(
4767+
(
4768+
var.owner
4769+
and (
4770+
isinstance(var.owner.op, BlockDiagonal)
4771+
or (
4772+
isinstance(var.owner.op, Blockwise)
4773+
and isinstance(var.owner.op.core_op, BlockDiagonal)
4774+
)
4775+
)
4776+
)
4777+
for var in ancestors([graph])
4778+
)
4779+
4780+
a = tensor("a", shape=(4, 2))
4781+
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
4782+
c = tensor("c", shape=(4, 4))
4783+
x = pt.linalg.block_diag(a, b, c)
4784+
4785+
d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))
4786+
4787+
# Test multiple clients are all rewritten
4788+
if left_multiply:
4789+
out = x @ d
4790+
else:
4791+
out = d @ x
4792+
4793+
assert has_blockdiag(out)
4794+
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
4795+
assert not has_blockdiag(fn.maker.fgraph.outputs[0])
4796+
4797+
n_dots_rewrite = sum(
4798+
isinstance(node.op, Dot | Dot22)
4799+
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
4800+
for node in fn.maker.fgraph.apply_nodes
4801+
)
4802+
assert n_dots_rewrite == 3
4803+
4804+
fn_expected = pytensor.function(
4805+
[a, b, c, d],
4806+
out,
4807+
mode=Mode(linker="py", optimizer=None),
4808+
)
4809+
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
4810+
4811+
n_dots_no_rewrite = sum(
4812+
isinstance(node.op, Dot | Dot22)
4813+
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
4814+
for node in fn_expected.maker.fgraph.apply_nodes
4815+
)
4816+
assert n_dots_no_rewrite == 1
4817+
4818+
rng = np.random.default_rng()
4819+
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
4820+
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
4821+
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
4822+
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4823+
4824+
rewrite_out = fn(a_val, b_val, c_val, d_val)
4825+
expected_out = fn_expected(a_val, b_val, c_val, d_val)
4826+
np.testing.assert_allclose(
4827+
rewrite_out,
4828+
expected_out,
4829+
atol=1e-6 if config.floatX == "float32" else 1e-12,
4830+
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4831+
)
4832+
4833+
4834+
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
4835+
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
4836+
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):
4837+
rng = np.random.default_rng()
4838+
a_size = int(rng.uniform(0, size))
4839+
b_size = int(rng.uniform(0, size - a_size))
4840+
c_size = size - a_size - b_size
4841+
4842+
a = tensor("a", shape=(a_size, a_size))
4843+
b = tensor("b", shape=(b_size, b_size))
4844+
c = tensor("c", shape=(c_size, c_size))
4845+
d = tensor("d", shape=(size,))
4846+
4847+
x = pt.linalg.block_diag(a, b, c)
4848+
out = x @ d
4849+
4850+
mode = get_default_mode()
4851+
if not rewrite:
4852+
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
4853+
fn = pytensor.function([a, b, c, d], out, mode=mode)
4854+
4855+
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
4856+
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
4857+
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
4858+
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
4859+
4860+
benchmark(
4861+
fn,
4862+
a_val,
4863+
b_val,
4864+
c_val,
4865+
d_val,
4866+
)

0 commit comments

Comments
 (0)