[torch.compile] Bunch of small changes needed for enabling torch.compile#3130
[torch.compile] Bunch of small changes needed for enabling torch.compile#3130pggPL wants to merge 5 commits into
Conversation
…stants; fix SP memory leak; test suite hook-up Wrap CommOverlapCore pybind11 methods that return compile-time constants so torch.compile(fullgraph=True) can trace through them without graph breaks: - `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py; `_ub_is_fp8()` in gemm.py - `with_cublasmp()` → `ub_is_cublasmp()` in base.py All callers in linear.py, layernorm_linear.py, layernorm_mlp.py, base.py, gemm.py, userbuffers_backward_linear.py and userbuffers_forward_linear.py updated. Fix quantized grad_output not being freed early for column-parallel SP backward. Row-parallel SP already called clear_tensor_data(grad_output) to release the gathered tensor; column-parallel SP quantizes grad_output to Float8TensorStorage but never freed it before returning. Under torch.compile reduce-overhead this leaves 3 live pool tensors at recording end and triggers "Detected 3 tensor(s) in the cudagraph pool not tracked as outputs". Extend the existing clear_tensor_data guard to cover both parallel modes. Fix custom-recipe quantizer state being re-initialised on every forward call even when the recipe object has not changed. The existing early-exit for CustomRecipeState was missing an identity check on the recipe object, so any repeated call with the same recipe would bypass the early-return and rebuild quantizers unnecessarily. Add `if recipe_state.recipe is recipe: return` to restore the intended caching behaviour. Add test_torch_compile.py to L0_pytorch_unittest so the autocast and existing compile tests run in CI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans LinearBwdArgs stored the entire FP8 recipe object so the backward could extract fp8_gemm_dgrad.use_split_accumulator and fp8_gemm_wgrad.use_split_accumulator at GEMM time. Recipe objects hold process-group references and are not serialisable as compile-time constants, making them incompatible with torch.compile custom-op paths. Replace fp8_recipe with two plain bool fields: - dgrad_use_split_accumulator (default _2X_ACC_DGRAD) - wgrad_use_split_accumulator (default _2X_ACC_WGRAD) These are resolved once in _linear_setup_ctx and passed into the args struct, so the backward consumes scalars instead of a live recipe object. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR bundles five targeted
Confidence Score: 5/5All changes are well-scoped and non-breaking; the column-SP FP8 tensor free is correctly placed after the wgrad GEMM and the freed variable is not accessed again. Each change is narrowly targeted: the split-accumulator refactor preserves the same defaults for non-FP8 and reproduces the same hasattr guards for FP8 recipes; the grad_output free only triggers after the GEMM that consumed it; the CustomRecipeState identity check uses No files require special attention. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant M as Linear.forward()
participant GSM as FP8GlobalStateManager
participant FA as LinearFwdArgs
participant BA as LinearBwdArgs
participant BW as _linear_backward()
M->>GSM: get_fp8_recipe()
GSM-->>M: _recipe
M->>M: resolve dgrad/wgrad split-accumulator bools
M->>FA: fwd_args(dgrad_use_split_accumulator, wgrad_use_split_accumulator)
FA->>BA: _linear_setup_ctx copies plain bools
Note over FA,BA: No recipe object stored
BW->>BW: "use_split_accumulator = bwd_args.dgrad_use_split_accumulator"
BW->>BW: dgrad GEMM
BW->>BW: "use_split_accumulator = bwd_args.wgrad_use_split_accumulator"
BW->>BW: wgrad GEMM
BW->>BW: clear_tensor_data(grad_output) [row-SP or col-SP+FP8]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant M as Linear.forward()
participant GSM as FP8GlobalStateManager
participant FA as LinearFwdArgs
participant BA as LinearBwdArgs
participant BW as _linear_backward()
M->>GSM: get_fp8_recipe()
GSM-->>M: _recipe
M->>M: resolve dgrad/wgrad split-accumulator bools
M->>FA: fwd_args(dgrad_use_split_accumulator, wgrad_use_split_accumulator)
FA->>BA: _linear_setup_ctx copies plain bools
Note over FA,BA: No recipe object stored
BW->>BW: "use_split_accumulator = bwd_args.dgrad_use_split_accumulator"
BW->>BW: dgrad GEMM
BW->>BW: "use_split_accumulator = bwd_args.wgrad_use_split_accumulator"
BW->>BW: wgrad GEMM
BW->>BW: clear_tensor_data(grad_output) [row-SP or col-SP+FP8]
Reviews (3): Last reviewed commit: "Provide explicit QuantizerRoles in torch..." | Re-trigger Greptile |
|
/te-ci pytorch L1 |
…t_result get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a reset, destroy_ub + re-init with different FP8 settings would read stale values until recompile. Only affects in-memory caches, not disk. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit the no-roles warning, which graph-breaks under fullgraph=True. qfactory dispatches on role.tensor_type instead of a pre-baked string key. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
| # Compiled graphs may have baked is_fp8_ubuf() via assume_constant_result; | ||
| # reset so re-init with different settings doesn't read stale constants. | ||
| torch.compiler.reset() |
There was a problem hiding this comment.
The current helper call sites are all inside @no_torch_dynamo() forwards and the added test_torch_compile.py coverage does not exercise user buffers or it's done implicitly in the test?
Is it possible avoid a process-wide compiler reset on UB teardown, or add a targeted compiled UB test that proves the stale-constant case and justifies this global invalidation?
| python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" | ||
| python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8" | ||
| python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" | ||
| python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_torch_compile.xml $TE_PATH/tests/pytorch/test_torch_compile.py || test_fail "test_torch_compile.py" |
There was a problem hiding this comment.
That file only compiles a local ToyLinear helper and torch.nn.Linear under te.autocast. It does not instantiate changed in this PR te.Linear, LayerNormLinear, or LayerNormMLP, and it has no UB, sequence_parallel/parallel_mode.
What tests would fail without changes to layernorm_linear, layernorm_mlp files?
There was a problem hiding this comment.
I fix the issue that the test was not connected to the CI.
Currently it tests only if te.autocast() can be traced inside torch.compile.
This is first of series of PRs and I change here only small things to make next PRs cleaner.
Description
Small standalone fixes extracted from a larger torch.compile branch, going directly from main. Two independent changes: making Userbuffers pybind11 queries compile-friendly, and freeing quantized grad_output early for column-parallel SP. Plus a custom-recipe caching fix, a split-accumulator refactor, and a CI test hook-up.
Type of change
Changes
Checklist: