Skip to content

[PyTorch] Enable Weight Preswizzling during Quantization in TE Modules#3093

Merged
vthumbe1503 merged 10 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-weight-swizzle-cache
Jun 24, 2026
Merged

[PyTorch] Enable Weight Preswizzling during Quantization in TE Modules#3093
vthumbe1503 merged 10 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-weight-swizzle-cache

Conversation

@cael-ling

@cael-ling cael-ling commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Description

Motivation

For block-scaled NVFP4, a cached weight is used in two GEMMs per step — fprop (row-wise scales) and dgrad (column-wise scales) — and each GEMM needs its scale factors in the GEMM-swizzled layout. Today that swizzle is recomputed lazily inside general_gemm on every micro-batch and thrown away, so with N micro-batches the weight scale swizzle runs 2*N times per step even though the weight is quantized only once, which hurts performance. (Activation quantizers already set optimize_for_gemm=True and were pre-swizzled; only the weight was missed.)

What this PR does

This PR sets weight_quantizer.optimize_for_gemm=True in both MXFP8/NVFP4 recipes to enable pre-swizzling of weights. This also allows us to save the preswizzled quantized weights in cache when we enable workspace caching. NOTE: optimize_for_gemm isnt set in case of FP8/FP4 primary weights(quantized_model_init) since it we cant dequantize swizzled weights for now(needed in optimizer.step)

  • Applied to Linear, LayerNormLinear, LayerNormMLP (fc1 + fc2) and GroupedLinear (per expert).

  • No-op for recipes whose scales do not require swizzling (e.g. per-tensor FP8).

  • Swizzling is a pure layout permutation, so numerics are unchanged.

  • New tests/pytorch/nvfp4/test_weight_swizzle_in_layers.py

    • When workspace is cached, this makes sure_with_gemm_swizzled_scales is set and persisted on the cached workspace.
    • _with_gemm_swizzled_scales is not set in case of module initialized with fp8 primary weights(quantized_model_init)
  • pytest tests/pytorch/test_numerics.py -k "linear or layernorm or mlp" — no regressions.

  • pytest tests/pytorch/test_grouped_linear.py -k "not grouped_tensor and not fused_path" — no regressions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…obatches

For block-scaled NVFP4 a cached weight participates in two GEMMs per step:
fprop (rowwise scales) and dgrad (columnwise scales). The GEMM-ready scale
swizzle was recomputed lazily inside every GEMM and discarded, so with N
microbatches the weight scale swizzle ran 2*N times per step even though the
weight is quantized only once.

Because weight RHT is disabled, the weight scales are not swizzled by the
cast-fusion path; with optimize_for_gemm off they also skip the post-quantize
fallback swizzle, so the only swizzle site left for the weight is the lazy one
inside general_gemm (swizzle_scales_for_gemm), which re-runs on every GEMM.
(Activation input/grad_output quantizers already set optimize_for_gemm=True, so
they were pre-swizzled via cast-fusion/fallback; only the weight was missed.)

Set weight_quantizer.optimize_for_gemm=True on the cached, non-FSDP path so the
swizzle is done once at quantize time (via the post-quantize fallback),
persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused
by every GEMM (swizzle_scales_for_gemm early-returns) -> 2 swizzles per step
instead of 2*N. Applied to Linear, LayerNormLinear, LayerNormMLP (fc1+fc2) and
GroupedLinear (per expert).

Gated to the cached path (is_first_microbatch is not None) with fsdp_group is
None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled
scale layout, so pre-swizzling is unsupported there. No-op for recipes whose
scales do not require swizzling (e.g. per-tensor FP8). Swizzling is a pure
layout permutation, so numerics are unchanged.

Add tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py verifying the cached
eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) for
Linear/LayerNormLinear/GroupedLinear and that the swizzled flag is persisted.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling requested a review from ksivaman as a code owner June 5, 2026 14:29
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
@greptile-apps

greptile-apps Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR pre-swizzles block-scaled weight scale factors (NVFP4, MXFP8) once at quantize time instead of lazily inside every GEMM call, reducing swizzle operations from 2*N to 2 per step for cached (non-FSDP2, non-primary-FP8) weights. A paired CUDA-layer change replaces cudaMemsetAsync for scale-buffer padding with a noop-aware kernel that correctly skips zeroing when a skip_update_flag is active, fixing a subtle bug where cached padding could be corrupted.

  • Sets weight_quantizer.optimize_for_gemm = True in _get_weight_quantizers() for Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear when not primary_weights_in_fp8, correctly excluding the quantized_model_init path where FP8 weight parameters are all-gathered/optimizer-updated in their unswizzled layout.
  • Adds zero_scales_kernel in quantize_mxfp8.cuh to replace two cudaMemsetAsync calls that previously zeroed valid cached scales when the noop flag was set.
  • New test_weight_swizzle_in_layers.py covers all four module types × two block-scaled recipes (MXFP8, NVFP4) with workspace-flag assertions and bit-exact numerical comparison between the cached and uncached code paths.

Confidence Score: 5/5

Safe to merge; the change is a pure performance optimisation with no numeric side-effects and the guard on primary_weights_in_fp8 correctly excludes every path where pre-swizzling would be unsafe.

The not primary_weights_in_fp8 guard is the right invariant: FSDP + primary FP8 weights is explicitly disallowed at the API level (RuntimeError), so all FSDP cases have BF16 primary weights and receive the pre-swizzle correctly. The CUDA kernel replacement fixes an existing subtle bug (zeroing valid cached scales on noop) and is straightforward. The bit-exact numerical tests across all four module types and both block-scaled recipes give strong correctness confidence.

No files require special attention; the identical comment-block typos in linear.py, layernorm_linear.py, and grouped_linear.py are cosmetic only.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Adds optimize_for_gemm = True to _get_weight_quantizers() guarded on not primary_weights_in_fp8; typos in added comment ("wont", "primay").
transformer_engine/pytorch/module/layernorm_linear.py Same optimize_for_gemm change as linear.py; same typos in the shorter comment block.
transformer_engine/pytorch/module/layernorm_mlp.py Sets optimize_for_gemm = True on both fc1 and fc2 weight quantizers; comment here is more thorough and correctly spelled.
transformer_engine/pytorch/module/grouped_linear.py Adds per-expert optimize_for_gemm = True in the num_gemms loop; same comment typos as linear.py.
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh Replaces cudaMemsetAsync for scale-padding zero-fill with a noop-aware custom kernel; correctly guards launch with size_bytes > 0 and uses cudaGetLastError() for launch-error detection.
tests/pytorch/test_weight_swizzle_in_layers.py New test covering workspace-flag assertions and bit-exact numerical comparison for all four module types × two recipes; added to CI harness.
qa/L0_pytorch_unittest/test.sh Adds the new test file to the L0 CI test suite.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant FWD as Forward pass
    participant GWQ as _get_weight_quantizers()
    participant QW as quantize_weight()
    participant WS as _fp8_workspaces
    participant GEMM as general_gemm()

    Note over FWD,GEMM: First microbatch (is_first_microbatch=True, not primary_weights_in_fp8, not FSDP2)
    FWD->>GWQ: _get_weight_quantizers()
    GWQ-->>FWD: "quantizer (optimize_for_gemm=True)"
    FWD->>QW: "quantize_weight(cache=True)"
    QW->>QW: quantize + swizzle scales eagerly
    QW-->>WS: "cache workspace (_with_gemm_swizzled_scales=True)"
    QW-->>GEMM: pre-swizzled workspace → GEMM (skip lazy swizzle)

    Note over FWD,GEMM: Later microbatches (is_first_microbatch=False)
    FWD->>WS: get cached workspace
    WS-->>GEMM: reuse pre-swizzled workspace (no re-quantize, no re-swizzle)

    Note over FWD,GEMM: primary_weights_in_fp8=True path (quantized_model_init)
    FWD->>GWQ: _get_weight_quantizers()
    GWQ-->>FWD: "quantizer (optimize_for_gemm=False, unchanged)"
    FWD->>GEMM: lazy swizzle inside every GEMM call
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant FWD as Forward pass
    participant GWQ as _get_weight_quantizers()
    participant QW as quantize_weight()
    participant WS as _fp8_workspaces
    participant GEMM as general_gemm()

    Note over FWD,GEMM: First microbatch (is_first_microbatch=True, not primary_weights_in_fp8, not FSDP2)
    FWD->>GWQ: _get_weight_quantizers()
    GWQ-->>FWD: "quantizer (optimize_for_gemm=True)"
    FWD->>QW: "quantize_weight(cache=True)"
    QW->>QW: quantize + swizzle scales eagerly
    QW-->>WS: "cache workspace (_with_gemm_swizzled_scales=True)"
    QW-->>GEMM: pre-swizzled workspace → GEMM (skip lazy swizzle)

    Note over FWD,GEMM: Later microbatches (is_first_microbatch=False)
    FWD->>WS: get cached workspace
    WS-->>GEMM: reuse pre-swizzled workspace (no re-quantize, no re-swizzle)

    Note over FWD,GEMM: primary_weights_in_fp8=True path (quantized_model_init)
    FWD->>GWQ: _get_weight_quantizers()
    GWQ-->>FWD: "quantizer (optimize_for_gemm=False, unchanged)"
    FWD->>GEMM: lazy swizzle inside every GEMM call
Loading

Reviews (5): Last reviewed commit: "Merge branch 'main' into feature/nvfp4-w..." | Re-trigger Greptile

Comment thread tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py Outdated

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from FSDP2 condition being irrelevant, LGTM

Comment on lines +178 to +185
@pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"])
def test_lazy_path_not_swizzled(kind):
"""Without weight caching (is_first_microbatch=None) no workspace is created
and the optimization stays off — guards against accidentally always-on."""
torch.manual_seed(0)
device = "cuda"
recipe = NVFP4BlockScaling(disable_stochastic_rounding=True)
module = _make_module(kind, 1024, 1024, device)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
@pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"])
def test_lazy_path_not_swizzled(kind):
"""Without weight caching (is_first_microbatch=None) no workspace is created
and the optimization stays offguards against accidentally always-on."""
torch.manual_seed(0)
device = "cuda"
recipe = NVFP4BlockScaling(disable_stochastic_rounding=True)
module = _make_module(kind, 1024, 1024, device)
@pytest.mark.parametrize("layer_type", ["Linear", "LayerNormLinear"])
def test_lazy_path_not_swizzled(kind):
"""Without weight caching (is_first_microbatch=None) no workspace is created
and the optimization stays offguards against accidentally always-on."""
torch.manual_seed(0)
device = "cuda"
recipe = NVFP4BlockScaling(disable_stochastic_rounding=True)
module = _make_module(layer_type, 1024, 1024, device)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — test_lazy_path_not_swizzled now parametrizes over all four module kinds.

x = x.detach().clone().requires_grad_(True)
module.zero_grad(set_to_none=True) # per-micro-batch grads (no accumulation)
with te.autocast(enabled=True, recipe=recipe):
out = module(x, is_first_microbatch=is_first)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If absence of m_splits argument is the only reason for creating new test for grouped_linear below, then can we add a check on the module in terms of passing m_splits only if module is GroupedLinear, instead of duplicating the test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — folded GroupedLinear into the parametrized test_weight_swizzle_cache_numerics, passing m_splits only for GroupedLinear (in _step); removed the duplicated grouped-only test.

Comment on lines +1738 to +1749
# Pre-swizzle (and cache) the weight scale factors when the quantized
# weights are cached across microbatches, so the per-GEMM scale swizzle
# (fprop rowwise + dgrad columnwise, redone every microbatch) collapses
# from 2*num_microbatches kernels to 2 per step per expert. Gated to the
# cached, non-FSDP path (FSDP/FSDP2 all-gather weights with un-swizzled
# scales; see NVFP4Tensor.fsdp_pre_all_gather), so pre-swizzling is
# unsupported there. No-op for non-swizzled recipes (e.g. per-tensor FP8).
if cache_weight and self.fsdp_group is None and not self.is_fsdp2:
for weight_quantizer in weight_quantizers:
if weight_quantizer is not None:
weight_quantizer.optimize_for_gemm = True

@vthumbe1503 vthumbe1503 Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think the comment is relevant In case of FSDP/FSDP2,
For FSDP, The scales are not sharded, and the whole scales are replicated across ranks today. So it doesnt matter if scales are swizzled or not. cc: @denera. Also NVFP4 pre allgather is function specific to FSDP2 not FSDP.
For FSDP2, we havent been caching weights as it causes memory bloating. And weight caching as a mechanism doesnt fit well with fsdp2. This was done for linear and layer_norm_linear but apparently not for grouped_linear in this PR #2805. But fixing that for grouped_linear might be byond scope of this PR. Even if weight caching is still kept as it is, current behavior is to save the entire weight instead of shard in the workspace and so swizzling being present shouldnt cause any issue.

Suggested change
# Pre-swizzle (and cache) the weight scale factors when the quantized
# weights are cached across microbatches, so the per-GEMM scale swizzle
# (fprop rowwise + dgrad columnwise, redone every microbatch) collapses
# from 2*num_microbatches kernels to 2 per step per expert. Gated to the
# cached, non-FSDP path (FSDP/FSDP2 all-gather weights with un-swizzled
# scales; see NVFP4Tensor.fsdp_pre_all_gather), so pre-swizzling is
# unsupported there. No-op for non-swizzled recipes (e.g. per-tensor FP8).
if cache_weight and self.fsdp_group is None and not self.is_fsdp2:
for weight_quantizer in weight_quantizers:
if weight_quantizer is not None:
weight_quantizer.optimize_for_gemm = True
# Pre-swizzle (and cache) the weight scale factors when the quantized
# weights are cached across microbatches, so the per-GEMM scale swizzle
# (fprop rowwise + dgrad columnwise, redone every microbatch) collapses
# from 2*num_microbatches kernels to 2 per step per expert.
# No-op for non-swizzled recipes (e.g. per-tensor FP8).
if cache_weight:
for weight_quantizer in weight_quantizers:
if weight_quantizer is not None:
weight_quantizer.optimize_for_gemm = True

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same applies in other files.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied to all four module files.

cael-ling and others added 2 commits June 17, 2026 00:36
Drop the FSDP/FSDP2 gating on optimize_for_gemm in Linear, LayerNormLinear,
LayerNormMLP and GroupedLinear. FSDP1 replicates (does not shard) the scale
factors, so the swizzle layout is irrelevant there, and weights are not cached
under FSDP2; the guard only added a misleading comment and dead conditions.
Pre-swizzle the weight scales whenever the quantized weight is cached.

Tests:
- Fold the GroupedLinear case into the parametrized
  test_weight_swizzle_cache_numerics by passing m_splits only for
  GroupedLinear, removing the duplicated grouped-only test.
- Add LayerNormMLP coverage (fc1 + fc2 two-quantizer path), generalizing
  the cached-workspace-count assertion per module type.
- Parametrize test_lazy_path_not_swizzled over all four module kinds.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling

Copy link
Copy Markdown
Contributor Author

Pushed a commit addressing the review: removed the irrelevant FSDP gating across all four modules, merged the GroupedLinear test, and added LayerNormMLP coverage. Please take a look, thanks. @vthumbe1503

@cael-ling cael-ling requested a review from vthumbe1503 June 17, 2026 07:45
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the changes. I think that optimize_for_gemm(weight preswizzling) should be enabled for mostly all use-cases except for primary weights in fp8/fp4.

For primary weights in fp8/fp4 it wont work because of

  1. Dequantization needs in optimizer step update which wont work on swizzled weights
  2. FSDP2 allgather is currently supported only for unswizzled weights.

So lets enable it for most use-cases @cael-ling instead of restricting it to weight caching only.
cael-ling#1

…ache

Enable weight swizzling for most cases
@greptile-apps

greptile-apps Bot commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Want your agent to iterate on Greptile's feedback? Try greploops.

@vthumbe1503

Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

@vthumbe1503 vthumbe1503 changed the title NVFP4: cache GEMM-swizzled weight scale factors across micro-batches [PyTorch] PreSwizzle Weights during Quantize for BF16 primary weights Jun 23, 2026
vthumbe1503 and others added 2 commits June 24, 2026 00:12
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ache

mxfp8 zero kernel for noop cuda graph compat
@vthumbe1503

Copy link
Copy Markdown
Collaborator

/te-ci

@vthumbe1503 vthumbe1503 changed the title [PyTorch] PreSwizzle Weights during Quantize for BF16 primary weights [PyTorch] Enable Weight Preswizzling during Quantization in TE Modules Jun 24, 2026
@vthumbe1503

Copy link
Copy Markdown
Collaborator

Pipeline: 55637043
Errors are pre-existing in main

@vthumbe1503 vthumbe1503 merged commit 42cb81b into NVIDIA:main Jun 24, 2026
34 of 43 checks passed
@vthumbe1503 vthumbe1503 added the performance Performance issues label Jun 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants