[PyTorch] Enable Weight Preswizzling during Quantization in TE Modules#3093
Conversation
…obatches For block-scaled NVFP4 a cached weight participates in two GEMMs per step: fprop (rowwise scales) and dgrad (columnwise scales). The GEMM-ready scale swizzle was recomputed lazily inside every GEMM and discarded, so with N microbatches the weight scale swizzle ran 2*N times per step even though the weight is quantized only once. Because weight RHT is disabled, the weight scales are not swizzled by the cast-fusion path; with optimize_for_gemm off they also skip the post-quantize fallback swizzle, so the only swizzle site left for the weight is the lazy one inside general_gemm (swizzle_scales_for_gemm), which re-runs on every GEMM. (Activation input/grad_output quantizers already set optimize_for_gemm=True, so they were pre-swizzled via cast-fusion/fallback; only the weight was missed.) Set weight_quantizer.optimize_for_gemm=True on the cached, non-FSDP path so the swizzle is done once at quantize time (via the post-quantize fallback), persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused by every GEMM (swizzle_scales_for_gemm early-returns) -> 2 swizzles per step instead of 2*N. Applied to Linear, LayerNormLinear, LayerNormMLP (fc1+fc2) and GroupedLinear (per expert). Gated to the cached path (is_first_microbatch is not None) with fsdp_group is None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled scale layout, so pre-swizzling is unsupported there. No-op for recipes whose scales do not require swizzling (e.g. per-tensor FP8). Swizzling is a pure layout permutation, so numerics are unchanged. Add tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py verifying the cached eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) for Linear/LayerNormLinear/GroupedLinear and that the swizzled flag is persisted. Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR pre-swizzles block-scaled weight scale factors (NVFP4, MXFP8) once at quantize time instead of lazily inside every GEMM call, reducing swizzle operations from
Confidence Score: 5/5Safe to merge; the change is a pure performance optimisation with no numeric side-effects and the guard on The No files require special attention; the identical comment-block typos in Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant FWD as Forward pass
participant GWQ as _get_weight_quantizers()
participant QW as quantize_weight()
participant WS as _fp8_workspaces
participant GEMM as general_gemm()
Note over FWD,GEMM: First microbatch (is_first_microbatch=True, not primary_weights_in_fp8, not FSDP2)
FWD->>GWQ: _get_weight_quantizers()
GWQ-->>FWD: "quantizer (optimize_for_gemm=True)"
FWD->>QW: "quantize_weight(cache=True)"
QW->>QW: quantize + swizzle scales eagerly
QW-->>WS: "cache workspace (_with_gemm_swizzled_scales=True)"
QW-->>GEMM: pre-swizzled workspace → GEMM (skip lazy swizzle)
Note over FWD,GEMM: Later microbatches (is_first_microbatch=False)
FWD->>WS: get cached workspace
WS-->>GEMM: reuse pre-swizzled workspace (no re-quantize, no re-swizzle)
Note over FWD,GEMM: primary_weights_in_fp8=True path (quantized_model_init)
FWD->>GWQ: _get_weight_quantizers()
GWQ-->>FWD: "quantizer (optimize_for_gemm=False, unchanged)"
FWD->>GEMM: lazy swizzle inside every GEMM call
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant FWD as Forward pass
participant GWQ as _get_weight_quantizers()
participant QW as quantize_weight()
participant WS as _fp8_workspaces
participant GEMM as general_gemm()
Note over FWD,GEMM: First microbatch (is_first_microbatch=True, not primary_weights_in_fp8, not FSDP2)
FWD->>GWQ: _get_weight_quantizers()
GWQ-->>FWD: "quantizer (optimize_for_gemm=True)"
FWD->>QW: "quantize_weight(cache=True)"
QW->>QW: quantize + swizzle scales eagerly
QW-->>WS: "cache workspace (_with_gemm_swizzled_scales=True)"
QW-->>GEMM: pre-swizzled workspace → GEMM (skip lazy swizzle)
Note over FWD,GEMM: Later microbatches (is_first_microbatch=False)
FWD->>WS: get cached workspace
WS-->>GEMM: reuse pre-swizzled workspace (no re-quantize, no re-swizzle)
Note over FWD,GEMM: primary_weights_in_fp8=True path (quantized_model_init)
FWD->>GWQ: _get_weight_quantizers()
GWQ-->>FWD: "quantizer (optimize_for_gemm=False, unchanged)"
FWD->>GEMM: lazy swizzle inside every GEMM call
Reviews (5): Last reviewed commit: "Merge branch 'main' into feature/nvfp4-w..." | Re-trigger Greptile |
| @pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"]) | ||
| def test_lazy_path_not_swizzled(kind): | ||
| """Without weight caching (is_first_microbatch=None) no workspace is created | ||
| and the optimization stays off — guards against accidentally always-on.""" | ||
| torch.manual_seed(0) | ||
| device = "cuda" | ||
| recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) | ||
| module = _make_module(kind, 1024, 1024, device) |
There was a problem hiding this comment.
Nit:
| @pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"]) | |
| def test_lazy_path_not_swizzled(kind): | |
| """Without weight caching (is_first_microbatch=None) no workspace is created | |
| and the optimization stays off — guards against accidentally always-on.""" | |
| torch.manual_seed(0) | |
| device = "cuda" | |
| recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) | |
| module = _make_module(kind, 1024, 1024, device) | |
| @pytest.mark.parametrize("layer_type", ["Linear", "LayerNormLinear"]) | |
| def test_lazy_path_not_swizzled(kind): | |
| """Without weight caching (is_first_microbatch=None) no workspace is created | |
| and the optimization stays off — guards against accidentally always-on.""" | |
| torch.manual_seed(0) | |
| device = "cuda" | |
| recipe = NVFP4BlockScaling(disable_stochastic_rounding=True) | |
| module = _make_module(layer_type, 1024, 1024, device) |
There was a problem hiding this comment.
Done — test_lazy_path_not_swizzled now parametrizes over all four module kinds.
| x = x.detach().clone().requires_grad_(True) | ||
| module.zero_grad(set_to_none=True) # per-micro-batch grads (no accumulation) | ||
| with te.autocast(enabled=True, recipe=recipe): | ||
| out = module(x, is_first_microbatch=is_first) |
There was a problem hiding this comment.
If absence of m_splits argument is the only reason for creating new test for grouped_linear below, then can we add a check on the module in terms of passing m_splits only if module is GroupedLinear, instead of duplicating the test?
There was a problem hiding this comment.
Done — folded GroupedLinear into the parametrized test_weight_swizzle_cache_numerics, passing m_splits only for GroupedLinear (in _step); removed the duplicated grouped-only test.
| # Pre-swizzle (and cache) the weight scale factors when the quantized | ||
| # weights are cached across microbatches, so the per-GEMM scale swizzle | ||
| # (fprop rowwise + dgrad columnwise, redone every microbatch) collapses | ||
| # from 2*num_microbatches kernels to 2 per step per expert. Gated to the | ||
| # cached, non-FSDP path (FSDP/FSDP2 all-gather weights with un-swizzled | ||
| # scales; see NVFP4Tensor.fsdp_pre_all_gather), so pre-swizzling is | ||
| # unsupported there. No-op for non-swizzled recipes (e.g. per-tensor FP8). | ||
| if cache_weight and self.fsdp_group is None and not self.is_fsdp2: | ||
| for weight_quantizer in weight_quantizers: | ||
| if weight_quantizer is not None: | ||
| weight_quantizer.optimize_for_gemm = True | ||
|
|
There was a problem hiding this comment.
I dont think the comment is relevant In case of FSDP/FSDP2,
For FSDP, The scales are not sharded, and the whole scales are replicated across ranks today. So it doesnt matter if scales are swizzled or not. cc: @denera. Also NVFP4 pre allgather is function specific to FSDP2 not FSDP.
For FSDP2, we havent been caching weights as it causes memory bloating. And weight caching as a mechanism doesnt fit well with fsdp2. This was done for linear and layer_norm_linear but apparently not for grouped_linear in this PR #2805. But fixing that for grouped_linear might be byond scope of this PR. Even if weight caching is still kept as it is, current behavior is to save the entire weight instead of shard in the workspace and so swizzling being present shouldnt cause any issue.
| # Pre-swizzle (and cache) the weight scale factors when the quantized | |
| # weights are cached across microbatches, so the per-GEMM scale swizzle | |
| # (fprop rowwise + dgrad columnwise, redone every microbatch) collapses | |
| # from 2*num_microbatches kernels to 2 per step per expert. Gated to the | |
| # cached, non-FSDP path (FSDP/FSDP2 all-gather weights with un-swizzled | |
| # scales; see NVFP4Tensor.fsdp_pre_all_gather), so pre-swizzling is | |
| # unsupported there. No-op for non-swizzled recipes (e.g. per-tensor FP8). | |
| if cache_weight and self.fsdp_group is None and not self.is_fsdp2: | |
| for weight_quantizer in weight_quantizers: | |
| if weight_quantizer is not None: | |
| weight_quantizer.optimize_for_gemm = True | |
| # Pre-swizzle (and cache) the weight scale factors when the quantized | |
| # weights are cached across microbatches, so the per-GEMM scale swizzle | |
| # (fprop rowwise + dgrad columnwise, redone every microbatch) collapses | |
| # from 2*num_microbatches kernels to 2 per step per expert. | |
| # No-op for non-swizzled recipes (e.g. per-tensor FP8). | |
| if cache_weight: | |
| for weight_quantizer in weight_quantizers: | |
| if weight_quantizer is not None: | |
| weight_quantizer.optimize_for_gemm = True | |
There was a problem hiding this comment.
Same applies in other files.
There was a problem hiding this comment.
Applied to all four module files.
Drop the FSDP/FSDP2 gating on optimize_for_gemm in Linear, LayerNormLinear, LayerNormMLP and GroupedLinear. FSDP1 replicates (does not shard) the scale factors, so the swizzle layout is irrelevant there, and weights are not cached under FSDP2; the guard only added a misleading comment and dead conditions. Pre-swizzle the weight scales whenever the quantized weight is cached. Tests: - Fold the GroupedLinear case into the parametrized test_weight_swizzle_cache_numerics by passing m_splits only for GroupedLinear, removing the duplicated grouped-only test. - Add LayerNormMLP coverage (fc1 + fc2 two-quantizer path), generalizing the cached-workspace-count assertion per module type. - Parametrize test_lazy_path_not_swizzled over all four module kinds. Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
|
Pushed a commit addressing the review: removed the irrelevant FSDP gating across all four modules, merged the GroupedLinear test, and added LayerNormMLP coverage. Please take a look, thanks. @vthumbe1503 |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
There was a problem hiding this comment.
Thanks for making the changes. I think that optimize_for_gemm(weight preswizzling) should be enabled for mostly all use-cases except for primary weights in fp8/fp4.
For primary weights in fp8/fp4 it wont work because of
- Dequantization needs in optimizer step update which wont work on swizzled weights
- FSDP2 allgather is currently supported only for unswizzled weights.
So lets enable it for most use-cases @cael-ling instead of restricting it to weight caching only.
cael-ling#1
…ache Enable weight swizzling for most cases
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ache mxfp8 zero kernel for noop cuda graph compat
|
/te-ci |
|
Pipeline: 55637043 |
Description
Motivation
For block-scaled NVFP4, a cached weight is used in two GEMMs per step — fprop (row-wise scales) and dgrad (column-wise scales) — and each GEMM needs its scale factors in the GEMM-swizzled layout. Today that swizzle is recomputed lazily inside
general_gemmon every micro-batch and thrown away, so withNmicro-batches the weight scale swizzle runs2*Ntimes per step even though the weight is quantized only once, which hurts performance. (Activation quantizers already setoptimize_for_gemm=Trueand were pre-swizzled; only the weight was missed.)What this PR does
This PR sets
weight_quantizer.optimize_for_gemm=Truein both MXFP8/NVFP4 recipes to enable pre-swizzling of weights. This also allows us to save the preswizzled quantized weights in cache when we enable workspace caching. NOTE: optimize_for_gemm isnt set in case of FP8/FP4 primary weights(quantized_model_init) since it we cant dequantize swizzled weights for now(needed in optimizer.step)Applied to
Linear,LayerNormLinear,LayerNormMLP(fc1 + fc2) andGroupedLinear(per expert).No-op for recipes whose scales do not require swizzling (e.g. per-tensor FP8).
Swizzling is a pure layout permutation, so numerics are unchanged.
New
tests/pytorch/nvfp4/test_weight_swizzle_in_layers.py_with_gemm_swizzled_scalesis set and persisted on the cached workspace._with_gemm_swizzled_scalesis not set in case of module initialized with fp8 primary weights(quantized_model_init)pytest tests/pytorch/test_numerics.py -k "linear or layernorm or mlp"— no regressions.pytest tests/pytorch/test_grouped_linear.py -k "not grouped_tensor and not fused_path"— no regressions.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: