Skip to content

Commit 6736e8e

Browse files
pair coding results
1 parent 558acbc commit 6736e8e

File tree

2 files changed

+50
-55
lines changed

2 files changed

+50
-55
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -162,59 +162,40 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
162162
if not isinstance(node.op.core_op, BlockDiagonal):
163163
return
164164

165-
def check_for_block_diag(x):
166-
return x.owner and (
167-
isinstance(x.owner.op, BlockDiagonal)
168-
or isinstance(x.owner.op, Blockwise)
169-
and isinstance(x.owner.op.core_op, BlockDiagonal)
170-
)
171-
172165
# Check that the BlockDiagonal is an input to a Dot node:
173166
for client in get_clients_at_depth(fgraph, node, depth=1):
174-
if not isinstance(client.op, Dot):
167+
if not (
168+
(
169+
isinstance(client.op, Dot)
170+
and all(input.ndim == 2 for input in client.inputs)
171+
)
172+
or client.op == _matrix_matrix_matmul
173+
):
175174
continue
176175

177176
op = client.op
178-
x, y = client.inputs
179177

180-
if not (check_for_block_diag(x) or check_for_block_diag(y)):
181-
continue
178+
client_idx = client.inputs.index(node.outputs[0])
182179

183-
# Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
184-
# non-block diagonal, and return a new block diagonal
185-
if check_for_block_diag(x) and not check_for_block_diag(y):
186-
components = x.owner.inputs
187-
y_splits = split(
188-
y,
189-
splits_size=[component.shape[-1] for component in components],
190-
n_splits=len(components),
191-
)
192-
new_components = [
193-
op(component, y_split)
194-
for component, y_split in zip(components, y_splits)
195-
]
196-
new_output = join(0, *new_components)
197-
198-
elif not check_for_block_diag(x) and check_for_block_diag(y):
199-
components = y.owner.inputs
200-
x_splits = split(
201-
x,
202-
splits_size=[component.shape[0] for component in components],
203-
n_splits=len(components),
204-
axis=1,
205-
)
180+
other_input = client.inputs[1 - client_idx]
181+
components = node.inputs
206182

207-
new_components = [
208-
op(x_split, component)
209-
for component, x_split in zip(components, x_splits)
210-
]
211-
new_output = join(1, *new_components)
183+
split_axis = -2 if client_idx == 0 else -1
184+
shape_idx = -1 if client_idx == 0 else -2
212185

213-
# Case 2: Both inputs are BlockDiagonal. Do nothing
214-
else:
215-
# TODO: If shapes are statically known and all components have equal shapes, we could rewrite
216-
# this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
217-
continue
186+
other_dot_input_split = split(
187+
other_input,
188+
splits_size=[component.shape[shape_idx] for component in components],
189+
n_splits=len(components),
190+
axis=split_axis,
191+
)
192+
new_components = [
193+
op(component, other_split)
194+
if client_idx == 0
195+
else op(other_split, component)
196+
for component, other_split in zip(components, other_dot_input_split)
197+
]
198+
new_output = join(split_axis, *new_components)
218199

219200
copy_stack_trace(node.outputs[0], new_output)
220201
return {client.outputs[0]: new_output}

tests/tensor/rewriting/test_math.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4749,15 +4749,21 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
47494749

47504750

47514751
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
4752-
def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
4752+
@pytest.mark.parametrize(
4753+
"batch_left", [True, False], ids=["batched_left", "unbatched_left"]
4754+
)
4755+
@pytest.mark.parametrize(
4756+
"batch_right", [True, False], ids=["batched_right", "unbatched_right"]
4757+
)
4758+
def test_local_block_diag_dot_to_dot_block_diag(left_multiply, batch_left, batch_right):
47534759
"""
47544760
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
47554761
"""
47564762
a = tensor("a", shape=(4, 2))
4757-
b = tensor("b", shape=(2, 4))
4763+
b = tensor("b", shape=(2, 4) if not batch_left else (3, 2, 4))
47584764
c = tensor("c", shape=(4, 4))
47594765
d = tensor("d", shape=(10, 10))
4760-
e = tensor("e", shape=(10, 10))
4766+
e = tensor("e", shape=(10, 10) if not batch_right else (3, 1, 10, 10))
47614767

47624768
x = pt.linalg.block_diag(a, b, c)
47634769

@@ -4767,30 +4773,38 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
47674773
else:
47684774
out = [d @ x, e @ x]
47694775

4770-
fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode)
4776+
with config.change_flags(optimizer_verbose=True):
4777+
fn = pytensor.function([a, b, c, d, e], out, mode=rewrite_mode)
4778+
47714779
assert not any(
47724780
isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort()
47734781
)
47744782

47754783
fn_expected = pytensor.function(
47764784
[a, b, c, d, e],
47774785
out,
4778-
mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"),
4786+
mode=Mode(linker="py", optimizer=None),
47794787
)
47804788

4789+
# TODO: Count Dots
4790+
47814791
rng = np.random.default_rng()
47824792
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
47834793
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
47844794
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
47854795
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
47864796
e_val = rng.normal(size=e.type.shape).astype(e.type.dtype)
47874797

4788-
np.testing.assert_allclose(
4789-
fn(a_val, b_val, c_val, d_val, e_val),
4790-
fn_expected(a_val, b_val, c_val, d_val, e_val),
4791-
atol=1e-6 if config.floatX == "float32" else 1e-12,
4792-
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4793-
)
4798+
rewrite_outs = fn(a_val, b_val, c_val, d_val, e_val)
4799+
expected_outs = fn_expected(a_val, b_val, c_val, d_val, e_val)
4800+
4801+
for out, expected in zip(rewrite_outs, expected_outs):
4802+
np.testing.assert_allclose(
4803+
out,
4804+
expected,
4805+
atol=1e-6 if config.floatX == "float32" else 1e-12,
4806+
rtol=1e-6 if config.floatX == "float32" else 1e-12,
4807+
)
47944808

47954809

47964810
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])

0 commit comments

Comments
 (0)