Stop subtensor merge bonanza#2098
Conversation
|
Possibly fixes: #1288 need to check |
30c36f6 to
8456244
Compare
|
Comparing to #2109 the CI for numba/scan went from 12m30->10m (py3.14) and 17m->13 (py3.11) |
6c30c01 to
5c5cbc1
Compare
|
Split scan rewrite into separate files, don't be shocked by the LOC |
5c5cbc1 to
0640656
Compare
|
Ready for review, all goals achieved.
|
25c14d1 to
d2febf5
Compare
170cd36 to
6447ec8
Compare
58c873b to
90ab7c9
Compare
937300f to
07b0589
Compare
|
@jessegrabowski I've addressed your comments and further simplified as much as I could. The commit that refactors scan/rewriting.py into scan/rewriting/*.py is the first, that should have no code change |
jessegrabowski
left a comment
There was a problem hiding this comment.
looks good to me. I don't claim to 100% understand the last 2 commits (scan_reduce_trace and and and scan_reduce_nsteps) and I didn't do the effort of pulling down the commit and having claude slam the load-bearing functions with as many test cases as humanly possible. I assume you have though so I'm happy to bring this in.
Reorganize the monolithic rewriting module into focused submodules: trace, db, inplace, io, merge, push_out, and utils. Split tests/scan/test_rewriting.py to mirror the structure.
d6e315e to
e326aa3
Compare
The unconditional merge produced large switch/min/max trees whenever any component of the chain (slice bounds or shapes) was symbolic — most visibly on Scan outputs whose stripping slice is rolled together with a client index. Add ``_can_merge_simply``: only merge slice+slice when both steps are constant, and slice+scalar when all components and shapes are constant. Non-mergeable dimensions stay as a separate outer Subtensor. Add ``local_subtensor_merge_unsafe`` (tag ``shape_unsafe``) that handles slice+scalar with step ±1 more aggressively without bounds checks, for the cases where the safe merge bails out.
Add @register_canonicalize("shape_unsafe") so integer subtensor chain
merges fire during canonicalize. Remove while_scan_merge_subtensor_last_element
whose job is now subsumed by the upstream merge.
scan_reduce_nsteps (position 1.611): reduce n_steps to the minimum that covers all constant-index clients. scan_reduce_trace (position 1.612): shorten per-output trace buffers to the tail each client actually reads. scan_remove_unused now runs first (1.605) so these rewrites don't process outputs that will be pruned.
Replace get_canonical_form_slice with direct slice walking via _strip_chain_negative_start (for out[init_l:][-k:] chains) and _translate_positive_to_negative (for constant-length buffers). Cap prealloc extra_size at n_steps.
e326aa3 to
da43b8a
Compare
Summary
Closes #112
Closes #1283
local_subtensor_mergeno longer expands into combinatorialswitch/min/maxtrees on Scan outputs with symbolic shapes. The scan memory rewrites (scan_reduce_nsteps,scan_reduce_trace) still trim buffers all the way down totaps + 1whenever a chain of constant-bound subtensors ends in a constant scalar.What changed
1. Split
scan/rewriting.pyinto a packageThe 3100-line monolith is split into focused modules (
db,inplace,io,merge,push_out,trace,utils) with matching test files. No behavioral changes.2. Fix
Scan.infer_shapefor 0-d untraced sit_sot outputsWhen
scan_reduce_traceconverts a sit_sot into an untraced sit_sot (0-d output),infer_shapenow handles the missing dimension correctly.3. Gate
local_subtensor_mergeto prevent symbolic blowupOnly call the symbolic
merge_two_sliceswhen slice bounds (and shapes, for slice+scalar) are constant. A new_merge_slice_into_slice_no_shape_refhandles the common slice-on-slice cases without consulting shape: forward × forward (sign-aware bound combination, e.g.x[1:-1][1:-1]→x[2:-2]),x[a:b][::-1],x[::-1][a:b],x[::-1][a:b:-1], andx[a:b:-1][::-1](last restricted to non-negative bounds).Aggressive scalar-into-slice merging moves to
local_subtensor_merge_integer(shape_unsafe), registered in bothcanonicalizeandspecialize. This subsumes the oldwhile_scan_merge_subtensor_last_elementscan rewrite. Constant checks use directisinstance(v, Constant)rather than recursiveget_scalar_constant_value— by the time these rewrites fire, canonicalization has already simplified index expressions to direct constants.4. Split
scan_save_memintoscan_reduce_nstepsandscan_reduce_tracescan_reduce_nsteps— when every client of a Scan output is a constant scalar index, reducen_stepsto the minimum that covers those reads and rewrite each client to a negative index against the trimmed trace.scan_reduce_trace— shortens outer buffers andn_stepsto the smallest range any client actually reads. Walks slice chains directly off the graph (noget_canonical_form_slice), reading buffer requirements straight from the (now-folded) negative indices. Caps preallocextra_sizeatn_stepsto avoid uninitialized slots.Invariant honored
result[...][-1](static[-1]through any chain of constant-bound, ±1-step slices) always reduces to a unit buffer, with both constant and symbolicn_steps.Verified
All patterns below reduce to a 0-d untraced scan output (unit buffer) with both static and symbolic
n_steps:result[-1],result[-3:][-1],result[3:][-1],result[::-1][0],result[5::-1][-1],result[:-1][::-1][-1]The
result[-1:]variant also reduces the buffer but keeps a 1-d output viaExpandDims.Benchmarks
Compile-time and post-rewrite node counts, mode
excluding("fusion")so the rewriter output is visible.x[1:-1]× 3x[1:-1]× 5x[1:-1]× 8grad(xs[-1], x0)symbolic-nScan (#112)Limitations (deliberate)
scan_reduce_nstepsbails on while-scans (can't statically bound iteration count).concatenate([rev, zeros])[k]from the while-scan gradient path is opaque to subtensor merge (would need Consider lifting Subtensor through Joins #919).