Skip to content

Commit 23b3781

Browse files
Add concat_with_broadcast helper function
Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
1 parent d4e8f73 commit 23b3781

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.scalar import upcast
2828
from pytensor.tensor import TensorLike, as_tensor_variable
2929
from pytensor.tensor import basic as ptb
30-
from pytensor.tensor.basic import alloc, second
30+
from pytensor.tensor.basic import alloc, join, second
3131
from pytensor.tensor.exceptions import NotScalarConstantError
3232
from pytensor.tensor.math import abs as pt_abs
3333
from pytensor.tensor.math import all as pt_all
@@ -2018,6 +2018,42 @@ def broadcast_with_others(a, others):
20182018
return brodacasted_vars
20192019

20202020

2021+
def concat_with_broadcast(tensor_list, dim=0):
2022+
"""
2023+
Concatenate a list of tensors, broadcasting the non-concatenated dimensions to align.
2024+
"""
2025+
dim = dim if dim > 0 else tensor_list[0].ndim + dim
2026+
non_concat_shape = [None] * tensor_list[0].ndim
2027+
2028+
# If all inputs have the same broadcastable dim, it's not really a broadcast
2029+
all_bcast = [
2030+
all(t.type.broadcastable[i] for t in tensor_list)
2031+
for i in range(tensor_list[0].ndim)
2032+
]
2033+
2034+
for tensor_inp in tensor_list:
2035+
for i, (bcast, sh) in enumerate(
2036+
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
2037+
):
2038+
if (
2039+
(bcast and not all_bcast[i])
2040+
or i == dim
2041+
or non_concat_shape[i] is not None
2042+
):
2043+
continue
2044+
non_concat_shape[i] = sh
2045+
2046+
assert non_concat_shape.count(None) == 1
2047+
2048+
bcast_tensor_inputs = []
2049+
for tensor_inp in tensor_list:
2050+
# We modify the concat_axis in place, as we don't need the list anywhere else
2051+
non_concat_shape[dim] = tensor_inp.shape[dim]
2052+
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
2053+
2054+
return join(dim, *bcast_tensor_inputs)
2055+
2056+
20212057
__all__ = [
20222058
"searchsorted",
20232059
"cumsum",
@@ -2035,6 +2071,7 @@ def broadcast_with_others(a, others):
20352071
"ravel_multi_index",
20362072
"broadcast_shape",
20372073
"broadcast_to",
2074+
"concat_with_broadcast",
20382075
"geomspace",
20392076
"logspace",
20402077
"linspace",

tests/tensor/test_extra_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,3 +1333,32 @@ def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
13331333
atol=1e-6 if config.floatX.endswith("64") else 1e-4,
13341334
rtol=1e-6 if config.floatX.endswith("64") else 1e-4,
13351335
)
1336+
1337+
1338+
def test_concat_with_broadcast():
1339+
rng = np.random.default_rng()
1340+
a = pt.tensor("a", shape=(1, 3, 5))
1341+
b = pt.tensor("b", shape=(5, 3, 10))
1342+
1343+
c = pt.concat_with_broadcast([a, b], dim=-1)
1344+
fn = function([a, b], c, mode="FAST_COMPILE")
1345+
assert c.type.shape == (5, 3, 15)
1346+
1347+
a_val = rng.normal(size=(1, 3, 5))
1348+
b_val = rng.normal(size=(5, 3, 10))
1349+
c_val = fn(a_val, b_val)
1350+
1351+
# The result should be a tile + concat
1352+
np.testing.assert_allclose(c_val[:, :, :5], np.tile(a_val, (5, 1, 1)))
1353+
np.testing.assert_allclose(c_val[:, :, 5:], b_val)
1354+
1355+
# If a and b already conform, the result should be the same as a concatenation
1356+
a = pt.tensor("a", shape=(1, 1, 3, 5, 10))
1357+
b = pt.tensor("b", shape=(1, 1, 3, 2, 10))
1358+
c = pt.concatenate([a, b], axis=-2)
1359+
assert c.type.shape == (1, 1, 3, 7, 10)
1360+
1361+
fn = function([a, b], c, mode="FAST_COMPILE")
1362+
a_val, b_val = rng.normal(size=(1, 1, 3, 5, 10)), rng.normal(size=(1, 1, 3, 2, 10))
1363+
c_val = fn(a_val, b_val)
1364+
np.testing.assert_allclose(c_val, np.concatenate([a_val, b_val], axis=-2))

0 commit comments

Comments
 (0)