Skip to content

Commit b0214a4

Browse files
Add concat_with_broadcast helper function
Use new helper in xt.concat Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
1 parent d4e8f73 commit b0214a4

File tree

3 files changed

+85
-24
lines changed

3 files changed

+85
-24
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.scalar import upcast
2828
from pytensor.tensor import TensorLike, as_tensor_variable
2929
from pytensor.tensor import basic as ptb
30-
from pytensor.tensor.basic import alloc, second
30+
from pytensor.tensor.basic import alloc, join, second
3131
from pytensor.tensor.exceptions import NotScalarConstantError
3232
from pytensor.tensor.math import abs as pt_abs
3333
from pytensor.tensor.math import all as pt_all
@@ -2018,6 +2018,42 @@ def broadcast_with_others(a, others):
20182018
return brodacasted_vars
20192019

20202020

2021+
def concat_with_broadcast(tensor_list, axis=0):
2022+
"""
2023+
Concatenate a list of tensors, broadcasting the non-concatenated dimensions to align.
2024+
"""
2025+
if not tensor_list:
2026+
raise ValueError("Cannot concatenate an empty list of tensors.")
2027+
2028+
ndim = tensor_list[0].ndim
2029+
if not all(t.ndim == ndim for t in tensor_list):
2030+
raise TypeError(
2031+
"Only tensors with the same number of dimensions can be concatenated. "
2032+
f"Input ndims were: {[x.ndim for x in tensor_list]}"
2033+
)
2034+
2035+
axis = normalize_axis_index(axis=axis, ndim=ndim)
2036+
non_concat_shape = [1 if i != axis else None for i in range(ndim)]
2037+
2038+
for tensor_inp in tensor_list:
2039+
for i, (bcast, sh) in enumerate(
2040+
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
2041+
):
2042+
if bcast or i == axis:
2043+
continue
2044+
non_concat_shape[i] = sh
2045+
2046+
assert non_concat_shape.count(None) == 1
2047+
2048+
bcast_tensor_inputs = []
2049+
for tensor_inp in tensor_list:
2050+
# We modify the concat_axis in place, as we don't need the list anywhere else
2051+
non_concat_shape[axis] = tensor_inp.shape[axis]
2052+
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
2053+
2054+
return join(axis, *bcast_tensor_inputs)
2055+
2056+
20212057
__all__ = [
20222058
"searchsorted",
20232059
"cumsum",
@@ -2035,6 +2071,7 @@ def broadcast_with_others(a, others):
20352071
"ravel_multi_index",
20362072
"broadcast_shape",
20372073
"broadcast_to",
2074+
"concat_with_broadcast",
20382075
"geomspace",
20392076
"logspace",
20402077
"linspace",

pytensor/xtensor/rewriting/shape.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from pytensor.graph import node_rewriter
33
from pytensor.tensor import (
44
broadcast_to,
5+
concat_with_broadcast,
56
expand_dims,
6-
join,
77
moveaxis,
88
specify_shape,
99
squeeze,
@@ -74,28 +74,7 @@ def lower_concat(fgraph, node):
7474

7575
# Convert input XTensors to Tensors and align batch dimensions
7676
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
77-
78-
# Broadcast non-concatenated dimensions of each input
79-
non_concat_shape = [None] * len(out_dims)
80-
for tensor_inp in tensor_inputs:
81-
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
82-
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
83-
for i, (bcast, sh) in enumerate(
84-
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
85-
):
86-
if bcast or i == concat_axis or non_concat_shape[i] is not None:
87-
continue
88-
non_concat_shape[i] = sh
89-
90-
assert non_concat_shape.count(None) == 1
91-
92-
bcast_tensor_inputs = []
93-
for tensor_inp in tensor_inputs:
94-
# We modify the concat_axis in place, as we don't need the list anywhere else
95-
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
96-
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
97-
98-
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
77+
joined_tensor = concat_with_broadcast(tensor_inputs, axis=concat_axis)
9978
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
10079
return [new_out]
10180

tests/tensor/test_extra_ops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,3 +1333,48 @@ def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
13331333
atol=1e-6 if config.floatX.endswith("64") else 1e-4,
13341334
rtol=1e-6 if config.floatX.endswith("64") else 1e-4,
13351335
)
1336+
1337+
1338+
def test_concat_with_broadcast():
1339+
rng = np.random.default_rng()
1340+
a = pt.tensor("a", shape=(1, 3, 5))
1341+
b = pt.tensor("b", shape=(5, 3, 10))
1342+
1343+
c = pt.concat_with_broadcast([a, b], axis=2)
1344+
fn = function([a, b], c, mode="FAST_COMPILE")
1345+
assert c.type.shape == (5, 3, 15)
1346+
1347+
a_val = rng.normal(size=(1, 3, 5)).astype(config.floatX)
1348+
b_val = rng.normal(size=(5, 3, 10)).astype(config.floatX)
1349+
c_val = fn(a_val, b_val)
1350+
1351+
# The result should be a tile + concat
1352+
np.testing.assert_allclose(c_val[:, :, :5], np.tile(a_val, (5, 1, 1)))
1353+
np.testing.assert_allclose(c_val[:, :, 5:], b_val)
1354+
1355+
# If a and b already conform, the result should be the same as a concatenation
1356+
a = pt.tensor("a", shape=(1, 1, 3, 5, 10))
1357+
b = pt.tensor("b", shape=(1, 1, 3, 2, 10))
1358+
c = pt.concat_with_broadcast([a, b], axis=-2)
1359+
assert c.type.shape == (1, 1, 3, 7, 10)
1360+
1361+
fn = function([a, b], c, mode="FAST_COMPILE")
1362+
a_val = rng.normal(size=(1, 1, 3, 5, 10)).astype(config.floatX)
1363+
b_val = rng.normal(size=(1, 1, 3, 2, 10)).astype(config.floatX)
1364+
c_val = fn(a_val, b_val)
1365+
np.testing.assert_allclose(c_val, np.concatenate([a_val, b_val], axis=-2))
1366+
1367+
c = pt.concat_with_broadcast([a], axis=0)
1368+
fn = function([a], c, mode="FAST_COMPILE")
1369+
np.testing.assert_allclose(fn(a_val), a_val)
1370+
1371+
with pytest.raises(ValueError, match="Cannot concatenate an empty list of tensors"):
1372+
pt.concat_with_broadcast([], axis=0)
1373+
1374+
with pytest.raises(
1375+
TypeError,
1376+
match="Only tensors with the same number of dimensions can be concatenated.",
1377+
):
1378+
a = pt.tensor("a", shape=(1, 3, 5))
1379+
b = pt.tensor("b", shape=(3, 5))
1380+
pt.concat_with_broadcast([a, b], axis=1)

0 commit comments

Comments
 (0)