Skip to content

Stop subtensor merge bonanza#2098

Merged
ricardoV94 merged 8 commits into
pymc-devs:mainfrom
ricardoV94:stop_subtensor_merge
May 15, 2026
Merged

Stop subtensor merge bonanza#2098
ricardoV94 merged 8 commits into
pymc-devs:mainfrom
ricardoV94:stop_subtensor_merge

Conversation

@ricardoV94

@ricardoV94 ricardoV94 commented Apr 29, 2026

Copy link
Copy Markdown
Member

Summary

Closes #112
Closes #1283

local_subtensor_merge no longer expands into combinatorial switch/min/max trees on Scan outputs with symbolic shapes. The scan memory rewrites (scan_reduce_nsteps, scan_reduce_trace) still trim buffers all the way down to taps + 1 whenever a chain of constant-bound subtensors ends in a constant scalar.

What changed

1. Split scan/rewriting.py into a package

The 3100-line monolith is split into focused modules (db, inplace, io, merge, push_out, trace, utils) with matching test files. No behavioral changes.

2. Fix Scan.infer_shape for 0-d untraced sit_sot outputs

When scan_reduce_trace converts a sit_sot into an untraced sit_sot (0-d output), infer_shape now handles the missing dimension correctly.

3. Gate local_subtensor_merge to prevent symbolic blowup

Only call the symbolic merge_two_slices when slice bounds (and shapes, for slice+scalar) are constant. A new _merge_slice_into_slice_no_shape_ref handles the common slice-on-slice cases without consulting shape: forward × forward (sign-aware bound combination, e.g. x[1:-1][1:-1]x[2:-2]), x[a:b][::-1], x[::-1][a:b], x[::-1][a:b:-1], and x[a:b:-1][::-1] (last restricted to non-negative bounds).

Aggressive scalar-into-slice merging moves to local_subtensor_merge_integer (shape_unsafe), registered in both canonicalize and specialize. This subsumes the old while_scan_merge_subtensor_last_element scan rewrite. Constant checks use direct isinstance(v, Constant) rather than recursive get_scalar_constant_value — by the time these rewrites fire, canonicalization has already simplified index expressions to direct constants.

4. Split scan_save_mem into scan_reduce_nsteps and scan_reduce_trace

  • scan_reduce_nsteps — when every client of a Scan output is a constant scalar index, reduce n_steps to the minimum that covers those reads and rewrite each client to a negative index against the trimmed trace.
  • scan_reduce_trace — shortens outer buffers and n_steps to the smallest range any client actually reads. Walks slice chains directly off the graph (no get_canonical_form_slice), reading buffer requirements straight from the (now-folded) negative indices. Caps prealloc extra_size at n_steps to avoid uninitialized slots.

Invariant honored

result[...][-1] (static [-1] through any chain of constant-bound, ±1-step slices) always reduces to a unit buffer, with both constant and symbolic n_steps.

Verified

All patterns below reduce to a 0-d untraced scan output (unit buffer) with both static and symbolic n_steps:

result[-1], result[-3:][-1], result[3:][-1], result[::-1][0], result[5::-1][-1], result[:-1][::-1][-1]

The result[-1:] variant also reduces the buffer but keeps a 1-d output via ExpandDims.

Benchmarks

Compile-time and post-rewrite node counts, mode excluding("fusion") so the rewriter output is visible.

Benchmark Before After Speedup
x[1:-1] × 3 82 apply, 228 ms 2 apply, 27 ms
x[1:-1] × 5 166 apply, 490 ms 2 apply, 45 ms 11×
x[1:-1] × 8 292 apply, 835 ms 2 apply, 74 ms 11×
grad(xs[-1], x0) symbolic-n Scan (#112) 348 apply, 1260 ms 22 apply, 114 ms 11×

Limitations (deliberate)

  • scan_reduce_nsteps bails on while-scans (can't statically bound iteration count).
  • Subtensor merge doesn't handle slice steps with magnitude > 1.
  • concatenate([rev, zeros])[k] from the while-scan gradient path is opaque to subtensor merge (would need Consider lifting Subtensor through Joins #919).

@ricardoV94

Copy link
Copy Markdown
Member Author

Possibly fixes: #1288 need to check

@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch 6 times, most recently from 30c36f6 to 8456244 Compare April 30, 2026 21:05
@ricardoV94

Copy link
Copy Markdown
Member Author

Comparing to #2109 the CI for numba/scan went from 12m30->10m (py3.14) and 17m->13 (py3.11)

@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch 3 times, most recently from 6c30c01 to 5c5cbc1 Compare May 8, 2026 00:21
@ricardoV94

Copy link
Copy Markdown
Member Author

Split scan rewrite into separate files, don't be shocked by the LOC

@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch from 5c5cbc1 to 0640656 Compare May 8, 2026 00:35
@ricardoV94 ricardoV94 marked this pull request as ready for review May 8, 2026 00:42
@ricardoV94

ricardoV94 commented May 8, 2026

Copy link
Copy Markdown
Member Author

Ready for review, all goals achieved.

  1. subtensor_merge kept under control
  2. subtensor_merge_unsafe handles the idx[slice][idx] case that's important for scan_save_mem
  3. scan_save_mem doesn't blow up graph either
  4. scan_save_mem recognizes out[-1] and out[-1:] (only last state needed) and optimizes buffer accondingly, even with symbolic n_steps.
  5. split concerns between reduce_nsteps and reduce_buffersize (the last one is rather destructive, you can't autodiff after it), the first one is fine
  6. Decided to simplify code and rely on subtensor_merge_unsafe instead of adding a cute Assert(n_steps>min) on scan. Code is in the history if anybody is interested. Either way it only guards against invalid index graphs, doesn't break any correct code.

Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch 2 times, most recently from 25c14d1 to d2febf5 Compare May 8, 2026 13:56
Comment thread pytensor/tensor/rewriting/subtensor.py
Comment thread pytensor/tensor/rewriting/subtensor.py
Comment thread pytensor/tensor/rewriting/subtensor.py
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch 2 times, most recently from 170cd36 to 6447ec8 Compare May 8, 2026 21:56
Comment thread pytensor/tensor/rewriting/subtensor.py Outdated
@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch from 58c873b to 90ab7c9 Compare May 13, 2026 13:41
@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch 3 times, most recently from 937300f to 07b0589 Compare May 14, 2026 09:48
@ricardoV94 ricardoV94 requested a review from jessegrabowski May 14, 2026 10:13
@ricardoV94

Copy link
Copy Markdown
Member Author

@jessegrabowski I've addressed your comments and further simplified as much as I could.

The commit that refactors scan/rewriting.py into scan/rewriting/*.py is the first, that should have no code change

@jessegrabowski jessegrabowski left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me. I don't claim to 100% understand the last 2 commits (scan_reduce_trace and and and scan_reduce_nsteps) and I didn't do the effort of pulling down the commit and having claude slam the load-bearing functions with as many test cases as humanly possible. I assume you have though so I'm happy to bring this in.

Reorganize the monolithic rewriting module into focused submodules:
trace, db, inplace, io, merge, push_out, and utils.
Split tests/scan/test_rewriting.py to mirror the structure.
@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch 2 times, most recently from d6e315e to e326aa3 Compare May 15, 2026 20:24
The unconditional merge produced large switch/min/max trees whenever any
component of the chain (slice bounds or shapes) was symbolic — most
visibly on Scan outputs whose stripping slice is rolled together with a
client index. Add ``_can_merge_simply``: only merge slice+slice when both
steps are constant, and slice+scalar when all components and shapes are
constant. Non-mergeable dimensions stay as a separate outer Subtensor.

Add ``local_subtensor_merge_unsafe`` (tag ``shape_unsafe``) that handles
slice+scalar with step ±1 more aggressively without bounds checks, for
the cases where the safe merge bails out.
Add @register_canonicalize("shape_unsafe") so integer subtensor chain
merges fire during canonicalize. Remove while_scan_merge_subtensor_last_element
whose job is now subsumed by the upstream merge.
scan_reduce_nsteps (position 1.611): reduce n_steps to the minimum
that covers all constant-index clients.

scan_reduce_trace (position 1.612): shorten per-output trace buffers
to the tail each client actually reads.

scan_remove_unused now runs first (1.605) so these rewrites don't
process outputs that will be pruned.
Replace get_canonical_form_slice with direct slice walking via
_strip_chain_negative_start (for out[init_l:][-k:] chains) and
_translate_positive_to_negative (for constant-length buffers).
Cap prealloc extra_size at n_steps.
@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch from e326aa3 to da43b8a Compare May 15, 2026 20:38
@ricardoV94 ricardoV94 merged commit 5ba565a into pymc-devs:main May 15, 2026
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Inplace on sit-sot / mit-sot when nsteps is symbolic local_subtensor_merge can complicate graphs

2 participants