Skip to content

Commit 11218cf

Browse files
committed
Rewrite Blockwise IncSubtensor
Also cover cases of AdvancedIncSubtensor with batch indices that were not supported before
1 parent 55d39ed commit 11218cf

File tree

3 files changed

+331
-138
lines changed

3 files changed

+331
-138
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 144 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ScalarFromTensor,
2525
TensorFromScalar,
2626
alloc,
27+
arange,
2728
cast,
2829
concatenate,
2930
expand_dims,
@@ -34,9 +35,10 @@
3435
switch,
3536
)
3637
from pytensor.tensor.basic import constant as tensor_constant
37-
from pytensor.tensor.blockwise import Blockwise
38+
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
3839
from pytensor.tensor.elemwise import Elemwise
3940
from pytensor.tensor.exceptions import NotScalarConstantError
41+
from pytensor.tensor.extra_ops import broadcast_to
4042
from pytensor.tensor.math import (
4143
add,
4244
and_,
@@ -58,6 +60,7 @@
5860
)
5961
from pytensor.tensor.shape import (
6062
shape_padleft,
63+
shape_padright,
6164
shape_tuple,
6265
)
6366
from pytensor.tensor.sharedvar import TensorSharedVariable
@@ -1580,6 +1583,9 @@ def local_blockwise_of_subtensor(fgraph, node):
15801583
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
15811584
15821585
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
1586+
1587+
TODO: Handle batched indices like we do with blockwise of inc_subtensor
1588+
TODO: Extend to AdvanceSubtensor
15831589
"""
15841590
if not isinstance(node.op.core_op, Subtensor):
15851591
return
@@ -1600,64 +1606,151 @@ def local_blockwise_of_subtensor(fgraph, node):
16001606
@register_stabilize("shape_unsafe")
16011607
@register_specialize("shape_unsafe")
16021608
@node_rewriter([Blockwise])
1603-
def local_blockwise_advanced_inc_subtensor(fgraph, node):
1604-
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1605-
if not isinstance(node.op.core_op, AdvancedIncSubtensor):
1606-
return None
1609+
def local_blockwise_inc_subtensor(fgraph, node):
1610+
"""Rewrite blockwised inc_subtensors.
16071611
1608-
x, y, *idxs = node.inputs
1612+
Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
1613+
Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
16091614
1610-
# It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1611-
if any(
1612-
(
1613-
isinstance(idx, SliceType | NoneTypeT)
1614-
or (idx.type.dtype == "bool" and idx.type.ndim > 0)
1615-
)
1616-
for idx in idxs
1617-
):
1615+
such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
1616+
and can be safely rewritten without Blockwise.
1617+
"""
1618+
core_op = node.op.core_op
1619+
if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor):
16181620
return None
16191621

1620-
op: Blockwise = node.op # type: ignore
1621-
batch_ndim = op.batch_ndim(node)
1622-
1623-
new_idxs = []
1624-
for idx in idxs:
1625-
if all(idx.type.broadcastable[:batch_ndim]):
1626-
new_idxs.append(idx.squeeze(tuple(range(batch_ndim))))
1627-
else:
1628-
# Rewrite does not apply
1622+
x, y, *idxs = node.inputs
1623+
[out] = node.outputs
1624+
if isinstance(node.op.core_op, AdvancedIncSubtensor):
1625+
if any(
1626+
(
1627+
# Blockwise requires all inputs to be tensors so it is not possible
1628+
# to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
1629+
# If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
1630+
# are separated by basic indices
1631+
isinstance(idx, SliceType | NoneTypeT)
1632+
# Also get out if we have boolean indices as they cross dimension boundaries
1633+
# / can't be safely broadcasted depending on their runtime content
1634+
or (idx.type.dtype == "bool")
1635+
)
1636+
for idx in idxs
1637+
):
16291638
return None
16301639

1631-
x_batch_bcast = x.type.broadcastable[:batch_ndim]
1632-
y_batch_bcast = y.type.broadcastable[:batch_ndim]
1633-
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast, strict=True)):
1634-
# Need to broadcast batch x dims
1635-
batch_shape = tuple(
1636-
x_dim if (not xb or yb) else y_dim
1637-
for xb, x_dim, yb, y_dim in zip(
1638-
x_batch_bcast,
1640+
batch_ndim = node.op.batch_ndim(node)
1641+
idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]]
1642+
max_idx_core_ndim = max(idxs_core_ndim, default=0)
1643+
1644+
# Step 1. Broadcast buffer to batch_shape
1645+
if x.type.broadcastable != out.type.broadcastable:
1646+
batch_shape = [1] * batch_ndim
1647+
for inp in node.inputs:
1648+
for i, (broadcastable, batch_dim) in enumerate(
1649+
zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim])
1650+
):
1651+
if broadcastable:
1652+
# This dimension is broadcastable, it doesn't provide shape information
1653+
continue
1654+
if batch_shape[i] != 1:
1655+
# We already found a source of shape for this batch dimension
1656+
continue
1657+
batch_shape[i] = batch_dim
1658+
x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:]))
1659+
assert x.type.broadcastable == out.type.broadcastable
1660+
1661+
# Step 2. Massage indices so they respect blockwise semantics
1662+
if isinstance(core_op, IncSubtensor):
1663+
# For basic IncSubtensor there are two cases:
1664+
# 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
1665+
# 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
1666+
# in case we can end up with a basic IncSubtensor again
1667+
core_idxs = []
1668+
counter = 0
1669+
for idx in core_op.idx_list:
1670+
if isinstance(idx, slice):
1671+
# Squeeze away dummy dimensions so we can convert to slice
1672+
new_entries = [None, None, None]
1673+
for i, entry in enumerate((idx.start, idx.stop, idx.step)):
1674+
if entry is None:
1675+
continue
1676+
else:
1677+
new_entries[i] = new_entry = idxs[counter].squeeze()
1678+
counter += 1
1679+
if new_entry.ndim > 0:
1680+
# If the slice entry has dimensions after the squeeze we can't convert it to a slice
1681+
# We could try to convert to equivalent integer indices, but nothing guarantees
1682+
# that the slice is "square".
1683+
return None
1684+
core_idxs.append(slice(*new_entries))
1685+
else:
1686+
core_idxs.append(_squeeze_left(idxs[counter]))
1687+
counter += 1
1688+
else:
1689+
# For AdvancedIncSubtensor we have tensor integer indices,
1690+
# We need to expand batch indexes on the right, so they don't interact with core index dimensions
1691+
# We still squeeze on the left in case that allows us to use simpler indices
1692+
core_idxs = [
1693+
_squeeze_left(
1694+
shape_padright(idx, max_idx_core_ndim - idx_core_ndim),
1695+
stop_at_dim=batch_ndim,
1696+
)
1697+
for idx, idx_core_ndim in zip(idxs, idxs_core_ndim)
1698+
]
1699+
1700+
# Step 3. Create new indices for the new batch dimension of x
1701+
if not all(
1702+
all(idx.type.broadcastable[:batch_ndim])
1703+
for idx in idxs
1704+
if not isinstance(idx, slice)
1705+
):
1706+
# If indices have batch dimensions in the indices, they will interact with the new dimensions of x
1707+
# We build vectorized indexing with new arange indices that do not interact with core indices or each other
1708+
# (i.e., they broadcast)
1709+
1710+
# Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
1711+
# we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
1712+
# even if not all batch dimensions have corresponding batch indices.
1713+
batch_slices = [
1714+
shape_padright(arange(x_batch_shape, dtype="int64"), n)
1715+
for (x_batch_shape, n) in zip(
16391716
tuple(x.shape)[:batch_ndim],
1640-
y_batch_bcast,
1641-
tuple(y.shape)[:batch_ndim],
1642-
strict=True,
1717+
reversed(range(max_idx_core_ndim, max_idx_core_ndim + batch_ndim)),
16431718
)
1644-
)
1645-
core_shape = tuple(x.shape)[batch_ndim:]
1646-
x = alloc(x, *batch_shape, *core_shape)
1647-
1648-
new_idxs = [slice(None)] * batch_ndim + new_idxs
1649-
x_view = x[tuple(new_idxs)]
1650-
1651-
# We need to introduce any implicit expand_dims on core dimension of y
1652-
y_core_ndim = y.type.ndim - batch_ndim
1653-
if (missing_y_core_ndim := x_view.type.ndim - batch_ndim - y_core_ndim) > 0:
1654-
missing_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
1655-
y = expand_dims(y, missing_axes)
1656-
1657-
symbolic_idxs = x_view.owner.inputs[1:]
1658-
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
1659-
copy_stack_trace(node.outputs, new_out)
1660-
return new_out
1719+
]
1720+
else:
1721+
# In the case we don't have batch indices,
1722+
# we can use slice(None) to broadcast the core indices to each new batch dimension of x / y
1723+
batch_slices = [slice(None)] * batch_ndim
1724+
1725+
new_idxs = (*batch_slices, *core_idxs)
1726+
x_view = x[new_idxs]
1727+
1728+
# Step 4. Introduce any implicit expand_dims on core dimension of y
1729+
missing_y_core_ndim = x_view.type.ndim - y.type.ndim
1730+
implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
1731+
y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim)
1732+
1733+
if isinstance(core_op, IncSubtensor):
1734+
# Check if we can still use a basic IncSubtensor
1735+
if isinstance(x_view.owner.op, Subtensor):
1736+
new_props = core_op._props_dict()
1737+
new_props["idx_list"] = x_view.owner.op.idx_list
1738+
new_core_op = type(core_op)(**new_props)
1739+
symbolic_idxs = x_view.owner.inputs[1:]
1740+
new_out = new_core_op(x, y, *symbolic_idxs)
1741+
else:
1742+
# We need to use AdvancedSet/IncSubtensor
1743+
if core_op.set_instead_of_inc:
1744+
new_out = x[new_idxs].set(y)
1745+
else:
1746+
new_out = x[new_idxs].inc(y)
1747+
else:
1748+
# AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
1749+
symbolic_idxs = x_view.owner.inputs[1:]
1750+
new_out = core_op(x, y, *symbolic_idxs)
1751+
1752+
copy_stack_trace(out, new_out)
1753+
return [new_out]
16611754

16621755

16631756
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])

pytensor/tensor/subtensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,6 @@ def _process(self, idxs, op_inputs, pstate):
14331433
pprint.assign(Subtensor, SubtensorPrinter())
14341434

14351435

1436-
# TODO: Implement similar vectorize for Inc/SetSubtensor
14371436
@_vectorize_node.register(Subtensor)
14381437
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
14391438
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""

0 commit comments

Comments
 (0)