pytorch - ✅(Solved) Fix `torch.compile(fullgraph=True)` cannot trace `torch.nn.functional.scaled_mm`, even though the equivalent `torch.ops.aten._scaled_mm_v2.default` call compiles and runs [2 pull requests, 1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180604Fetched 2026-04-17 08:26:05
View on GitHub
Comments
1
Participants
1
Timeline
74
Reactions
0
Participants
Timeline (top)
mentioned ×33subscribed ×33labeled ×6commented ×1

Error Message

The traceback points into torch/nn/functional.py, where scaled_mm calls torch._scaled_mm_v2(...).

Error logs

Fix Action

Fix / Workaround

As a workaround, downstream code currently needs a compile-only workaround that bypasses F.scaled_mm and calls the aten op directly. I don't think this should be necessary.

However, the equivalent dispatcher entrypoint works under fullgraph tracing:

torch.ops.aten._scaled_mm_v2.default(...)

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 192 On-line CPU(s) list: 0-191 Vendor ID: AuthenticAMD Model name: AMD EPYC 7R13 Processor CPU family: 25 Model: 1 Thread(s) per core: 2 Core(s) per socket: 48 Socket(s): 2 Stepping: 1 BogoMIPS: 5300.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid Hypervisor vendor: KVM Virtualization type: full L1d cache: 3 MiB (96 instances) L1i cache: 3 MiB (96 instances) L2 cache: 48 MiB (96 instances) L3 cache: 384 MiB (12 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-47,96-143 NUMA node1 CPU(s): 48-95,144-191 Vulnerability Gather data sampling: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; Safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsa: Mitigation; Clear CPU buffers Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Not affected

PR fix notes

PR #4229: fix blockwise FP8 scaled_mm scale layout in Float8BlockwiseLinear

Description (problem / solution / changelog)

Summary

  • replace the non-Triton blockwise FP8 matmul path with functional scaled_mm semantics for forward, grad_x, and grad_weight
  • fix the BlockWise128x128 RHS scale layout by padding the K-block dimension to a multiple of 4, as required by cuBLASLt
  • fix the BlockWise1x128 RHS scale orientation in grad_weight by transposing the activation scales before the matmul
  • use aten._scaled_mm_v2 only under torch.compile(fullgraph=True) so the compiler can trace the op without graph-breaking on the Python F.scaled_mm wrapper

What was going wrong?

The issue was not just that we were calling torch._scaled_mm, but that this path in our blockwise FP8 linear code was not matching the cuBLASLt scale-layout contract for blockwise scaling. In particular:

  • grad_x = grad_output @ weight uses RHS BlockWise128x128 scales
  • cuBLASLt requires those scales to be K-major, with the K-block dimension padded to a multiple of 4
  • our code was passing the unpadded layout, which caused incorrect behavior in the scaled-mm backend

There was a second layout bug in grad_weight = grad_output^T @ x:

  • the RHS uses BlockWise1x128 scaling
  • the quantized activation scales had the right values but the wrong orientation for the RHS scaled-mm call
  • transposing those scales fixes the contract mismatch

Also to note:

  • eager mode uses torch.nn.functional.scaled_mm
  • compile mode calls aten._scaled_mm_v2 directly because, in this torch build, Dynamo fullgraph cannot trace through the Python F.scaled_mm wrapper even though it lowers to the same underlying op

Testing

 pytest -q test/prototype/blockwise_fp8_training/test_blockwise_linear.py

cc @slayton58 @drisspg as part of #4209

Changed files

  • torchao/prototype/blockwise_fp8_training/linear.py (modified, +104/-28)

PR #180668: Add torch._scaled_mm_v2 to the trace rules set

Description (problem / solution / changelog)

Fixes #180604 by adding torch._scaled_mm_v2 to the rule set.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo @azahed98

Changed files

  • test/dynamo/test_repros.py (modified, +26/-0)
  • torch/_dynamo/trace_rules.py (modified, +1/-0)

Code Example

torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
  Explanation: Dynamo does not know how to trace the builtin `torch._VariableFunctionsClass._scaled_mm_v2.`

---

torch.ops.aten._scaled_mm_v2.default(...)

---

import torch
import torch.nn.functional as F


def make_inputs():
    m, n, k = 15, 32, 16
    a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16).to(torch.float8_e4m3fn)
    b = (
        torch.randn(k, n, device="cuda", dtype=torch.bfloat16)
        .to(torch.float8_e4m3fn)
        .t()
        .contiguous()
        .t()
    )
    scale_a = torch.ones(1, device="cuda", dtype=torch.float32)
    scale_b = torch.ones(1, device="cuda", dtype=torch.float32)
    return a, b, scale_a, scale_b


def via_functional(a, b, scale_a, scale_b):
    return F.scaled_mm(
        a,
        b,
        scale_a=scale_a,
        scale_recipe_a=F.ScalingType.TensorWise,
        swizzle_a=F.SwizzleType.NO_SWIZZLE,
        scale_b=scale_b,
        scale_recipe_b=F.ScalingType.TensorWise,
        swizzle_b=F.SwizzleType.NO_SWIZZLE,
        output_dtype=torch.bfloat16,
    )


def via_aten(a, b, scale_a, scale_b):
    return torch.ops.aten._scaled_mm_v2.default(
        a,
        b,
        [scale_a],
        [F.ScalingType.TensorWise.value],
        [F.SwizzleType.NO_SWIZZLE.value],
        [scale_b],
        [F.ScalingType.TensorWise.value],
        [F.SwizzleType.NO_SWIZZLE.value],
        None,
        torch.bfloat16,
        [],
        False,
    )


a, b, scale_a, scale_b = make_inputs()

print(via_functional(a, b, scale_a, scale_b).shape)
print(torch.compile(via_aten, fullgraph=True)(a, b, scale_a, scale_b).shape)
print(torch.compile(via_functional, fullgraph=True)(a, b, scale_a, scale_b).shape)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

In eager mode, torch.nn.functional.scaled_mm(...) works.

In compile mode with torch.compile(..., fullgraph=True), the same F.scaled_mm(...) call fails in Dynamo, even though it lowers to the same underlying operator as a direct torch.ops.aten._scaled_mm_v2.default(...) call. The direct aten overload compiles successfully on the same inputs.

As a workaround, downstream code currently needs a compile-only workaround that bypasses F.scaled_mm and calls the aten op directly. I don't think this should be necessary.

Observed Behaviour

On the affected build:

  1. Eager F.scaled_mm(...) succeeds.
  2. torch.compile(via_aten, fullgraph=True) succeeds.
  3. torch.compile(via_functional, fullgraph=True) fails in Dynamo.

The failure is:

torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
  Explanation: Dynamo does not know how to trace the builtin `torch._VariableFunctionsClass._scaled_mm_v2.`

The traceback points into torch/nn/functional.py, where scaled_mm calls torch._scaled_mm_v2(...).

Expected Behavior

torch.compile(..., fullgraph=True) should be able to trace torch.nn.functional.scaled_mm(...) without requiring downstream users to special-case compile mode and manually call torch.ops.aten._scaled_mm_v2.default(...).

From further investigation, we discovered that F.scaled_mm currently normalizes Python-side arguments and then calls torch._scaled_mm_v2(...) from torch/nn/functional.py. In the affected build, Dynamo treats that builtin call as skipped / unsupported during fullgraph tracing.

However, the equivalent dispatcher entrypoint works under fullgraph tracing:

torch.ops.aten._scaled_mm_v2.default(...)

So the issue does not appear to be the underlying _scaled_mm_v2 operator itself. The problem appears to be the specific torch._scaled_mm_v2(...) builtin path that F.scaled_mm uses.

Minimal Reproduce Script:

Reproduced with:

  1. PyTorch: 2.12.0a0+git9594ae4
  2. CUDA: 13.0
  3. GPU: NVIDIA H100 80GB HBM3
  4. CUDA capability: (9, 0)
  5. Platform: Linux-6.12.77-99.140.amzn2023.x86_64-x86_64-with-glibc2.39
import torch
import torch.nn.functional as F


def make_inputs():
    m, n, k = 15, 32, 16
    a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16).to(torch.float8_e4m3fn)
    b = (
        torch.randn(k, n, device="cuda", dtype=torch.bfloat16)
        .to(torch.float8_e4m3fn)
        .t()
        .contiguous()
        .t()
    )
    scale_a = torch.ones(1, device="cuda", dtype=torch.float32)
    scale_b = torch.ones(1, device="cuda", dtype=torch.float32)
    return a, b, scale_a, scale_b


