pytorch - 💡(How to fix) Fix ROCm `F.layer_norm` silently produces wrong output for large inputs (even below `M*N > 2^32` threshold)

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Code Example

"""ROCm F.layer_norm correctness repro.

Compares F.layer_norm(x) against a chunked reference (same op, slice-by-slice
along the leading dim) and counts how many rows differ by more than 1e-2.

Two shapes are tested:
  A. (1,1,2048,5247,267)M*N = 2.87 B (BELOW 2^32)
  B. (1,1,5247,5247,267)M*N = 7.35 B (ABOVE 2^32 — matches #181555)
"""
import torch
import torch.nn.functional as F


def check(shape, C, chunk=256):
    numel = 1
    for s in shape:
        numel *= s
    print(f"shape={shape} M*N={numel:,} ({numel / 2**32:.2f}x 2^32)")
    torch.manual_seed(0)
    x = torch.randn(*shape, device="cuda", dtype=torch.float32)
    w = torch.ones(C, device="cuda", dtype=torch.float32)

    out_full = F.layer_norm(x, (C,), weight=w, bias=None, eps=1e-5)
    torch.cuda.synchronize()

    bad = 0
    first_bad = -1
    for i in range(0, shape[2], chunk):
        ci = x[:, :, i:i + chunk].contiguous()
        cr = F.layer_norm(ci, (C,), weight=w, bias=None, eps=1e-5)
        diff = (cr - out_full[:, :, i:i + chunk]).abs()
        per_row = diff.amax(dim=(0, 1, 3, 4))
        for r in (per_row > 0.01).nonzero().flatten().tolist():
            bad += 1
            if first_bad < 0:
                first_bad = i + r
        del ci, cr, diff, per_row
    del x, w, out_full
    torch.cuda.empty_cache()
    print(f"  bad_rows={bad}/{shape[2]} first_bad={first_bad}")


if __name__ == "__main__":
    print(f"PyTorch: {torch.__version__}")
    print(f"HIP: {torch.version.hip}")
    print(f"Device: {torch.cuda.get_device_name(0)}\n")
    check((1, 1, 2048, 5247, 267), 267)
    check((1, 1, 5247, 5247, 267), 267)

---

shape=(1, 1, 2048, 5247, 267) M*N=2,869,143,552 (0.67x 2^32)
  bad_rows=1599/2048 first_bad=449
shape=(1, 1, 5247, 5247, 267) M*N=7,350,779,403 (1.71x 2^32)
  bad_rows=4797/5247 first_bad=450

---

[101877579776.0, 101877596160.0, 101877612544.0, 101877628928.0, ...]

---

Collecting environment information...
PyTorch version: 2.13.0.dev20260521+rocm7.2
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 7.2.53211

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jan 26 2026, 14:55:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-1074-oracle-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: AMD Instinct MI300X (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: 7.2.53211
MIOpen runtime version: 3.5.1
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               224
On-line CPU(s) list:                  0-223
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8480+
CPU family:                           6
Model:                                143
Thread(s) per core:                   2
Core(s) per socket:                   56
Socket(s):                            2
Stepping:                             8
CPU max MHz:                          3800.0000
CPU min MHz:                          800.0000
BogoMIPS:                             4000.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx
fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology
nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid
dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb
cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid
ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb
intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local
split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku
ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect
cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8
flush_l1d arch_capabilities
Virtualization:                       VT-x
L1d cache:                            5.3 MiB (112 instances)
L1i cache:                            3.5 MiB (112 instances)
L2 cache:                             224 MiB (112 instances)
L3 cache:                             210 MiB (2 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-55,112-167
NUMA node1 CPU(s):                    56-111,168-223
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Vulnerable
Vulnerability Spectre v1:             Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:             Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Vulnerable; BHI: Vulnerable
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] torch==2.13.0.dev20260521+rocm7.2
[pip3] triton-rocm==3.7.0+git88b227e2
[conda] Could not collect
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

This debugging and this issue report were largely created by Claude Code. I have checked it's

On ROCm, torch.nn.functional.layer_norm silently produces incorrect output (zeros or what looks like uninitialized memory) for the back portion of large 2D-reshaped inputs. Two related fixes have already landed for this same kernel (#144007, #181600), but neither resolves the failures shown below.

The bug is deterministic with torch.manual_seed(0), reproduces on two different AMD GPU architectures, and reproduces on the current PyTorch nightly that already contains the #181600 fix.

<details> <summary>Minimal reproduction</summary>
"""ROCm F.layer_norm correctness repro.

Compares F.layer_norm(x) against a chunked reference (same op, slice-by-slice
along the leading dim) and counts how many rows differ by more than 1e-2.

Two shapes are tested:
  A. (1,1,2048,5247,267) — M*N = 2.87 B (BELOW 2^32)
  B. (1,1,5247,5247,267) — M*N = 7.35 B (ABOVE 2^32 — matches #181555)
"""
import torch
import torch.nn.functional as F


def check(shape, C, chunk=256):
    numel = 1
    for s in shape:
        numel *= s
    print(f"shape={shape} M*N={numel:,} ({numel / 2**32:.2f}x 2^32)")
    torch.manual_seed(0)
    x = torch.randn(*shape, device="cuda", dtype=torch.float32)
    w = torch.ones(C, device="cuda", dtype=torch.float32)

    out_full = F.layer_norm(x, (C,), weight=w, bias=None, eps=1e-5)
    torch.cuda.synchronize()

    bad = 0
    first_bad = -1
    for i in range(0, shape[2], chunk):
        ci = x[:, :, i:i + chunk].contiguous()
        cr = F.layer_norm(ci, (C,), weight=w, bias=None, eps=1e-5)
        diff = (cr - out_full[:, :, i:i + chunk]).abs()
        per_row = diff.amax(dim=(0, 1, 3, 4))
        for r in (per_row > 0.01).nonzero().flatten().tolist():
            bad += 1
            if first_bad < 0:
                first_bad = i + r
        del ci, cr, diff, per_row
    del x, w, out_full
    torch.cuda.empty_cache()
    print(f"  bad_rows={bad}/{shape[2]} first_bad={first_bad}")


if __name__ == "__main__":
    print(f"PyTorch: {torch.__version__}")
    print(f"HIP: {torch.version.hip}")
    print(f"Device: {torch.cuda.get_device_name(0)}\n")
    check((1, 1, 2048, 5247, 267), 267)
    check((1, 1, 5247, 5247, 267), 267)
</details>

Expected output on a correct implementation: bad_rows=0/... for both shapes. Actual output on every configuration tested (see below):

shape=(1, 1, 2048, 5247, 267) M*N=2,869,143,552 (0.67x 2^32)
  bad_rows=1599/2048 first_bad=449
shape=(1, 1, 5247, 5247, 267) M*N=7,350,779,403 (1.71x 2^32)
  bad_rows=4797/5247 first_bad=450

Tested configurations

GPUArchPyTorchShape A (<2^32)Shape B (>2^32)
AMD Instinct MI300Xgfx9422.10.0+rocm7.2.0.gitb6ee5fdebadbad
AMD Instinct MI300Xgfx9422.13.0.dev20260521+rocm7.2badbad
AMD Radeon 8060S (Strix Halo)gfx11512.11.0+rocm7.2bad(not tested — OOM at 60 GB)

The nightly 2.13.0.dev20260521+rocm7.2 was built ~17 days after #181600 was merged (commit 8eae40c5), and the fix is present in its sources, but both shapes still fail.

The failure pattern is identical across all three configurations: bad row count is a multiple of 1599, bad rows always end at d2-1, first bad row is 449 or 450. Independent of dtype (fp32 and bf16 both fail at the same shape).

Specific features

  • CPU F.layer_norm on the same saved input produces correct output (range [-26.3, 26.5]).

  • A chunked GPU F.layer_norm (slice-by-slice along a leading dim) on the same input produces correct output, matching CPU.

  • Bad rows in the output are either zero or filled with values that look like uninitialized memory — e.g. consecutive entries differing by exactly 16384:

    [101877579776.0, 101877596160.0, 101877612544.0, 101877628928.0, ...]
  • The bug requires both normalized_shape > 128 AND d2 ≥ ~2048 (with d3=5247). Tested:

    shape (B, S, d2, d3, C)dtypeM·Nbad rows
    (1,1,5247,5247,267)fp327.35 B4797
    (1,1,5247,5247,267)bf167.35 B4797
    (1,1,2048,5247,267)fp322.87 B1599
    (1,1,1024,5247,267)fp321.43 B0
    (1,1,5247,5247,128)fp323.52 B0
    (1,1,5247,5247,64)fp321.76 B0
  • Wrapping the call in a chunked-launch wrapper (split input along a leading dim into pieces of ≤ 2^30 elements before calling F.layer_norm) makes the output correct.

Speculation by Claude

  • Why doesn't the #181600 int64_t cast in vectorized_layer_norm_kernel help on ROCm?
    • The ROCm chunked-launch path added by #144007 does X_data2 += N * blocks.x with int N and uint blocks.x — that arithmetic is in uint32_t and could wrap, but for our tested blocks.x values it shouldn't. This is a guess.
    • For our shape A, M * N = 2.87 B exceeds signed-int32 max but not unsigned. If any compile-time path on ROCm picks signed promotion for a blockIdx.x * N site that #181600 didn't change, we'd see overflow at exactly our threshold. Have not looked at the compiled SASS/GCN code to confirm.
    • The first bad row in M-space is ~449 * 5247 = 2,355,903 for shape A and ~450 * 5247 = 2,361,150 for shape B. These don't sit on any obvious power-of-two boundary.

Related

  • #136291 — original ROCm layer_norm crash on large tensors
  • #144007 — ROCm chunked-launch fix for the crash above
  • #181555 — CUDA layer_norm wrong-output bug for M*N > 2^32
  • #181600 — int64_t cast fix for [#181555] (merged; the nightly we tested includes it)

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang

Versions

<details><summary>For latest nightly on MI300X</summary>
Collecting environment information...
PyTorch version: 2.13.0.dev20260521+rocm7.2
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 7.2.53211

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jan 26 2026, 14:55:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-1074-oracle-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: AMD Instinct MI300X (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: 7.2.53211
MIOpen runtime version: 3.5.1
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               224
On-line CPU(s) list:                  0-223
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8480+
CPU family:                           6
Model:                                143
Thread(s) per core:                   2
Core(s) per socket:                   56
Socket(s):                            2
Stepping:                             8
CPU max MHz:                          3800.0000
CPU min MHz:                          800.0000
BogoMIPS:                             4000.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx
fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology
nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid
dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb
cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid
ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb
intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local
split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku
ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect
cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8
flush_l1d arch_capabilities
Virtualization:                       VT-x
L1d cache:                            5.3 MiB (112 instances)
L1i cache:                            3.5 MiB (112 instances)
L2 cache:                             224 MiB (112 instances)
L3 cache:                             210 MiB (2 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-55,112-167
NUMA node1 CPU(s):                    56-111,168-223
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Vulnerable
Vulnerability Spectre v1:             Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:             Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Vulnerable; BHI: Vulnerable
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] torch==2.13.0.dev20260521+rocm7.2
[pip3] triton-rocm==3.7.0+git88b227e2
[conda] Could not collect
</details>

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING