Skip to content

Commit aedc8b4

Browse files
Test case with fake broadcast dim
1 parent 3c50d2a commit aedc8b4

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2024,11 +2024,22 @@ def concat_with_broadcast(tensor_list, dim=0):
20242024
"""
20252025
dim = dim if dim > 0 else tensor_list[0].ndim + dim
20262026
non_concat_shape = [None] * tensor_list[0].ndim
2027+
2028+
# If all inputs have the same broadcastable dim, it's not really a broadcast
2029+
all_bcast = [
2030+
all(t.type.broadcastable[i] for t in tensor_list)
2031+
for i in range(tensor_list[0].ndim)
2032+
]
2033+
20272034
for tensor_inp in tensor_list:
20282035
for i, (bcast, sh) in enumerate(
20292036
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
20302037
):
2031-
if bcast or i == dim or non_concat_shape[i] is not None:
2038+
if (
2039+
(bcast and not all_bcast[i])
2040+
or i == dim
2041+
or non_concat_shape[i] is not None
2042+
):
20322043
continue
20332044
non_concat_shape[i] = sh
20342045

tests/tensor/test_extra_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,3 +1349,14 @@ def test_concat_with_broadcast():
13491349

13501350
c_val = fn(a_val, b_val)
13511351
np.testing.assert_allclose(c_val[:, :, :5], np.tile(a_val, (5, 1, 1)))
1352+
1353+
# If a and b already conform, the result should be the same as a concatenation
1354+
a = pt.tensor("a", shape=(1, 1, 3, 5, 10))
1355+
b = pt.tensor("b", shape=(1, 1, 3, 2, 10))
1356+
c = pt.concatenate([a, b], axis=-2)
1357+
assert c.type.shape == (1, 1, 3, 7, 10)
1358+
1359+
fn = function([a, b], c, mode="FAST_COMPILE")
1360+
a_val, b_val = rng.normal(size=(1, 1, 3, 5, 10)), rng.normal(size=(1, 1, 3, 2, 10))
1361+
c_val = fn(a_val, b_val)
1362+
np.testing.assert_allclose(c_val, np.concatenate([a_val, b_val], axis=-2))

0 commit comments

Comments
 (0)