-
Notifications
You must be signed in to change notification settings - Fork 137
Description
Description
concatenate requires all non-axis dimensions to match, this is annoying when we want to say pad a dimension with another vector like concatenate([some_matrix, some_vector], axis=-1)
, We need to manually do something like concatenate([some_matrix, broadcast_to((some_matrix.shape[0], some_vector.shape[0]), some_vector)], axis=-1)
.
I suggest adding a helper that does the broadcasting automatically. It is basically the same code we need to lower xtensor.concat
to tensor operations:
pytensor/pytensor/xtensor/rewriting/shape.py
Lines 68 to 100 in 12213d0
@register_lower_xtensor | |
@node_rewriter(tracks=[Concat]) | |
def lower_concat(fgraph, node): | |
out_dims = node.outputs[0].type.dims | |
concat_dim = node.op.dim | |
concat_axis = out_dims.index(concat_dim) | |
# Convert input XTensors to Tensors and align batch dimensions | |
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs] | |
# Broadcast non-concatenated dimensions of each input | |
non_concat_shape = [None] * len(out_dims) | |
for tensor_inp in tensor_inputs: | |
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime | |
# I'm running this as "shape_unsafe" to simplify the logic / returned graph | |
for i, (bcast, sh) in enumerate( | |
zip(tensor_inp.type.broadcastable, tensor_inp.shape) | |
): | |
if bcast or i == concat_axis or non_concat_shape[i] is not None: | |
continue | |
non_concat_shape[i] = sh | |
assert non_concat_shape.count(None) == 1 | |
bcast_tensor_inputs = [] | |
for tensor_inp in tensor_inputs: | |
# We modify the concat_axis in place, as we don't need the list anywhere else | |
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis] | |
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape)) | |
joined_tensor = join(concat_axis, *bcast_tensor_inputs) | |
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) | |
return [new_out] |
We can refactor that into a helper that works directly on tensor inputs, offer it to users, and reuse in the lowering rewrite.
Call it concat_with_broacast
or xconcat
?