Skip to content

Implement xtensor-like concat helper #1552

@ricardoV94

Description

@ricardoV94

Description

concatenate requires all non-axis dimensions to match, this is annoying when we want to say pad a dimension with another vector like concatenate([some_matrix, some_vector], axis=-1), We need to manually do something like concatenate([some_matrix, broadcast_to((some_matrix.shape[0], some_vector.shape[0]), some_vector)], axis=-1).

I suggest adding a helper that does the broadcasting automatically. It is basically the same code we need to lower xtensor.concat to tensor operations:

@register_lower_xtensor
@node_rewriter(tracks=[Concat])
def lower_concat(fgraph, node):
out_dims = node.outputs[0].type.dims
concat_dim = node.op.dim
concat_axis = out_dims.index(concat_dim)
# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
# Broadcast non-concatenated dimensions of each input
non_concat_shape = [None] * len(out_dims)
for tensor_inp in tensor_inputs:
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
for i, (bcast, sh) in enumerate(
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
):
if bcast or i == concat_axis or non_concat_shape[i] is not None:
continue
non_concat_shape[i] = sh
assert non_concat_shape.count(None) == 1
bcast_tensor_inputs = []
for tensor_inp in tensor_inputs:
# We modify the concat_axis in place, as we don't need the list anywhere else
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]

We can refactor that into a helper that works directly on tensor inputs, offer it to users, and reuse in the lowering rewrite.

Call it concat_with_broacast or xconcat?

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions