Skip to content

Optimize matmuls involving block diagonal matrices #1493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 26, 2025

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jun 21, 2025

Description

This PR adds a rewrite to optimize matrix multiplication involving block diagonal matrices. When we have a a matrix X = BlockDiag(A, B), when you do Z = X @ Y, there's no interaction between terms in the A part and B part of the X matrix. So the dot can be instead computed as row_stack(A @ Y[:X.shape[0]], B @ Y[X.shape[0]:] (or in the general case, Y can be split into n pieces with appropriate shapes, and do row_stack([diag_component @ y_split for diag_component, y_split in zip(BlockDiag.inputs, split(Y, *args)]). If the case where the blockdiag matrix is right-multiplying, you instead col_stack and slice on axis=1.

Anyway, it's a lot faster to do this, because matmuls scale really badly in the dimension of the input, so doing two smaller operations is preferred. Here are the benchmarks, small has n=10, medium has n=100, large has n=1000. But in all cases it shows at least 2x speedup.

---------------------------------------------------------------------------------------------------------------- benchmark: 6 tests ----------------------------------------------------------------------------------------------------------------
Name (time in us)                                                       Min                   Max                  Mean              StdDev              Median                 IQR             Outliers           OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_block_diag_dot_to_dot_concat_benchmark[small-rewrite]           4.5830 (1.0)         90.7090 (2.13)         5.3007 (1.0)        1.6155 (2.53)       5.2080 (1.0)        0.1660 (1.0)         67;560  188,654.6546 (1.0)       12533           1
test_block_diag_dot_to_dot_concat_benchmark[small-no_rewrite]        8.5420 (1.86)        90.1670 (2.12)        10.1183 (1.91)       1.6055 (2.51)      10.0000 (1.92)       0.1680 (1.01)      430;2150   98,830.6599 (0.52)      18721           1

test_block_diag_dot_to_dot_concat_benchmark[medium-rewrite]          6.1250 (1.34)        44.8750 (1.05)         7.2724 (1.37)       0.6386 (1.0)        7.4170 (1.42)       0.2490 (1.50)     7575;7886  137,505.3510 (0.73)      35875           1
test_block_diag_dot_to_dot_concat_benchmark[medium-no_rewrite]      14.0420 (3.06)        42.6250 (1.0)         16.5707 (3.13)       1.3341 (2.09)      17.2500 (3.31)       2.1660 (13.05)     1174;108   60,347.4538 (0.32)      12177           1

test_block_diag_dot_to_dot_concat_benchmark[large-rewrite]          14.6660 (3.20)       248.2920 (5.83)        16.5375 (3.12)       4.7284 (7.40)      16.1250 (3.10)       0.4590 (2.76)      249;1621   60,468.5555 (0.32)      18765           1
test_block_diag_dot_to_dot_concat_benchmark[large-no_rewrite]      788.6250 (172.08)   1,982.7500 (46.52)    1,019.2728 (192.29)   150.6524 (235.91)   987.3335 (189.58)   130.6250 (786.86)      132;63      981.0916 (0.01)        734           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1493.org.readthedocs.build/en/1493/

@jessegrabowski jessegrabowski requested review from Copilot and ricardoV94 and removed request for Copilot June 21, 2025 20:14
Copilot

This comment was marked as outdated.

@ricardoV94 ricardoV94 added the enhancement New feature or request label Jun 21, 2025
Copy link

codecov bot commented Jun 21, 2025

Codecov Report

❌ Patch coverage is 98.00000% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 81.54%. Comparing base (d4e8f73) to head (b2cba9f).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/math.py 96.55% 0 Missing and 1 partial ⚠️

❌ Your patch status has failed because the patch coverage (98.00%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1493      +/-   ##
==========================================
+ Coverage   81.53%   81.54%   +0.01%     
==========================================
  Files         230      230              
  Lines       53012    53048      +36     
  Branches     9412     9419       +7     
==========================================
+ Hits        43222    43257      +35     
  Misses       7361     7361              
- Partials     2429     2430       +1     
Files with missing lines Coverage Δ
pytensor/tensor/extra_ops.py 89.18% <100.00%> (+0.30%) ⬆️
pytensor/xtensor/rewriting/shape.py 98.01% <100.00%> (-0.20%) ⬇️
pytensor/tensor/rewriting/math.py 90.43% <96.55%> (+0.10%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Some minor optimization questions

@jessegrabowski jessegrabowski force-pushed the block-diag-dot-rewrite branch from 68591ad to 7c3820b Compare July 25, 2025 13:08
@jessegrabowski
Copy link
Member Author

I'm still struggling with ExpandDims "blocking" between Dot and BlockDiag in the batch case here, do you have a suggestion for how to handle it?

@ricardoV94
Copy link
Member

What sort of graph exactly?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jul 25, 2025

Actually the issue is the Join in the blockwise case. Given a graph like:

Matmul [id A] <Tensor3(float64, shape=(3, 10, 10))> 4
 ├─ ExpandDims{axis=0} [id B] <Tensor3(float64, shape=(1, 10, 10))> 3
 │  └─ d [id C] <Matrix(float64, shape=(10, 10))>
 └─ Blockwise{BlockDiagonal{n_inputs=3}, (m0,n0),(m1,n1),(m2,n2)->(m,n)} [id D] <Tensor3(float64, shape=(3, 10, 10))> 2
    ├─ ExpandDims{axis=0} [id E] <Tensor3(float64, shape=(1, 4, 2))> 1
    │  └─ a [id F] <Matrix(float64, shape=(4, 2))>
    ├─ b [id G] <Tensor3(float64, shape=(3, 2, 4))>
    └─ ExpandDims{axis=0} [id H] <Tensor3(float64, shape=(1, 4, 4))> 0
       └─ c [id I] <Matrix(float64, shape=(4, 4))>

We pull out the 3 inputs and split d into [(1, 10, 4), (1, 10, 2), (1, 10, 4)], then do the 3 dots, obtaining results of shapes [(1, 10, 2), (3, 10, 4), (1, 10, 4)]

This should all be correct, but now the join fails, because the batch dim on the 3 results is not the same. Looking at it in the debugger, it seems like i need to tile a and c along the broadcastable dims until their shapes match up. Is there a helper for that?

To clarify that last point, here is res[:, 0, :2] from the matmul, which is the contribution from the (1, 10, 2) component:

array([[ 0.30917466, -1.10133142],
       [ 0.30917466, -1.10133142],
       [ 0.30917466, -1.10133142]])

And the (1, 10, 2) component itself:

array([[[ 0.30917466, -1.10133142],
        [-2.7780056 ,  6.51444629],
        [ 0.79439975, -2.1491907 ],
        [-3.09616644, -0.1655169 ],
        [ 0.28896778,  0.65660826],
        [ 2.88199648, -3.4479807 ],
        [-2.07519668,  1.33874492],
        [-0.57444061, -3.40710469],
        [-2.41102468,  2.96805356],
        [ 1.06383867, -0.90597861]]])

As you can see, it's just the first row duplicated 3 times.

@ricardoV94
Copy link
Member

Is there a helper for that?

That's what I'm proposing in the first part of #1552 ;)

@jessegrabowski
Copy link
Member Author

I added a concat_with_broadcast helper, basically just copying your code from #1552

The last case I'm still struggling with is when there are Dimshuffles between the Dot and the BlockDiag, for example:

Transpose{axes=[2, 0, 1, 3]} [id A] 10
 └─ Reshape{4} [id B] 9
    ├─ Dot [id C] 8
    │  ├─ Squeeze{axis=0} [id D] 7
    │  │  └─ Reshape{3} [id E] 6
    │  │     ├─ Blockwise{BlockDiagonal{n_inputs=3}, (m0,n0),(m1,n1),(m2,n2)->(m,n)} [id F] 5
    │  │     │  ├─ ExpandDims{axis=0} [id G] 4
    │  │     │  │  └─ a [id H]
    │  │     │  ├─ b [id I]
    │  │     │  └─ ExpandDims{axis=0} [id J] 3
    │  │     │     └─ c [id K]
    │  │     └─ [ 1 -1 10] [id L]
    │  └─ Squeeze{axis=0} [id M] 2
    │     └─ Reshape{3} [id N] 1
    │        ├─ Transpose{axes=[1, 2, 0, 3]} [id O] 0
    │        │  └─ d [id P]
    │        └─ [ 1 10 -1] [id Q]
    └─ [ 3 10  3 10] [id R]

I put the rewrite into specialize because I was hoping by that point these dimshuffles would have been cleaned up, but no such luck.

@jessegrabowski
Copy link
Member Author

I cobbled together something that works, but not sure if it's ideal

@jessegrabowski jessegrabowski force-pushed the block-diag-dot-rewrite branch from a9655b0 to abdce8e Compare July 26, 2025 06:08
@ricardoV94
Copy link
Member

Yeah you want the rewrite before the batched_to_core_dot in specialize. Register it in canonicalize or stabilize?

@jessegrabowski
Copy link
Member Author

It's in stabilize now

@jessegrabowski jessegrabowski force-pushed the block-diag-dot-rewrite branch from abdce8e to 749bfda Compare July 26, 2025 09:01
Use new helper in xt.concat

Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
@jessegrabowski jessegrabowski force-pushed the block-diag-dot-rewrite branch from 749bfda to 08e6736 Compare July 26, 2025 09:27
@jessegrabowski jessegrabowski requested a review from Copilot July 26, 2025 09:27
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes matrix multiplication involving block diagonal matrices by rewriting dot products to operate on smaller sub-matrices instead of the full block diagonal matrix. The optimization leverages the sparse structure of block diagonal matrices where only the diagonal blocks contain non-zero values.

  • Adds a new rewrite rule that transforms dot(block_diag(A, B), C) into concat(dot(A, C_split), dot(B, C_split))
  • Introduces a concat_with_broadcast utility function for concatenating tensors with automatic broadcasting
  • Provides significant performance improvements (2x+ speedup) by avoiding expensive operations on large sparse matrices

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
pytensor/tensor/rewriting/math.py Adds the main optimization rewrite rule for block diagonal matrix multiplication
pytensor/tensor/extra_ops.py Implements concat_with_broadcast utility function for tensor concatenation with broadcasting
pytensor/xtensor/rewriting/shape.py Refactors existing code to use the new concat_with_broadcast function
tests/tensor/test_extra_ops.py Adds comprehensive tests for the new concat_with_broadcast function
tests/tensor/rewriting/test_math.py Adds tests and benchmarks for the block diagonal optimization

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great

I suspect we'll be missing some cases when there's a Dimshuffle between the BlockDiag and the Dot like

block_diag(...) @ tensor3 where the block diag will have an expand_dims that I don't think is lifted into the block_diag?

Also we should rewrite block_diag(...).mT -> block_diag(e0.mT, ..., eN.mT) to lift the transpose through the inputs. As well as some elemwise(block_diag(...)) like mul.

In general we should have a list of elemwise functions that map zero->zero, this comes up in a lot of these linalg rewrites.

All this should be a separate issue, this PR is concise and clean like this.

@ricardoV94
Copy link
Member

Is this related to the COLA issue as well (to link if so)?

Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
@jessegrabowski jessegrabowski force-pushed the block-diag-dot-rewrite branch from 08e6736 to b2cba9f Compare July 26, 2025 09:57
@jessegrabowski
Copy link
Member Author

Is this related to the COLA issue as well (to link if so)?

We don't have this one on the COLA list, because we didn't have a subsection for dot. I will add it.

@jessegrabowski jessegrabowski merged commit 6f8bb55 into pymc-devs:main Jul 26, 2025
71 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting linalg Linear algebra performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add rewrite to optimize block_diag(a, b) @ c
2 participants