Skip to content

More rewrites for ExtractDiag#2045

Open
jessegrabowski wants to merge 15 commits into
pymc-devs:mainfrom
jessegrabowski:extract-diag-rewrites
Open

More rewrites for ExtractDiag#2045
jessegrabowski wants to merge 15 commits into
pymc-devs:mainfrom
jessegrabowski:extract-diag-rewrites

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented Apr 12, 2026

These came up when I was working on #2032. They're not related to that perse so I split them off. We're missing a lot of simple rewrites for ExtractDiag. I added:

  • ExtractDiag(Eye) -> Ones or Zeros, depending on k
  • ExtractDiag(Eye * x) -> Alloc(x) (new shape)
  • ExtractDiag(Zeros / Ones) -> Zeros / Ones of new shape
  • ExtractDiag(Alloc) -> Alloc (with new shape)
  • ExtractDiag(Elemwise(a, b)) -> Elemwise(ExtractDiag(a), ExtractDiag(b)) (plus some broadcasting logic)
  • ExtractDiag(Transpose(x), offset=k) -> ExtractDiag(X, offset=-k) (I understand transpose is just a free view but it removes a blocker to further rewrites that no longer need to see through the transpose)

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 14, 2026

I've been playing with AdvancedSubtensor rewrites from analysis of the Wishart PR and we should talk.

I feel AdvancedSubtensor rewrites are more general for these cases and not necessarily harder. We may want to rewrite ExtractDiagonal (and AllocDiagona) as the AdvancedSubtensor/AdvancedSetSubtensor version, then immediately call on these rewrites when we know they'll help.

Comment thread pytensor/tensor/rewriting/basic.py Outdated
Comment thread pytensor/tensor/rewriting/basic.py Outdated
Comment thread pytensor/tensor/rewriting/basic.py Outdated
@jessegrabowski
Copy link
Copy Markdown
Member Author

I'll pause on this pending that discussion then. I will say though that some of these have big savings, specifically the elemwise lift.

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 23, 2026

ExtractDiag(Transpose(x), offset=k) -> ExtractDiag(X, offset=-k) (I understand transpose is just a free view but it removes a blocker to further rewrites that no longer need to see through the transpose)

No reason not to do this. This is actually something that generalizes to other Ops somewhat. Sum of Transpose -> Sum with transposed axes, and the like. Always good to remove cruft that doesn't affect computation

Otherwise all rewrites you mention would work out of the box if we materialize ExtractDiag (and AllocDia) as the equivalent advanced subtensor / advanced set_subtensor and let the rewrites from #2061 act.

Except for

ExtractDiag(Elemwise(a, b)) -> Elemwise(ExtractDiag(a), ExtractDiag(b)) (plus some broadcasting logic)

Which I think we should tackle now. I had been worried about duplicate indices, but if they are constant (or provably unique like created from symbolic arange) we don't need to worry about. We should always reduce before we compute.


We need to think about this. Maybe ExtractDiag AllocDiag could be an OFG so it's trivial for rewrites to materialize the low lever IR and reuse the general read-write rewrite logic we have?

@jessegrabowski
Copy link
Copy Markdown
Member Author

@ricardoV94 how does this PR need to change now that #2061 is merged

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 29, 2026

@jessegrabowski what I would suggest trying (may prove not the best solution) is to lower all ExtractDiag of Diag | Eye | Alloc | Eye * x -> AdvancedSubtensor(arange(n), arange(n)) of AdvancedSetSubtensor(zeros(n, n), 1 | x, arange(n), arange(n)), and see if our existing rewrites simplify all. This lowering would be done in the rewrite, when we think it's going to work as is done in:

@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([ExtractDiag])
def local_extract_diag_of_write(fgraph, node):
"""Delegate ``extract_diag(advanced_inc_subtensor(...))`` to the constant-indices rewrite.
Rewrites ``extract_diag(x, offset=k)`` as the equivalent
``x[..., arange(d) + max(0, -k), arange(d) + max(0, k), ...]`` and
calls ``local_advanced_read_of_write_constant_indices`` to do the
work. Since ``extract_diag`` is a zero-copy view, we only commit the
replacement when the downstream rewrite eliminates the gather.
Requires statically-known sizes on the two diagonal axes.
"""
op = node.op
inner = node.inputs[0]
# AdvancedIncSubtensor1 is intentionally not accepted: it writes whole
# rows/slices on a single axis, not specific (i, j) positions, so it
# can't express "write the diagonal" the way two paired index arrays can.
if not (inner.owner and isinstance(inner.owner.op, AdvancedIncSubtensor)):
return None
# Need static sizes on the two diagonal axes to build constant indices.
dim_a = inner.type.shape[op.axis1]
dim_b = inner.type.shape[op.axis2]
if dim_a is None or dim_b is None:
return None
k = op.offset
row_offset = max(0, -k)
col_offset = max(0, k)
d = min(dim_a - row_offset, dim_b - col_offset)
if d <= 0:
return None
# Build equivalent AdvancedSubtensor: inner[..., arange(d) + row_offset, ..., arange(d) + col_offset, ...]
base_arange = np.arange(d, dtype=np.int64)
rows = pytensor.tensor.as_tensor_variable(base_arange + row_offset)
cols = pytensor.tensor.as_tensor_variable(base_arange + col_offset)
idxs = [slice(None)] * inner.type.ndim
idxs[op.axis1] = rows
idxs[op.axis2] = cols
equiv = inner[tuple(idxs)]
if not (equiv.owner and isinstance(equiv.owner.op, AdvancedSubtensor)):
return None
# Delegate to the general read-after-write rewrite.
result = local_advanced_read_of_write_constant_indices.fn(fgraph, equiv.owner)
if not result:
return None
# Stay zero-copy where possible: when the simplification reduced to a
# gather of the inner write's base at our diagonal-arange pattern (i.e.
# the no-coverage case where the write is irrelevant for this read),
# re-emit as ExtractDiag so we keep the view semantics of the original.
base = inner.owner.inputs[0]
[result_var] = result
if (
result_var.owner
and isinstance(result_var.owner.op, AdvancedSubtensor)
and result_var.owner.inputs[0] is base
):
out = base.diagonal(offset=k, axis1=op.axis1, axis2=op.axis2)
copy_stack_trace(node.outputs[0], out)
return [out]
copy_stack_trace(node.outputs[0], result)
return result

(the k changes the arange bit)

Make sure to reuse the same arange(n) between AdvancedSubtensor and AdvancedSetSubtensor in case it's not constant for our rewrites that work with symbolic x. If n is constant use constant indices for the more general rewrites.

I think this will cover all ExtractDiag of Diag writes.

Separately keep the transpose -> k=-k rewrite and the ExtractDiag(Elemwise). Just be careful about broadcasted inputs, if you weren't already: ExtractDiag(mat + v) -> ExtractDiag(mat) + v (possibly with some squeezing/dimshuffle.

We don't have that rewrite for Elemwise, and it would be nice to keep having an ExtractDiag in other end anyway, because ExtractDiag is always simpler to reason about.

@jessegrabowski jessegrabowski force-pushed the extract-diag-rewrites branch from e161f89 to 6d5b95b Compare May 2, 2026 20:34
@jessegrabowski jessegrabowski requested a review from ricardoV94 May 5, 2026 02:35
Comment thread pytensor/tensor/rewriting/basic.py Outdated
Comment thread pytensor/tensor/rewriting/basic.py Outdated
Comment thread pytensor/tensor/rewriting/basic.py Outdated
Comment thread pytensor/tensor/rewriting/basic.py Outdated
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
Comment thread tests/tensor/rewriting/test_basic.py Outdated
Comment thread tests/tensor/rewriting/test_basic.py Outdated
Comment thread tests/tensor/rewriting/test_basic.py Outdated
expected = pt.diagonal(x) + c.squeeze(axis=1)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_elemwise_row_broadcast_offset(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should test with batch dims at least once

Comment thread pytensor/tensor/rewriting/basic.py Outdated
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
Comment thread pytensor/tensor/rewriting/subtensor_lift.py Outdated
@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented May 6, 2026

@jessegrabowski I pushed some slop WIP:

  1. Extend subtensor_of_alloc to advanced indices (when provably not more work)
  2. Extend subtensor_of_alloc to batch_dims (elemwise an blockwise) (when provably not more work)
  3. Rewrite to recognize x[arange(n), arange(n)] -> ExtractDiag, and x[arange(n)] -> x[slice]

Together with the ExtractDiag/AllocDiag lowering this subsumes most rewrites you thought about:

  1. ExtractDiag(Eye)
  2. Extractdiag(Elemwise)
  3. ExtractDiag(Eye * x) (first elemwise mul then eye, no special case needed)
  4. ExtractDiag(Alloc)

Leaving only the ExtractDiag(Transpose(x))

Only the read_of_write needs constant indices / static shape to do its magic

The rewrite back into ExtractDiag and slice means we don't have to worry about rematerializing eagerly (to not pay the copy cost that AdvancedSubtensor takes)

The same template can be used for the other subtensor_lift (CAReduce, Softmax)... The reason why I think this is worth the trouble is that is more general (users are not penalized for having written extract_diag by hand), and extends to other common views (lower_diag, blocks, split, etc...). The only reason ExtractDiag is really special is that it's a view, and not a copy, so we just need to respect that optmization path.

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented May 7, 2026

Time to squash commits, remove the alloc_like bad patch in favor of #2118 and ask the bot to logically split the changes, so we can reassess?

The whole pipeline makes sense to me. Lifting ExtractDiag by itself would always be simpler but:

  1. Users may have written x[pt.diag_indices] form
  2. If we want to help users we need a way to detect it
  3. If we can detect it we can temporarily lower extract_diag to the advanced subtensor repr, and later recover it
  4. If we can do this we gain:
    1. Rewrites that work with arbitrary advanced indexing will automatically apply to it
    2. Rewrite that would only work with arbitrary advanced indexing (looking at read_of_write_constant_indices) will apply

There were however many careful checks that were needed, (are things non-negative, are the bounds >= dim length, ...etc). This are really things that showed up in some form or another every time I started working with index operations, so while it may seem overkill here I think they will pay dividends. In fact I already cherry picked the non negative thing to #2098

What I ended up not bothering with for now: ExtractDiag(Eye | AllocDiag), those use the symbolic form, since we didn't have rewrites for these (read_of_write would require static shape to be able to do anything, which is strictly worse). We can revisit extension for arbitrary constant read of eye / alloc_diag later. I also don't think we have an arbitrary read of extract_diag, something we can visit later also if it shows up

@ricardoV94
Copy link
Copy Markdown
Member

Also a quick test that x[arange(k, x.shape, i)] -> x[slice], and x[pt.diag_indices()] -> x.diagonal would be good. These are nice gains regardless

@jessegrabowski jessegrabowski force-pushed the extract-diag-rewrites branch from 7df266a to 67f878a Compare May 10, 2026 01:48
@jessegrabowski
Copy link
Copy Markdown
Member Author

Time to squash commits, remove the alloc_like bad patch in favor of #2118 and ask the bot to logically split the changes, so we can reassess?

I reset the branch and had claude help me collect all the changes into 12 logical commits. lmk what you think. There is zero diff to what we had before i did the re-arrangement, and relevant tests pass on every commit.

@ricardoV94
Copy link
Copy Markdown
Member

I've noticed these commit messages are becoming more verbose than the code changes

@jessegrabowski
Copy link
Copy Markdown
Member Author

I've noticed these commit messages are becoming more verbose than the code changes

I always give it specific instructions to keep the commit messages <80 characters. I notice you don't so I specifically left it. Happy to fix.

Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
jessegrabowski and others added 13 commits May 10, 2026 19:01
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-Authored-By: Ricardo Vieira <ricardo.vieira1994@gmail.com>
@jessegrabowski jessegrabowski force-pushed the extract-diag-rewrites branch from 67f878a to e55cb89 Compare May 10, 2026 23:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request graph rewriting

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants