24
24
ScalarFromTensor ,
25
25
TensorFromScalar ,
26
26
alloc ,
27
+ arange ,
27
28
cast ,
28
29
concatenate ,
29
30
expand_dims ,
34
35
switch ,
35
36
)
36
37
from pytensor .tensor .basic import constant as tensor_constant
37
- from pytensor .tensor .blockwise import Blockwise
38
+ from pytensor .tensor .blockwise import Blockwise , _squeeze_left
38
39
from pytensor .tensor .elemwise import Elemwise
39
40
from pytensor .tensor .exceptions import NotScalarConstantError
41
+ from pytensor .tensor .extra_ops import broadcast_to
40
42
from pytensor .tensor .math import (
41
43
add ,
42
44
and_ ,
58
60
)
59
61
from pytensor .tensor .shape import (
60
62
shape_padleft ,
63
+ shape_padright ,
61
64
shape_tuple ,
62
65
)
63
66
from pytensor .tensor .sharedvar import TensorSharedVariable
@@ -1580,6 +1583,9 @@ def local_blockwise_of_subtensor(fgraph, node):
1580
1583
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
1581
1584
1582
1585
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
1586
+
1587
+ TODO: Handle batched indices like we do with blockwise of inc_subtensor
1588
+ TODO: Extend to AdvanceSubtensor
1583
1589
"""
1584
1590
if not isinstance (node .op .core_op , Subtensor ):
1585
1591
return
@@ -1600,64 +1606,151 @@ def local_blockwise_of_subtensor(fgraph, node):
1600
1606
@register_stabilize ("shape_unsafe" )
1601
1607
@register_specialize ("shape_unsafe" )
1602
1608
@node_rewriter ([Blockwise ])
1603
- def local_blockwise_advanced_inc_subtensor (fgraph , node ):
1604
- """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1605
- if not isinstance (node .op .core_op , AdvancedIncSubtensor ):
1606
- return None
1609
+ def local_blockwise_inc_subtensor (fgraph , node ):
1610
+ """Rewrite blockwised inc_subtensors.
1607
1611
1608
- x , y , * idxs = node .inputs
1612
+ Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
1613
+ Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
1609
1614
1610
- # It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1611
- if any (
1612
- (
1613
- isinstance (idx , SliceType | NoneTypeT )
1614
- or (idx .type .dtype == "bool" and idx .type .ndim > 0 )
1615
- )
1616
- for idx in idxs
1617
- ):
1615
+ such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
1616
+ and can be safely rewritten without Blockwise.
1617
+ """
1618
+ core_op = node .op .core_op
1619
+ if not isinstance (core_op , AdvancedIncSubtensor | IncSubtensor ):
1618
1620
return None
1619
1621
1620
- op : Blockwise = node .op # type: ignore
1621
- batch_ndim = op .batch_ndim (node )
1622
-
1623
- new_idxs = []
1624
- for idx in idxs :
1625
- if all (idx .type .broadcastable [:batch_ndim ]):
1626
- new_idxs .append (idx .squeeze (tuple (range (batch_ndim ))))
1627
- else :
1628
- # Rewrite does not apply
1622
+ x , y , * idxs = node .inputs
1623
+ [out ] = node .outputs
1624
+ if isinstance (node .op .core_op , AdvancedIncSubtensor ):
1625
+ if any (
1626
+ (
1627
+ # Blockwise requires all inputs to be tensors so it is not possible
1628
+ # to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
1629
+ # If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
1630
+ # are separated by basic indices
1631
+ isinstance (idx , SliceType | NoneTypeT )
1632
+ # Also get out if we have boolean indices as they cross dimension boundaries
1633
+ # / can't be safely broadcasted depending on their runtime content
1634
+ or (idx .type .dtype == "bool" )
1635
+ )
1636
+ for idx in idxs
1637
+ ):
1629
1638
return None
1630
1639
1631
- x_batch_bcast = x .type .broadcastable [:batch_ndim ]
1632
- y_batch_bcast = y .type .broadcastable [:batch_ndim ]
1633
- if any (xb and not yb for xb , yb in zip (x_batch_bcast , y_batch_bcast , strict = True )):
1634
- # Need to broadcast batch x dims
1635
- batch_shape = tuple (
1636
- x_dim if (not xb or yb ) else y_dim
1637
- for xb , x_dim , yb , y_dim in zip (
1638
- x_batch_bcast ,
1640
+ batch_ndim = node .op .batch_ndim (node )
1641
+ idxs_core_ndim = [len (inp_sig ) for inp_sig in node .op .inputs_sig [2 :]]
1642
+ max_idx_core_ndim = max (idxs_core_ndim , default = 0 )
1643
+
1644
+ # Step 1. Broadcast buffer to batch_shape
1645
+ if x .type .broadcastable != out .type .broadcastable :
1646
+ batch_shape = [1 ] * batch_ndim
1647
+ for inp in node .inputs :
1648
+ for i , (broadcastable , batch_dim ) in enumerate (
1649
+ zip (inp .type .broadcastable [:batch_ndim ], tuple (inp .shape )[:batch_ndim ])
1650
+ ):
1651
+ if broadcastable :
1652
+ # This dimension is broadcastable, it doesn't provide shape information
1653
+ continue
1654
+ if batch_shape [i ] != 1 :
1655
+ # We already found a source of shape for this batch dimension
1656
+ continue
1657
+ batch_shape [i ] = batch_dim
1658
+ x = broadcast_to (x , (* batch_shape , * x .shape [batch_ndim :]))
1659
+ assert x .type .broadcastable == out .type .broadcastable
1660
+
1661
+ # Step 2. Massage indices so they respect blockwise semantics
1662
+ if isinstance (core_op , IncSubtensor ):
1663
+ # For basic IncSubtensor there are two cases:
1664
+ # 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
1665
+ # 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
1666
+ # in case we can end up with a basic IncSubtensor again
1667
+ core_idxs = []
1668
+ counter = 0
1669
+ for idx in core_op .idx_list :
1670
+ if isinstance (idx , slice ):
1671
+ # Squeeze away dummy dimensions so we can convert to slice
1672
+ new_entries = [None , None , None ]
1673
+ for i , entry in enumerate ((idx .start , idx .stop , idx .step )):
1674
+ if entry is None :
1675
+ continue
1676
+ else :
1677
+ new_entries [i ] = new_entry = idxs [counter ].squeeze ()
1678
+ counter += 1
1679
+ if new_entry .ndim > 0 :
1680
+ # If the slice entry has dimensions after the squeeze we can't convert it to a slice
1681
+ # We could try to convert to equivalent integer indices, but nothing guarantees
1682
+ # that the slice is "square".
1683
+ return None
1684
+ core_idxs .append (slice (* new_entries ))
1685
+ else :
1686
+ core_idxs .append (_squeeze_left (idxs [counter ]))
1687
+ counter += 1
1688
+ else :
1689
+ # For AdvancedIncSubtensor we have tensor integer indices,
1690
+ # We need to expand batch indexes on the right, so they don't interact with core index dimensions
1691
+ # We still squeeze on the left in case that allows us to use simpler indices
1692
+ core_idxs = [
1693
+ _squeeze_left (
1694
+ shape_padright (idx , max_idx_core_ndim - idx_core_ndim ),
1695
+ stop_at_dim = batch_ndim ,
1696
+ )
1697
+ for idx , idx_core_ndim in zip (idxs , idxs_core_ndim )
1698
+ ]
1699
+
1700
+ # Step 3. Create new indices for the new batch dimension of x
1701
+ if not all (
1702
+ all (idx .type .broadcastable [:batch_ndim ])
1703
+ for idx in idxs
1704
+ if not isinstance (idx , slice )
1705
+ ):
1706
+ # If indices have batch dimensions in the indices, they will interact with the new dimensions of x
1707
+ # We build vectorized indexing with new arange indices that do not interact with core indices or each other
1708
+ # (i.e., they broadcast)
1709
+
1710
+ # Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
1711
+ # we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
1712
+ # even if not all batch dimensions have corresponding batch indices.
1713
+ batch_slices = [
1714
+ shape_padright (arange (x_batch_shape , dtype = "int64" ), n )
1715
+ for (x_batch_shape , n ) in zip (
1639
1716
tuple (x .shape )[:batch_ndim ],
1640
- y_batch_bcast ,
1641
- tuple (y .shape )[:batch_ndim ],
1642
- strict = True ,
1717
+ reversed (range (max_idx_core_ndim , max_idx_core_ndim + batch_ndim )),
1643
1718
)
1644
- )
1645
- core_shape = tuple (x .shape )[batch_ndim :]
1646
- x = alloc (x , * batch_shape , * core_shape )
1647
-
1648
- new_idxs = [slice (None )] * batch_ndim + new_idxs
1649
- x_view = x [tuple (new_idxs )]
1650
-
1651
- # We need to introduce any implicit expand_dims on core dimension of y
1652
- y_core_ndim = y .type .ndim - batch_ndim
1653
- if (missing_y_core_ndim := x_view .type .ndim - batch_ndim - y_core_ndim ) > 0 :
1654
- missing_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1655
- y = expand_dims (y , missing_axes )
1656
-
1657
- symbolic_idxs = x_view .owner .inputs [1 :]
1658
- new_out = op .core_op .make_node (x , y , * symbolic_idxs ).outputs
1659
- copy_stack_trace (node .outputs , new_out )
1660
- return new_out
1719
+ ]
1720
+ else :
1721
+ # In the case we don't have batch indices,
1722
+ # we can use slice(None) to broadcast the core indices to each new batch dimension of x / y
1723
+ batch_slices = [slice (None )] * batch_ndim
1724
+
1725
+ new_idxs = (* batch_slices , * core_idxs )
1726
+ x_view = x [new_idxs ]
1727
+
1728
+ # Step 4. Introduce any implicit expand_dims on core dimension of y
1729
+ missing_y_core_ndim = x_view .type .ndim - y .type .ndim
1730
+ implicit_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1731
+ y = _squeeze_left (expand_dims (y , implicit_axes ), stop_at_dim = batch_ndim )
1732
+
1733
+ if isinstance (core_op , IncSubtensor ):
1734
+ # Check if we can still use a basic IncSubtensor
1735
+ if isinstance (x_view .owner .op , Subtensor ):
1736
+ new_props = core_op ._props_dict ()
1737
+ new_props ["idx_list" ] = x_view .owner .op .idx_list
1738
+ new_core_op = type (core_op )(** new_props )
1739
+ symbolic_idxs = x_view .owner .inputs [1 :]
1740
+ new_out = new_core_op (x , y , * symbolic_idxs )
1741
+ else :
1742
+ # We need to use AdvancedSet/IncSubtensor
1743
+ if core_op .set_instead_of_inc :
1744
+ new_out = x [new_idxs ].set (y )
1745
+ else :
1746
+ new_out = x [new_idxs ].inc (y )
1747
+ else :
1748
+ # AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
1749
+ symbolic_idxs = x_view .owner .inputs [1 :]
1750
+ new_out = core_op (x , y , * symbolic_idxs )
1751
+
1752
+ copy_stack_trace (out , new_out )
1753
+ return [new_out ]
1661
1754
1662
1755
1663
1756
@node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ])
0 commit comments