More rewrites for ExtractDiag#2045
Conversation
|
I've been playing with AdvancedSubtensor rewrites from analysis of the Wishart PR and we should talk. I feel AdvancedSubtensor rewrites are more general for these cases and not necessarily harder. We may want to rewrite ExtractDiagonal (and AllocDiagona) as the AdvancedSubtensor/AdvancedSetSubtensor version, then immediately call on these rewrites when we know they'll help. |
|
I'll pause on this pending that discussion then. I will say though that some of these have big savings, specifically the elemwise lift. |
No reason not to do this. This is actually something that generalizes to other Ops somewhat. Sum of Transpose -> Sum with transposed axes, and the like. Always good to remove cruft that doesn't affect computation Otherwise all rewrites you mention would work out of the box if we materialize ExtractDiag (and AllocDia) as the equivalent advanced subtensor / advanced set_subtensor and let the rewrites from #2061 act. Except for
Which I think we should tackle now. I had been worried about duplicate indices, but if they are constant (or provably unique like created from symbolic arange) we don't need to worry about. We should always reduce before we compute. We need to think about this. Maybe ExtractDiag AllocDiag could be an OFG so it's trivial for rewrites to materialize the low lever IR and reuse the general read-write rewrite logic we have? |
|
@ricardoV94 how does this PR need to change now that #2061 is merged |
|
@jessegrabowski what I would suggest trying (may prove not the best solution) is to lower all ExtractDiag of Diag | Eye | Alloc | Eye * x -> AdvancedSubtensor(arange(n), arange(n)) of AdvancedSetSubtensor(zeros(n, n), 1 | x, arange(n), arange(n)), and see if our existing rewrites simplify all. This lowering would be done in the rewrite, when we think it's going to work as is done in: pytensor/pytensor/tensor/rewriting/subtensor.py Lines 2151 to 2220 in d35fb51 (the k changes the arange bit) Make sure to reuse the same arange(n) between AdvancedSubtensor and AdvancedSetSubtensor in case it's not constant for our rewrites that work with symbolic x. If n is constant use constant indices for the more general rewrites. I think this will cover all ExtractDiag of Diag writes. Separately keep the transpose -> k=-k rewrite and the ExtractDiag(Elemwise). Just be careful about broadcasted inputs, if you weren't already: ExtractDiag(mat + v) -> ExtractDiag(mat) + v (possibly with some squeezing/dimshuffle. We don't have that rewrite for Elemwise, and it would be nice to keep having an ExtractDiag in other end anyway, because ExtractDiag is always simpler to reason about. |
e161f89 to
6d5b95b
Compare
| expected = pt.diagonal(x) + c.squeeze(axis=1) | ||
| assert_equal_computations([rewritten], [expected]) | ||
|
|
||
| def test_extract_diag_of_elemwise_row_broadcast_offset(self): |
There was a problem hiding this comment.
should test with batch dims at least once
|
@jessegrabowski I pushed some slop WIP:
Together with the ExtractDiag/AllocDiag lowering this subsumes most rewrites you thought about:
Leaving only the ExtractDiag(Transpose(x)) Only the read_of_write needs constant indices / static shape to do its magic The rewrite back into ExtractDiag and slice means we don't have to worry about rematerializing eagerly (to not pay the copy cost that AdvancedSubtensor takes) The same template can be used for the other subtensor_lift (CAReduce, Softmax)... The reason why I think this is worth the trouble is that is more general (users are not penalized for having written extract_diag by hand), and extends to other common views (lower_diag, blocks, split, etc...). The only reason ExtractDiag is really special is that it's a view, and not a copy, so we just need to respect that optmization path. |
797405f to
caf05f7
Compare
|
Time to squash commits, remove the alloc_like bad patch in favor of #2118 and ask the bot to logically split the changes, so we can reassess? The whole pipeline makes sense to me. Lifting ExtractDiag by itself would always be simpler but:
There were however many careful checks that were needed, (are things non-negative, are the bounds >= dim length, ...etc). This are really things that showed up in some form or another every time I started working with index operations, so while it may seem overkill here I think they will pay dividends. In fact I already cherry picked the non negative thing to #2098 What I ended up not bothering with for now: ExtractDiag(Eye | AllocDiag), those use the symbolic form, since we didn't have rewrites for these (read_of_write would require static shape to be able to do anything, which is strictly worse). We can revisit extension for arbitrary constant read of eye / alloc_diag later. I also don't think we have an arbitrary read of extract_diag, something we can visit later also if it shows up |
|
Also a quick test that x[arange(k, x.shape, i)] -> x[slice], and x[pt.diag_indices()] -> x.diagonal would be good. These are nice gains regardless |
7df266a to
67f878a
Compare
I reset the branch and had claude help me collect all the changes into 12 logical commits. lmk what you think. There is zero diff to what we had before i did the re-arrangement, and relevant tests pass on every commit. |
|
I've noticed these commit messages are becoming more verbose than the code changes |
I always give it specific instructions to keep the commit messages <80 characters. I notice you don't so I specifically left it. Happy to fix. |
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
67f878a to
e55cb89
Compare
These came up when I was working on #2032. They're not related to that perse so I split them off. We're missing a lot of simple rewrites for ExtractDiag. I added: