Skip to content

IAttention FP8#4209

Draft
narendasan wants to merge 2 commits intomainfrom
narendasan/quantization_fixes
Draft

IAttention FP8#4209
narendasan wants to merge 2 commits intomainfrom
narendasan/quantization_fixes

Conversation

@narendasan
Copy link
Copy Markdown
Collaborator

@narendasan narendasan commented Apr 23, 2026

Description

Prototype support for FP8 normalization scale in IAttention layer

Structurally I think this approach is reasonable, to extract relevant graph level info and bake it into metadata that gets consumed in the converter. I dont think we need the softmax qdq pass necessarily if model opt inserts this for us.

cc: @nvyihengz, @yizhuoz004

Fixes #4200, #4167

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@narendasan narendasan requested a review from zewenli98 April 23, 2026 16:46
@meta-cla meta-cla Bot added the cla signed label Apr 23, 2026
@narendasan narendasan changed the title Narendasan/quantization fixes IAttention FP8 Apr 23, 2026
@github-actions github-actions Bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: build system Issues re: Build system component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 23, 2026
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/attention.py	2026-04-23 16:47:17.496428+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/attention.py	2026-04-23 16:47:37.867662+00:00
@@ -29,11 +29,12 @@
    attention_layer: trt.IAttention,
) -> bool:
    """Set FP8 softmax normalization quantization on the IAttention layer if the current
    node was annotated with a softmax FP8 scale by the fp8_attention_softmax lowering pass.

-    Returns True if FP8 normalization was configured (caller must set decomposable=False)."""
+    Returns True if FP8 normalization was configured (caller must set decomposable=False).
+    """
    if ctx.current_node is None:
        return False
    scale_val = ctx.current_node.meta.get("_fp8_softmax_scale")
    if scale_val is None:
        return False
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2026-04-23 16:47:17.538801+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2026-04-23 16:47:42.928700+00:00
@@ -580,34 +580,34 @@
    # FP8 Q/K/V inputs (exponent_bits=4): SDPA node must be annotated with 1/448.
    gm_fp8 = _build_sdpa_input_quant_graph(exponent_bits=4)
    annotate_fp8_sdpa(gm_fp8, settings)
    sdpa_nodes = [n for n in gm_fp8.graph.nodes if n.target in _SDPA_TARGETS]
    assert sdpa_nodes, "No SDPA node found in graph"
-    assert all("_fp8_softmax_scale" in n.meta for n in sdpa_nodes), (
-        "annotate_fp8_sdpa did not annotate SDPA when Q/K/V inputs are FP8"
-    )
+    assert all(
+        "_fp8_softmax_scale" in n.meta for n in sdpa_nodes
+    ), "annotate_fp8_sdpa did not annotate SDPA when Q/K/V inputs are FP8"
    expected_scale = 1.0 / 448.0
    for n in sdpa_nodes:
-        assert abs(n.meta["_fp8_softmax_scale"] - expected_scale) < 1e-12, (
-            f"Wrong softmax scale: {n.meta['_fp8_softmax_scale']}"
-        )
+        assert (
+            abs(n.meta["_fp8_softmax_scale"] - expected_scale) < 1e-12
+        ), f"Wrong softmax scale: {n.meta['_fp8_softmax_scale']}"

    # INT8 Q/K/V inputs (exponent_bits=0): SDPA node must NOT be annotated.
    gm_int8 = _build_sdpa_input_quant_graph(exponent_bits=0)
    annotate_fp8_sdpa(gm_int8, settings)
    sdpa_int8 = [n for n in gm_int8.graph.nodes if n.target in _SDPA_TARGETS]
-    assert all("_fp8_softmax_scale" not in n.meta for n in sdpa_int8), (
-        "annotate_fp8_sdpa incorrectly annotated SDPA when Q/K/V are INT8"
-    )
+    assert all(
+        "_fp8_softmax_scale" not in n.meta for n in sdpa_int8
+    ), "annotate_fp8_sdpa incorrectly annotated SDPA when Q/K/V are INT8"

    # Only Q and K are FP8-quantized, V is raw: SDPA must NOT be annotated.
    gm_partial = _build_sdpa_input_quant_graph(exponent_bits=4, quantize_v=False)
    annotate_fp8_sdpa(gm_partial, settings)
    sdpa_partial = [n for n in gm_partial.graph.nodes if n.target in _SDPA_TARGETS]
-    assert all("_fp8_softmax_scale" not in n.meta for n in sdpa_partial), (
-        "annotate_fp8_sdpa incorrectly annotated SDPA when V input is not FP8"
-    )
+    assert all(
+        "_fp8_softmax_scale" not in n.meta for n in sdpa_partial
+    ), "annotate_fp8_sdpa incorrectly annotated SDPA when V input is not FP8"


@unittest.skipIf(
    torch.cuda.get_device_capability() < (8, 9),
    "FP8 quantization requires compute capability 8.9 or later",
@@ -649,19 +649,13 @@
        """Mirror of what a modelopt FP8 MHA PyTorch export will look like:
        tensorrt.quantize_op on Q, K, V feeding F.scaled_dot_product_attention."""

        def __init__(self, amax_val: float = 6.0):
            super().__init__()
-            self.register_buffer(
-                "amax_q", torch.tensor(amax_val, dtype=torch.float32)
-            )
-            self.register_buffer(
-                "amax_k", torch.tensor(amax_val, dtype=torch.float32)
-            )
-            self.register_buffer(
-                "amax_v", torch.tensor(amax_val, dtype=torch.float32)
-            )
+            self.register_buffer("amax_q", torch.tensor(amax_val, dtype=torch.float32))
+            self.register_buffer("amax_k", torch.tensor(amax_val, dtype=torch.float32))
+            self.register_buffer("amax_v", torch.tensor(amax_val, dtype=torch.float32))

        def forward(self, q, k, v):
            q_fp8 = torch.ops.tensorrt.quantize_op(q, self.amax_q, 8, 4, False, False)
            k_fp8 = torch.ops.tensorrt.quantize_op(k, self.amax_k, 8, 4, False, False)
            v_fp8 = torch.ops.tensorrt.quantize_op(v, self.amax_v, 8, 4, False, False)
@@ -690,12 +684,11 @@
    engine_json = json.loads(
        inspector.get_engine_information(trt.LayerInformationFormat.JSON)
    )
    layers = engine_json.get("Layers", [])
    layer_names = [
-        layer if isinstance(layer, str) else layer.get("Name", "")
-        for layer in layers
+        layer if isinstance(layer, str) else layer.get("Name", "") for layer in layers
    ]
    assert any("mha" in name.lower() for name in layer_names), (
        f"No fused MHA kernel found in compiled engine. Expected a layer "
        f"containing 'mha' (e.g. _gemm_mha_v2) — TRT fuses FP8 Q/K/V + "
        f"normalization_quantize_to_type into a single MHA kernel. "
@@ -714,8 +707,8 @@
        trt_out = compiled(q, k, v)
    cos = torch.nn.functional.cosine_similarity(
        ref_out.flatten().float().unsqueeze(0),
        trt_out.flatten().float().unsqueeze(0),
    ).item()
-    assert cos > 0.99, (
-        f"FP8 MHA output deviates from PyTorch reference: cosine_similarity={cos}"
-    )
+    assert (
+        cos > 0.99
+    ), f"FP8 MHA output deviates from PyTorch reference: cosine_similarity={cos}"

ctx,
torch.tensor(scale_val, dtype=torch.float32),
name + "_softmax_fp8_scale",
dtype=torch.float32,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype needs to match the pre-quant QKV dtype. otherwise TRT compilatio will fail on some platforms

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know where we can fetch this info?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7f0d61c I pulled the attention layer's output tensor's dtype.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/attention.py	2026-04-23 17:31:58.988167+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/attention.py	2026-04-23 17:32:17.137589+00:00
@@ -29,11 +29,12 @@
    attention_layer: trt.IAttention,
) -> bool:
    """Set FP8 softmax normalization quantization on the IAttention layer if the current
    node was annotated with a softmax FP8 scale by the fp8_attention_softmax lowering pass.

-    Returns True if FP8 normalization was configured (caller must set decomposable=False)."""
+    Returns True if FP8 normalization was configured (caller must set decomposable=False).
+    """
    if ctx.current_node is None:
        return False
    scale_val = ctx.current_node.meta.get("_fp8_softmax_scale")
    if scale_val is None:
        return False
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/insert_fp8_softmax_qdq.py	2026-04-23 17:31:58.992065+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/insert_fp8_softmax_qdq.py	2026-04-23 17:32:18.027942+00:00
@@ -115,11 +115,13 @@
        if attn_src.op != "call_function" or attn_src.target not in _MATMUL_TARGETS:
            continue
        if len(attn_src.args) < 2:
            continue
        q_source, k_source = attn_src.args[0], attn_src.args[1]
-        if not (_source_is_fp8_quantize(q_source) and _source_is_fp8_quantize(k_source)):
+        if not (
+            _source_is_fp8_quantize(q_source) and _source_is_fp8_quantize(k_source)
+        ):
            continue

        # Register a per-insertion amax buffer (1.0).
        amax_name = f"_fp8_softmax_qdq_amax_{amax_buffer_idx}"
        amax_buffer_idx += 1
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2026-04-23 17:31:59.020247+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2026-04-23 17:32:21.643502+00:00
@@ -580,34 +580,34 @@
    # FP8 Q/K/V inputs (exponent_bits=4): SDPA node must be annotated with 1/448.
    gm_fp8 = _build_sdpa_input_quant_graph(exponent_bits=4)
    annotate_fp8_sdpa(gm_fp8, settings)
    sdpa_nodes = [n for n in gm_fp8.graph.nodes if n.target in _SDPA_TARGETS]
    assert sdpa_nodes, "No SDPA node found in graph"
-    assert all("_fp8_softmax_scale" in n.meta for n in sdpa_nodes), (
-        "annotate_fp8_sdpa did not annotate SDPA when Q/K/V inputs are FP8"
-    )
+    assert all(
+        "_fp8_softmax_scale" in n.meta for n in sdpa_nodes
+    ), "annotate_fp8_sdpa did not annotate SDPA when Q/K/V inputs are FP8"
    expected_scale = 1.0 / 448.0
    for n in sdpa_nodes:
-        assert abs(n.meta["_fp8_softmax_scale"] - expected_scale) < 1e-12, (
-            f"Wrong softmax scale: {n.meta['_fp8_softmax_scale']}"
-        )
+        assert (
+            abs(n.meta["_fp8_softmax_scale"] - expected_scale) < 1e-12
+        ), f"Wrong softmax scale: {n.meta['_fp8_softmax_scale']}"

    # INT8 Q/K/V inputs (exponent_bits=0): SDPA node must NOT be annotated.
    gm_int8 = _build_sdpa_input_quant_graph(exponent_bits=0)
    annotate_fp8_sdpa(gm_int8, settings)
    sdpa_int8 = [n for n in gm_int8.graph.nodes if n.target in _SDPA_TARGETS]
-    assert all("_fp8_softmax_scale" not in n.meta for n in sdpa_int8), (
-        "annotate_fp8_sdpa incorrectly annotated SDPA when Q/K/V are INT8"
-    )
+    assert all(
+        "_fp8_softmax_scale" not in n.meta for n in sdpa_int8
+    ), "annotate_fp8_sdpa incorrectly annotated SDPA when Q/K/V are INT8"

    # Only Q and K are FP8-quantized, V is raw: SDPA must NOT be annotated.
    gm_partial = _build_sdpa_input_quant_graph(exponent_bits=4, quantize_v=False)
    annotate_fp8_sdpa(gm_partial, settings)
    sdpa_partial = [n for n in gm_partial.graph.nodes if n.target in _SDPA_TARGETS]
-    assert all("_fp8_softmax_scale" not in n.meta for n in sdpa_partial), (
-        "annotate_fp8_sdpa incorrectly annotated SDPA when V input is not FP8"
-    )
+    assert all(
+        "_fp8_softmax_scale" not in n.meta for n in sdpa_partial
+    ), "annotate_fp8_sdpa incorrectly annotated SDPA when V input is not FP8"


@unittest.skipIf(
    torch.cuda.get_device_capability() < (8, 9),
    "FP8 quantization requires compute capability 8.9 or later",
@@ -649,19 +649,13 @@
        """Mirror of what a modelopt FP8 MHA PyTorch export will look like:
        tensorrt.quantize_op on Q, K, V feeding F.scaled_dot_product_attention."""

        def __init__(self, amax_val: float = 6.0):
            super().__init__()
-            self.register_buffer(
-                "amax_q", torch.tensor(amax_val, dtype=torch.float32)
-            )
-            self.register_buffer(
-                "amax_k", torch.tensor(amax_val, dtype=torch.float32)
-            )
-            self.register_buffer(
-                "amax_v", torch.tensor(amax_val, dtype=torch.float32)
-            )
+            self.register_buffer("amax_q", torch.tensor(amax_val, dtype=torch.float32))
+            self.register_buffer("amax_k", torch.tensor(amax_val, dtype=torch.float32))
+            self.register_buffer("amax_v", torch.tensor(amax_val, dtype=torch.float32))

        def forward(self, q, k, v):
            q_fp8 = torch.ops.tensorrt.quantize_op(q, self.amax_q, 8, 4, False, False)
            k_fp8 = torch.ops.tensorrt.quantize_op(k, self.amax_k, 8, 4, False, False)
            v_fp8 = torch.ops.tensorrt.quantize_op(v, self.amax_v, 8, 4, False, False)
@@ -690,12 +684,11 @@
    engine_json = json.loads(
        inspector.get_engine_information(trt.LayerInformationFormat.JSON)
    )
    layers = engine_json.get("Layers", [])
    layer_names = [
-        layer if isinstance(layer, str) else layer.get("Name", "")
-        for layer in layers
+        layer if isinstance(layer, str) else layer.get("Name", "") for layer in layers
    ]
    assert any("mha" in name.lower() for name in layer_names), (
        f"No fused MHA kernel found in compiled engine. Expected a layer "
        f"containing 'mha' (e.g. _gemm_mha_v2) — TRT fuses FP8 Q/K/V + "
        f"normalization_quantize_to_type into a single MHA kernel. "
@@ -714,13 +707,13 @@
        trt_out = compiled(q, k, v)
    cos = torch.nn.functional.cosine_similarity(
        ref_out.flatten().float().unsqueeze(0),
        trt_out.flatten().float().unsqueeze(0),
    ).item()
-    assert cos > 0.99, (
-        f"FP8 MHA output deviates from PyTorch reference: cosine_similarity={cos}"
-    )
+    assert (
+        cos > 0.99
+    ), f"FP8 MHA output deviates from PyTorch reference: cosine_similarity={cos}"


@unittest.skipIf(
    torch.cuda.get_device_capability() < (8, 9),
    "FP8 quantization requires compute capability 8.9 or later",
@@ -749,19 +742,13 @@
    torch.manual_seed(0)

    class FP8MHAModel(torch.nn.Module):
        def __init__(self, amax_val: float = 6.0):
            super().__init__()
-            self.register_buffer(
-                "amax_q", torch.tensor(amax_val, dtype=torch.float32)
-            )
-            self.register_buffer(
-                "amax_k", torch.tensor(amax_val, dtype=torch.float32)
-            )
-            self.register_buffer(
-                "amax_v", torch.tensor(amax_val, dtype=torch.float32)
-            )
+            self.register_buffer("amax_q", torch.tensor(amax_val, dtype=torch.float32))
+            self.register_buffer("amax_k", torch.tensor(amax_val, dtype=torch.float32))
+            self.register_buffer("amax_v", torch.tensor(amax_val, dtype=torch.float32))

        def forward(self, q, k, v):
            q_fp8 = torch.ops.tensorrt.quantize_op(q, self.amax_q, 8, 4, False, False)
            k_fp8 = torch.ops.tensorrt.quantize_op(k, self.amax_k, 8, 4, False, False)
            v_fp8 = torch.ops.tensorrt.quantize_op(v, self.amax_v, 8, 4, False, False)
@@ -791,12 +778,11 @@
    engine_json = json.loads(
        inspector.get_engine_information(trt.LayerInformationFormat.JSON)
    )
    layers = engine_json.get("Layers", [])
    layer_names = [
-        layer if isinstance(layer, str) else layer.get("Name", "")
-        for layer in layers
+        layer if isinstance(layer, str) else layer.get("Name", "") for layer in layers
    ]
    assert any("mha" in name.lower() for name in layer_names), (
        f"No fused MHA kernel found on decomposed path. Expected a layer "
        f"containing 'mha' (e.g. _gemm_mha_v2) — TRT fuses FP8 Q/K/V + "
        f"softmax-output Q/DQ into _gemm_mha_v2 on Method 2 path. "
@@ -816,8 +802,8 @@
        trt_out = compiled(q, k, v)
    cos = torch.nn.functional.cosine_similarity(
        ref_out.flatten().float().unsqueeze(0),
        trt_out.flatten().float().unsqueeze(0),
    ).item()
-    assert cos > 0.99, (
-        f"Decomposed FP8 MHA output deviates from PyTorch reference: cos={cos}"
-    )
+    assert (
+        cos > 0.99
+    ), f"Decomposed FP8 MHA output deviates from PyTorch reference: cos={cos}"

@narendasan narendasan force-pushed the narendasan/quantization_fixes branch from fe7f268 to 211e690 Compare April 23, 2026 17:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛 [Bug] Torch-TRT does not translate softmax quantizer generated by modelopt fp8 mha quantization

2 participants