def via_functional(a, b, scale_a, scale_b):
    return F.scaled_mm(
        a,
        b,
        scale_a=scale_a,
        scale_recipe_a=F.ScalingType.TensorWise,
        swizzle_a=F.SwizzleType.NO_SWIZZLE,
        scale_b=scale_b,
        scale_recipe_b=F.ScalingType.TensorWise,
        swizzle_b=F.SwizzleType.NO_SWIZZLE,
        output_dtype=torch.bfloat16,
    )


def via_aten(a, b, scale_a, scale_b):
    return torch.ops.aten._scaled_mm_v2.default(
        a,
        b,
        [scale_a],
        [F.ScalingType.TensorWise.value],
        [F.SwizzleType.NO_SWIZZLE.value],
        [scale_b],
        [F.ScalingType.TensorWise.value],
        [F.SwizzleType.NO_SWIZZLE.value],
        None,
        torch.bfloat16,
        [],
        False,
    )


a, b, scale_a, scale_b = make_inputs()

print(via_functional(a, b, scale_a, scale_b).shape)
print(torch.compile(via_aten, fullgraph=True)(a, b, scale_a, scale_b).shape)
print(torch.compile(via_functional, fullgraph=True)(a, b, scale_a, scale_b).shape)

Error logs

No response

Versions

Collecting environment information... PyTorch version: 2.12.0a0+git9594ae4 Is debug build: False CUDA used to build PyTorch: 13.0 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64) GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0 Clang version: Could not collect CMake version: version 3.31.10 Libc version: glibc-2.39

Python version: 3.12.13 | packaged by conda-forge | (main, Mar 5 2026, 16:50:00) [GCC 14.3.0] (64-bit runtime) Python platform: Linux-6.12.77-99.140.amzn2023.x86_64-x86_64-with-glibc2.39 Is CUDA available: True CUDA runtime version: 13.2.78 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3 GPU 1: NVIDIA H100 80GB HBM3

Nvidia driver version: 595.58.03 cuDNN version: Could not collect Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 192 On-line CPU(s) list: 0-191 Vendor ID: AuthenticAMD Model name: AMD EPYC 7R13 Processor CPU family: 25 Model: 1 Thread(s) per core: 2 Core(s) per socket: 48 Socket(s): 2 Stepping: 1 BogoMIPS: 5300.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid Hypervisor vendor: KVM Virtualization type: full L1d cache: 3 MiB (96 instances) L1i cache: 3 MiB (96 instances) L2 cache: 48 MiB (96 instances) L3 cache: 384 MiB (12 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-47,96-143 NUMA node1 CPU(s): 48-95,144-191 Vulnerability Gather data sampling: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; Safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsa: Mitigation; Clear CPU buffers Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Not affected

Versions of relevant libraries: [pip3] flake8==7.3.0 [pip3] intel-cmplr-lib-ur==2025.3.2 [pip3] intel-openmp==2025.3.2 [pip3] mkl-include==2025.3.1 [pip3] mkl-static==2025.3.1 [pip3] numpy==2.4.3 [pip3] onemkl-license==2025.3.1 [pip3] optree==0.19.0 [pip3] tbb==2022.3.1 [pip3] tbb-devel==2022.3.1 [pip3] tcmlib==1.4.1 [pip3] torch==2.12.0a0+git9594ae4 [pip3] torchao==0.17.0+gitcf0b50ae1 [pip3] triton==3.6.0+git9844da95 [pip3] umf==1.0.3 [conda] intel-cmplr-lib-ur 2025.3.2 pypi_0 pypi [conda] intel-openmp 2025.3.2 pypi_0 pypi [conda] mkl-include 2025.3.1 pypi_0 pypi [conda] mkl-static 2025.3.1 pypi_0 pypi [conda] numpy 2.4.3 pypi_0 pypi [conda] onemkl-license 2025.3.1 pypi_0 pypi [conda] optree 0.19.0 pypi_0 pypi [conda] tbb 2022.3.1 pypi_0 pypi [conda] tbb-devel 2022.3.1 pypi_0 pypi [conda] tcmlib 1.4.1 pypi_0 pypi [conda] torch 2.12.0a0+git9594ae4 pypi_0 pypi [conda] torchao 0.17.0+gitcf0b50ae1 pypi_0 pypi [conda] triton 3.6.0+git9844da95 pypi_0 pypi [conda] umf 1.0.3 pypi_0 pypi

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @Lucaskabela @jataylo @azahed98

extent analysis

TL;DR

The issue can be worked around by directly calling torch.ops.aten._scaled_mm_v2.default instead of F.scaled_mm when using torch.compile with fullgraph=True.

Guidance

  • The problem appears to be specific to the torch._scaled_mm_v2 builtin path used by F.scaled_mm, which is treated as skipped/unsupported by Dynamo during fullgraph tracing.
  • To verify the issue, run the provided minimal reproduce script and check if the error occurs when calling torch.compile(via_functional, fullgraph=True).
  • As a temporary workaround, modify the code to use torch.ops.aten._scaled_mm_v2.default directly, as shown in the via_aten function in the reproduce script.
  • The root cause of the issue is likely related to how Dynamo handles the torch._scaled_mm_v2 builtin, and may require changes to the PyTorch or Dynamo codebase to fix.

Example

# Workaround: use torch.ops.aten._scaled_mm_v2.default directly
def via_aten(a, b, scale_a, scale_b):
    return torch.ops.aten._scaled_mm_v2.default(
        a,
        b,
        [scale_a],
        [F.ScalingType.TensorWise.value],
        [F.SwizzleType.NO_SWIZZLE.value],
        [scale_b],
        [F.ScalingType.TensorWise.value],
        [F.SwizzleType.NO_SWIZZLE.value],
        None,
        torch.bfloat16,
        [],
        False,
    )

Notes

  • This workaround may not be desirable in the long term, as it requires modifying user code to work around a potential issue in the PyTorch or Dynamo codebase.
  • Further investigation is needed to determine the root cause of the issue and to develop a more permanent fix.

Recommendation

Apply the workaround by using torch.ops.aten._scaled_mm_v2.default directly, as it allows the code to compile and run successfully with torch.compile and fullgraph=True.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING