[PyTorch] Preserve fprop operands for dequantized backward override#3141
[PyTorch] Preserve fprop operands for dequantized backward override#3141negvet wants to merge 2 commits into
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
|
cc @zianglih |
for more information, see https://pre-commit.ci
|
/te-ci L0 L1 |
Greptile SummaryThis PR fixes a semantic conflict between
Confidence Score: 4/5Safe to merge; the two-line production change is a targeted guard that correctly forces save_original_input=False when backward_override=dequantized, matching the already-existing pattern for high_precision. The fix and its test coverage are clean and well-structured. The only gap is that GroupedLinear + high_precision overriding save_original_input=False is tested for Linear but has no equivalent test for GroupedLinear, leaving a small blind spot in case that path regresses. The test file would benefit from a GroupedLinear counterpart to test_linear_backward_override_high_precision_forces_save_original_input; production files are straightforward and need no further review. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["forward() called\n(Linear / GroupedLinear)"] --> B{fp8 enabled?}
B -- No --> C["backward_override = None"]
B -- Yes --> D["backward_override =\nrecipe.backward_override"]
C --> E{override value?}
D --> E
E -- high_precision --> F["save_original_input = True\n(force original tensor)"]
E -- dequantized --> G["save_original_input = False\n(NEW: force quantized tensor)"]
E -- None --> H["save_original_input =\nmodule constructor value"]
F --> I["Save plain torch.Tensor\n(original input) for backward"]
G --> J["Save QuantizedTensor\n(rowwise-only FP8) for backward"]
H --> K{constructor value?}
K -- True --> I
K -- False --> J
I --> L["Backward uses original\nunquantized activations"]
J --> M["Backward dequantizes\nfprop-quantized operand"]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["forward() called\n(Linear / GroupedLinear)"] --> B{fp8 enabled?}
B -- No --> C["backward_override = None"]
B -- Yes --> D["backward_override =\nrecipe.backward_override"]
C --> E{override value?}
D --> E
E -- high_precision --> F["save_original_input = True\n(force original tensor)"]
E -- dequantized --> G["save_original_input = False\n(NEW: force quantized tensor)"]
E -- None --> H["save_original_input =\nmodule constructor value"]
F --> I["Save plain torch.Tensor\n(original input) for backward"]
G --> J["Save QuantizedTensor\n(rowwise-only FP8) for backward"]
H --> K{constructor value?}
K -- True --> I
K -- False --> J
I --> L["Backward uses original\nunquantized activations"]
J --> M["Backward dequantizes\nfprop-quantized operand"]
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| assert_close(y_test_detached, y_ref, rtol=0, atol=0, check_dtype=True) | ||
| assert_close(dx_test, dx_ref, rtol=0, atol=0, check_dtype=True) | ||
| for test_dw, ref_dw in zip(dw_test, dw_ref): | ||
| assert_close(test_dw, ref_dw, rtol=0, atol=0, check_dtype=True) |
There was a problem hiding this comment.
Missing
GroupedLinear + high_precision test
There is a symmetric gap in the new test suite: test_linear_backward_override_high_precision_forces_save_original_input verifies that high_precision overrides save_original_input=False for te.Linear, but no equivalent test exists for te.GroupedLinear. The high_precision branch in grouped_linear.py has been in place since #2644 and the lack of coverage means a future regression there would go undetected by this test file.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Thanks for the fix! |
Description
Follow-up to #2644, which introduced
NVTE_BACKWARD_OVERRIDE=high_precision|dequantized.high_precisionis intended to use original unquantized tensor in backward, whiledequantizedis intended to use dequantized tensor from the forward-quantized one. However,save_original_input=Truecould override thedequantizedbehavior inLinearandGroupedLinear, causing backward to use the original input instead of the fprop-quantized operand.This PR makes the override semantics explicit:
backward_override="high_precision"forcessave_original_input=Truebackward_override="dequantized"forcessave_original_input=FalseFixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: