Skip to content

TE EP MXFP8 fails with last_dim % MXFP8_BLOCK_SIZE == 0 #2966

Description

@faradawn

Summary

transformer_engine.pytorch.ops.GroupedLinear fails with MXFP8BlockScaling in a Mixtral MoE expert-parallel run when per-expert token splits are not divisible by 32.

Repro context

  • Model: Mixtral-8x7B
  • GPUs: 8x B300
  • EP size: 2
  • Batch size: 8
  • Sequence length: 8192
  • Precision recipe: MXFP8BlockScaling
  • Expert FFN path: transformer_engine.pytorch.ops.GroupedLinear
  • Dispatcher: NCCL all-to-all
  • Tutorial/example PR: Add examples for MoE models - Mixtral in TE #2642

Repro command:

cd /lustre/fsw/coreai_prod_infbench/faradawny/TransformerEngine/docs/examples/te_mixtral

torchrun --standalone --nproc_per_node=8 run_finetune_ep.py \
  --improvement 8 \
  --ep-size 2 \
  --batch-size 8 \
  --max-seq-length 8192 \
  --warmup-steps 5 \
  --train-steps 10 \
  2>&1 | tee logs/sweep_seq8k_ep2_8gpus_sequential_ops/seq8k_batch8_ep2_tier8_sequential_ops_mxfp8.log

Error

[rank0]:   File "/lustre/fsw/coreai_prod_infbench/faradawny/TransformerEngine/docs/examples/te_mixtral/te_mixtral.py", line 687, in _expert_ffn
[rank0]:     gate_up_output = self.experts_gate_up(tokens, split_sizes)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/ops/op.py", line 522, in forward
[rank0]:     return OperationFuser([self])(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/ops/basic/grouped_linear.py", line 739, in fuser_forward
[rank0]:     xs = tex.split_quantize(x, split_sizes_int, input_quantizers)
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: /workspace/TransformerEngine/transformer_engine/pytorch/csrc/quantizer.cpp:1668 in function get_scale_shape: Assertion failed: last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0. MXFP8 requires tensor dims that are divisible by 32 (got shape=(2283,4096))

Other ranks fail similarly with shapes like (1441,4096), (1178,4096), and (1225,4096).

Expected

The Sequential Ops grouped path should either handle MXFP8 padding internally per split, or provide a clear documented requirement/workaround for MoE token splits whose per-expert token counts are not multiples of 32.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions