-
Notifications
You must be signed in to change notification settings - Fork 137
Optimize matmuls involving block diagonal matrices #1493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize matmuls involving block diagonal matrices #1493
Conversation
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (98.00%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1493 +/- ##
==========================================
+ Coverage 81.53% 81.54% +0.01%
==========================================
Files 230 230
Lines 53012 53048 +36
Branches 9412 9419 +7
==========================================
+ Hits 43222 43257 +35
Misses 7361 7361
- Partials 2429 2430 +1
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Some minor optimization questions
68591ad
to
7c3820b
Compare
I'm still struggling with ExpandDims "blocking" between Dot and BlockDiag in the batch case here, do you have a suggestion for how to handle it? |
What sort of graph exactly? |
Actually the issue is the Join in the blockwise case. Given a graph like:
We pull out the 3 inputs and split This should all be correct, but now the join fails, because the batch dim on the 3 results is not the same. Looking at it in the debugger, it seems like i need to tile To clarify that last point, here is
And the
As you can see, it's just the first row duplicated 3 times. |
That's what I'm proposing in the first part of #1552 ;) |
I added a The last case I'm still struggling with is when there are Dimshuffles between the Dot and the BlockDiag, for example:
I put the rewrite into specialize because I was hoping by that point these dimshuffles would have been cleaned up, but no such luck. |
I cobbled together something that works, but not sure if it's ideal |
a9655b0
to
abdce8e
Compare
Yeah you want the rewrite before the batched_to_core_dot in specialize. Register it in canonicalize or stabilize? |
It's in stabilize now |
abdce8e
to
749bfda
Compare
Use new helper in xt.concat Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
749bfda
to
08e6736
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes matrix multiplication involving block diagonal matrices by rewriting dot products to operate on smaller sub-matrices instead of the full block diagonal matrix. The optimization leverages the sparse structure of block diagonal matrices where only the diagonal blocks contain non-zero values.
- Adds a new rewrite rule that transforms
dot(block_diag(A, B), C)
intoconcat(dot(A, C_split), dot(B, C_split))
- Introduces a
concat_with_broadcast
utility function for concatenating tensors with automatic broadcasting - Provides significant performance improvements (2x+ speedup) by avoiding expensive operations on large sparse matrices
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
pytensor/tensor/rewriting/math.py | Adds the main optimization rewrite rule for block diagonal matrix multiplication |
pytensor/tensor/extra_ops.py | Implements concat_with_broadcast utility function for tensor concatenation with broadcasting |
pytensor/xtensor/rewriting/shape.py | Refactors existing code to use the new concat_with_broadcast function |
tests/tensor/test_extra_ops.py | Adds comprehensive tests for the new concat_with_broadcast function |
tests/tensor/rewriting/test_math.py | Adds tests and benchmarks for the block diagonal optimization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great
I suspect we'll be missing some cases when there's a Dimshuffle between the BlockDiag and the Dot like
block_diag(...) @ tensor3
where the block diag will have an expand_dims that I don't think is lifted into the block_diag
?
Also we should rewrite block_diag(...).mT -> block_diag(e0.mT, ..., eN.mT)
to lift the transpose through the inputs. As well as some elemwise(block_diag(...))
like mul
.
In general we should have a list of elemwise functions that map zero->zero, this comes up in a lot of these linalg rewrites.
All this should be a separate issue, this PR is concise and clean like this.
Is this related to the COLA issue as well (to link if so)? |
Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
08e6736
to
b2cba9f
Compare
We don't have this one on the COLA list, because we didn't have a subsection for |
Description
This PR adds a rewrite to optimize matrix multiplication involving block diagonal matrices. When we have a a matrix
X = BlockDiag(A, B)
, when you doZ = X @ Y
, there's no interaction between terms in theA
part andB
part of theX
matrix. So the dot can be instead computed asrow_stack(A @ Y[:X.shape[0]], B @ Y[X.shape[0]:]
(or in the general case,Y
can be split inton
pieces with appropriate shapes, and dorow_stack([diag_component @ y_split for diag_component, y_split in zip(BlockDiag.inputs, split(Y, *args)])
. If the case where the blockdiag matrix is right-multiplying, you instead col_stack and slice on axis=1.Anyway, it's a lot faster to do this, because matmuls scale really badly in the dimension of the input, so doing two smaller operations is preferred. Here are the benchmarks, small has
n=10
, medium hasn=100
, large hasn=1000
. But in all cases it shows at least 2x speedup.Related Issue
block_diag(a, b) @ c
#1044 Implement xtensor-like concat helper #1552Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1493.org.readthedocs.build/en/1493/