pytorch - 💡(How to fix) Fix `scaled_dot_product_attention` bf16 CPU output changes between PyTorch 2.11 and 2.12 on Arm CPUs

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

import torch import torch.nn.functional as F

def make_bf16(seed, shape, device): g = torch.Generator(device="cpu").manual_seed(seed) x = torch.randn(shape, generator=g, dtype=torch.float32) return x.to(device=device, dtype=torch.bfloat16)

def run(device): shape = (1, 4, 32, 16)

try:
    q = make_bf16(1, shape, device)
    k = make_bf16(2, shape, device)
    v = make_bf16(3, shape, device)

    with torch.no_grad():
        out = F.scaled_dot_product_attention(
            q, k, v, dropout_p=0.0, is_causal=False
        )

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    elif device.type == "mps":
        torch.mps.synchronize()

    out = out.float().cpu()

    print(f"\ntorch={torch.__version__}")
    print(f"device={device}")
    print(f"sum={out.sum().item():.12f}")
    print(f"mean={out.mean().item():.12f}")
    print(f"norm={torch.linalg.vector_norm(out).item():.12f}")
    print("out[0,0,0,:8]=")

    for x in out[0, 0, 0, :8]:
        print(f"  {x.item():.12f}")

except Exception as e:
    print(f"\ntorch={torch.__version__}")
    print(f"device={device}")
    print(f"error={type(e).__name__}: {e}")

print(f"torch={torch.version}") print(f"cuda_available={torch.cuda.is_available()}")

if torch.cuda.is_available(): print(f"cuda_device_count={torch.cuda.device_count()}") for i in range(torch.cuda.device_count()): print(f"cuda:{i}_name={torch.cuda.get_device_name(i)}") free, total = torch.cuda.mem_get_info(i) print(f"cuda:{i}_free_memory={free / 10243:.2f} GiB") print(f"cuda:{i}_total_memory={total / 10243:.2f} GiB")

devices = [torch.device("cpu")]

if torch.cuda.is_available(): devices.append(torch.device("cuda:0"))

if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): devices.append(torch.device("mps"))

for device in devices: run(device)

Fix Action

Fix / Workaround

CPU: Architecture: aarch64 CPU op-mode(s): 64-bit Byte Order: Little Endian CPU(s): 20 On-line CPU(s) list: 0-19 Vendor ID: ARM Model name: Cortex-X925 Model: 1 Thread(s) per core: 1 Core(s) per socket: 10 Socket(s): 1 Stepping: r0p1 Frequency boost: disabled CPU(s) scaling MHz: 100% CPU max MHz: 3900.0000 CPU min MHz: 1378.0000 BogoMIPS: 2000.00 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt Model name: Cortex-A725 Model: 1 Thread(s) per core: 1 Core(s) per socket: 10 Socket(s): 1 Stepping: r0p1 CPU(s) scaling MHz: 100% CPU max MHz: 2808.0000 CPU min MHz: 338.0000 BogoMIPS: 2000.00 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt L1d cache: 1.3 MiB (20 instances) L1i cache: 1.3 MiB (20 instances) L2 cache: 25 MiB (20 instances) L3 cache: 24 MiB (2 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-19 Vulnerability Gather data sampling: Not affected Vulnerability Ghostwrite: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Old microcode: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; __user pointer sanitization Vulnerability Spectre v2: Mitigation; CSV2, BHB Vulnerability Srbds: Not affected Vulnerability Tsa: Not affected Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Not affected

CPU: Architecture: aarch64 CPU op-mode(s): 64-bit Byte Order: Little Endian CPU(s): 20 On-line CPU(s) list: 0-19 Vendor ID: ARM Model name: Cortex-X925 Model: 1 Thread(s) per core: 1 Core(s) per socket: 10 Socket(s): 1 Stepping: r0p1 Frequency boost: disabled CPU(s) scaling MHz: 100% CPU max MHz: 3900.0000 CPU min MHz: 1378.0000 BogoMIPS: 2000.00 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt Model name: Cortex-A725 Model: 1 Thread(s) per core: 1 Core(s) per socket: 10 Socket(s): 1 Stepping: r0p1 CPU(s) scaling MHz: 100% CPU max MHz: 2808.0000 CPU min MHz: 338.0000 BogoMIPS: 2000.00 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt L1d cache: 1.3 MiB (20 instances) L1i cache: 1.3 MiB (20 instances) L2 cache: 25 MiB (20 instances) L3 cache: 24 MiB (2 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-19 Vulnerability Gather data sampling: Not affected Vulnerability Ghostwrite: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Old microcode: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; __user pointer sanitization Vulnerability Spectre v2: Mitigation; CSV2, BHB Vulnerability Srbds: Not affected Vulnerability Tsa: Not affected Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Not affected

Code Example

import torch
import torch.nn.functional as F


def make_bf16(seed, shape, device):
    g = torch.Generator(device="cpu").manual_seed(seed)
    x = torch.randn(shape, generator=g, dtype=torch.float32)
    return x.to(device=device, dtype=torch.bfloat16)


def run(device):
    shape = (1, 4, 32, 16)

    try:
        q = make_bf16(1, shape, device)
        k = make_bf16(2, shape, device)
        v = make_bf16(3, shape, device)

        with torch.no_grad():
            out = F.scaled_dot_product_attention(
                q, k, v, dropout_p=0.0, is_causal=False
            )

        if device.type == "cuda":
            torch.cuda.synchronize(device)
        elif device.type == "mps":
            torch.mps.synchronize()

        out = out.float().cpu()

        print(f"\ntorch={torch.__version__}")
        print(f"device={device}")
        print(f"sum={out.sum().item():.12f}")
        print(f"mean={out.mean().item():.12f}")
        print(f"norm={torch.linalg.vector_norm(out).item():.12f}")
        print("out[0,0,0,:8]=")

        for x in out[0, 0, 0, :8]:
            print(f"  {x.item():.12f}")

    except Exception as e:
        print(f"\ntorch={torch.__version__}")
        print(f"device={device}")
        print(f"error={type(e).__name__}: {e}")


print(f"torch={torch.__version__}")
print(f"cuda_available={torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"cuda_device_count={torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"cuda:{i}_name={torch.cuda.get_device_name(i)}")
        free, total = torch.cuda.mem_get_info(i)
        print(f"cuda:{i}_free_memory={free / 1024**3:.2f} GiB")
        print(f"cuda:{i}_total_memory={total / 1024**3:.2f} GiB")

devices = [torch.device("cpu")]

if torch.cuda.is_available():
    devices.append(torch.device("cuda:0"))

if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    devices.append(torch.device("mps"))

for device in devices:
    run(device)

---

Torch 2.11 baseline: 0.003967285156
Torch 2.12 CPU:      0.003631591797
Difference:         -0.000335693359
BF16 ULP diff:       11 ULP

---

Torch 2.11 baseline:
sum  = -81.527687072754
mean = -0.039808440953
norm = 12.513469696045

Torch 2.12 CPU, macOS M4 Max:
sum  = -81.517723083496
mean = -0.039803575724
norm = 12.513530731201

Torch 2.12 CPU, DGX Spark / NVIDIA GB10:
sum  = -81.518699645996
mean = -0.039804052562
norm = 12.513515472412

---

torch=2.12.0
cuda_available=False

torch=2.12.0
device=cpu
sum=-81.517723083496
mean=-0.039803575724
norm=12.513530731201
out[0,0,0,:8]=
  0.003631591797
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.030151367188
  -0.458984375000
  0.308593750000

torch=2.12.0
device=mps
sum=-81.540939331055
mean=-0.039814911783
norm=12.513702392578
out[0,0,0,:8]=
  0.003540039062
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.030151367188
  -0.460937500000
  0.308593750000


torch=2.11.0
cuda_available=False

torch=2.11.0
device=cpu
sum=-81.527687072754
mean=-0.039808440953
norm=12.513469696045
out[0,0,0,:8]=
  0.003967285156
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.029907226562
  -0.460937500000
  0.308593750000

torch=2.11.0
device=mps
sum=-81.540939331055
mean=-0.039814911783
norm=12.513702392578
out[0,0,0,:8]=
  0.003540039062
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.030151367188
  -0.460937500000
  0.308593750000

---

torch=2.12.0+cu130
cuda_available=True
cuda_device_count=1
cuda:0_name=NVIDIA GB10
cuda:0_free_memory=0.53 GiB
cuda:0_total_memory=121.69 GiB

torch=2.12.0+cu130
device=cpu
sum=-81.518699645996
mean=-0.039804052562
norm=12.513515472412
out[0,0,0,:8]=
  0.003631591797
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.030151367188
  -0.458984375000
  0.308593750000

torch=2.12.0+cu130
device=cuda:0
sum=-81.527687072754
mean=-0.039808440953
norm=12.513469696045
out[0,0,0,:8]=
  0.003967285156
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.029907226562
  -0.460937500000
  0.308593750000


torch=2.11.0+cu130
cuda_available=True
cuda_device_count=1
cuda:0_name=NVIDIA GB10
cuda:0_free_memory=0.63 GiB
cuda:0_total_memory=121.69 GiB

torch=2.11.0+cu130
device=cpu
sum=-81.527687072754
mean=-0.039808440953
norm=12.513469696045
out[0,0,0,:8]=
  0.003967285156
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.029907226562
  -0.460937500000
  0.308593750000

torch=2.11.0+cu130
device=cuda:0
sum=-81.527687072754
mean=-0.039808440953
norm=12.513469696045
out[0,0,0,:8]=
  0.003967285156
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.029907226562
  -0.460937500000
  0.308593750000

---

Collecting environment information...
PyTorch version: 2.11.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (aarch64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version: Could not collect
CMake version: version 3.31.10
Libc version: glibc-2.39

Python version: 3.12.13 | packaged by conda-forge | (main, Mar  5 2026, 16:39:32) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.17.0-1014-nvidia-aarch64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 13.0.88
CUDA_MODULE_LOADING set to: 
GPU models and configuration: GPU 0: NVIDIA GB10
Nvidia driver version: 580.142
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            aarch64
CPU op-mode(s):                          64-bit
Byte Order:                              Little Endian
CPU(s):                                  20
On-line CPU(s) list:                     0-19
Vendor ID:                               ARM
Model name:                              Cortex-X925
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      10
Socket(s):                               1
Stepping:                                r0p1
Frequency boost:                         disabled
CPU(s) scaling MHz:                      100%
CPU max MHz:                             3900.0000
CPU min MHz:                             1378.0000
BogoMIPS:                                2000.00
Flags:                                   fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
Model name:                              Cortex-A725
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      10
Socket(s):                               1
Stepping:                                r0p1
CPU(s) scaling MHz:                      100%
CPU max MHz:                             2808.0000
CPU min MHz:                             338.0000
BogoMIPS:                                2000.00
Flags:                                   fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
L1d cache:                               1.3 MiB (20 instances)
L1i cache:                               1.3 MiB (20 instances)
L2 cache:                                25 MiB (20 instances)
L3 cache:                                24 MiB (2 instances)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-19
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; CSV2, BHB
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Not affected

Versions of relevant libraries:
[pip3] numpy==2.1.3
[pip3] nvidia-cublas==13.1.0.3
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cudnn-cu13==9.19.0.56
[pip3] nvidia-cudnn-frontend==1.23.0
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-nccl-cu13==2.28.9
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvtx==13.0.85
[pip3] torch==2.11.0
[pip3] torch_c_dlpack_ext==0.1.5
[pip3] torch_memory_saver==0.0.9.post1
[pip3] torchao==0.18.0+gitdc538c33
[pip3] torchaudio==2.11.0
[pip3] torchvision==0.26.0
[pip3] tosa2torch==0.0.1
[pip3] triton==3.6.0
[conda] numpy                                 2.1.3               pypi_0                pypi
[conda] nvidia-cublas                         13.1.0.3            pypi_0                pypi
[conda] nvidia-cuda-cupti                     13.0.85             pypi_0                pypi
[conda] nvidia-cuda-nvrtc                     13.0.88             pypi_0                pypi
[conda] nvidia-cuda-runtime                   13.0.96             pypi_0                pypi
[conda] nvidia-cudnn-cu13                     9.19.0.56           pypi_0                pypi
[conda] nvidia-cudnn-frontend                 1.23.0              pypi_0                pypi
[conda] nvidia-cufft                          12.0.0.61           pypi_0                pypi
[conda] nvidia-curand                         10.4.0.35           pypi_0                pypi
[conda] nvidia-cusolver                       12.0.4.66           pypi_0                pypi
[conda] nvidia-cusparse                       12.6.3.3            pypi_0                pypi
[conda] nvidia-cusparselt-cu13                0.8.0               pypi_0                pypi
[conda] nvidia-nccl-cu13                      2.28.9              pypi_0                pypi
[conda] nvidia-nvjitlink                      13.0.88             pypi_0                pypi
[conda] nvidia-nvtx                           13.0.85             pypi_0                pypi
[conda] torch                                 2.11.0              pypi_0                pypi
[conda] torch-c-dlpack-ext                    0.1.5               pypi_0                pypi
[conda] torch-memory-saver                    0.0.9.post1         pypi_0                pypi
[conda] torchao                               0.18.0+gitdc538c33  pypi_0                pypi
[conda] torchaudio                            2.11.0              pypi_0                pypi
[conda] torchvision                           0.26.0              pypi_0                pypi
[conda] tosa2torch                            0.0.1               pypi_0                pypi
[conda] triton                                3.6.0               pypi_0                pypi

---

Collecting environment information...
PyTorch version: 2.12.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (aarch64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.13 | packaged by conda-forge | (main, Mar  5 2026, 16:39:32) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.17.0-1014-nvidia-aarch64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 13.0.88
CUDA_MODULE_LOADING set to: 
GPU models and configuration: GPU 0: NVIDIA GB10
Nvidia driver version: 580.142
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            aarch64
CPU op-mode(s):                          64-bit
Byte Order:                              Little Endian
CPU(s):                                  20
On-line CPU(s) list:                     0-19
Vendor ID:                               ARM
Model name:                              Cortex-X925
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      10
Socket(s):                               1
Stepping:                                r0p1
Frequency boost:                         disabled
CPU(s) scaling MHz:                      100%
CPU max MHz:                             3900.0000
CPU min MHz:                             1378.0000
BogoMIPS:                                2000.00
Flags:                                   fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
Model name:                              Cortex-A725
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      10
Socket(s):                               1
Stepping:                                r0p1
CPU(s) scaling MHz:                      100%
CPU max MHz:                             2808.0000
CPU min MHz:                             338.0000
BogoMIPS:                                2000.00
Flags:                                   fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
L1d cache:                               1.3 MiB (20 instances)
L1i cache:                               1.3 MiB (20 instances)
L2 cache:                                25 MiB (20 instances)
L3 cache:                                24 MiB (2 instances)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-19
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; CSV2, BHB
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Not affected

Versions of relevant libraries:
[pip3] nvidia-cublas==13.1.1.3
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cudnn-cu13==9.20.0.48
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.1
[pip3] nvidia-nccl-cu13==2.29.7
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvtx==13.0.85
[pip3] torch==2.12.0
[pip3] triton==3.7.0
[conda] nvidia-cublas              13.1.1.3         pypi_0                pypi
[conda] nvidia-cuda-cupti          13.0.85          pypi_0                pypi
[conda] nvidia-cuda-nvrtc          13.0.88          pypi_0                pypi
[conda] nvidia-cuda-runtime        13.0.96          pypi_0                pypi
[conda] nvidia-cudnn-cu13          9.20.0.48        pypi_0                pypi
[conda] nvidia-cufft               12.0.0.61        pypi_0                pypi
[conda] nvidia-curand              10.4.0.35        pypi_0                pypi
[conda] nvidia-cusolver            12.0.4.66        pypi_0                pypi
[conda] nvidia-cusparse            12.6.3.3         pypi_0                pypi
[conda] nvidia-cusparselt-cu13     0.8.1            pypi_0                pypi
[conda] nvidia-nccl-cu13           2.29.7           pypi_0                pypi
[conda] nvidia-nvjitlink           13.0.88          pypi_0                pypi
[conda] nvidia-nvtx                13.0.85          pypi_0                pypi
[conda] torch                      2.12.0           pypi_0                pypi
[conda] triton                     3.7.0            pypi_0                pypi

---

Collecting environment information...
PyTorch version: 2.11.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.7.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.4.2)
CMake version: version 4.3.1
Libc version: N/A

Python version: 3.12.13 | packaged by conda-forge | (main, Mar  5 2026, 17:06:14) [Clang 19.1.7 ] (64-bit runtime)
Python platform: macOS-15.7.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Max

Versions of relevant libraries:
[pip3] numpy==2.4.4
[pip3] torch==2.11.0
[conda] executorch                            1.3.0a0+502d2de     pypi_0              pypi
[conda] numpy                                 2.1.3               pypi_0              pypi
[conda] pytorch-tokenizers                    1.1.0               pypi_0              pypi
[conda] torch                                 2.12.0              pypi_0              pypi
[conda] torchao                               0.18.0+gitb9f7744c  pypi_0              pypi
[conda] torchdata                             0.11.0+cpu          pypi_0              pypi
[conda] torchfix                              0.6.0               pypi_0              pypi
[conda] torchvision                           0.27.0              pypi_0              pypi

---

Collecting environment information...
PyTorch version: 2.12.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.7.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.4.2)
CMake version: version 4.3.1
Libc version: N/A

Python version: 3.12.13 | packaged by conda-forge | (main, Mar  5 2026, 17:06:14) [Clang 19.1.7 ] (64-bit runtime)
Python platform: macOS-15.7.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Max

Versions of relevant libraries:
[pip3] executorch==1.3.0a0+502d2de
[pip3] flake8==6.1.0
[pip3] flake8-breakpoint==1.1.0
[pip3] flake8-bugbear==24.4.26
[pip3] flake8-comprehensions==3.14.0
[pip3] flake8-plugin-utils==1.3.3
[pip3] flake8-pyi==23.5.0
[pip3] mypy==1.14.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.1.3
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.12.0
[pip3] torchao==0.18.0+gitb9f7744c
[pip3] torchdata==0.11.0+cpu
[pip3] torchvision==0.27.0
[conda] executorch                            1.3.0a0+502d2de     pypi_0              pypi
[conda] numpy                                 2.1.3               pypi_0              pypi
[conda] pytorch-tokenizers                    1.1.0               pypi_0              pypi
[conda] torch                                 2.12.0              pypi_0              pypi
[conda] torchao                               0.18.0+gitb9f7744c  pypi_0              pypi
[conda] torchdata                             0.11.0+cpu          pypi_0              pypi
[conda] torchfix                              0.6.0               pypi_0              pypi
[conda] torchvision                           0.27.0              pypi_0              pypi
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.nn.functional.scaled_dot_product_attention produces different bfloat16 CPU results on Arm CPUs in PyTorch 2.12 compared with PyTorch 2.11 for the same seeded inputs.

The issue is visible on:

  • macOS M4 Max / Apple Silicon CPU
  • DGX Spark / NVIDIA GB10 Arm CPU

Using PyTorch 2.11 as the baseline, PyTorch 2.12 CPU output changes, while PyTorch 2.12 CUDA on DGX Spark still matches the PyTorch 2.11 CPU/CUDA baseline.

Possible related PR: https://github.com/pytorch/pytorch/pull/176881

To reproduce

import torch
import torch.nn.functional as F


def make_bf16(seed, shape, device):
    g = torch.Generator(device="cpu").manual_seed(seed)
    x = torch.randn(shape, generator=g, dtype=torch.float32)
    return x.to(device=device, dtype=torch.bfloat16)


def run(device):
    shape = (1, 4, 32, 16)

    try:
        q = make_bf16(1, shape, device)
        k = make_bf16(2, shape, device)
        v = make_bf16(3, shape, device)

        with torch.no_grad():
            out = F.scaled_dot_product_attention(
                q, k, v, dropout_p=0.0, is_causal=False
            )

        if device.type == "cuda":
            torch.cuda.synchronize(device)
        elif device.type == "mps":
            torch.mps.synchronize()

        out = out.float().cpu()

        print(f"\ntorch={torch.__version__}")
        print(f"device={device}")
        print(f"sum={out.sum().item():.12f}")
        print(f"mean={out.mean().item():.12f}")
        print(f"norm={torch.linalg.vector_norm(out).item():.12f}")
        print("out[0,0,0,:8]=")

        for x in out[0, 0, 0, :8]:
            print(f"  {x.item():.12f}")

    except Exception as e:
        print(f"\ntorch={torch.__version__}")
        print(f"device={device}")
        print(f"error={type(e).__name__}: {e}")


print(f"torch={torch.__version__}")
print(f"cuda_available={torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"cuda_device_count={torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"cuda:{i}_name={torch.cuda.get_device_name(i)}")
        free, total = torch.cuda.mem_get_info(i)
        print(f"cuda:{i}_free_memory={free / 1024**3:.2f} GiB")
        print(f"cuda:{i}_total_memory={total / 1024**3:.2f} GiB")

devices = [torch.device("cpu")]

if torch.cuda.is_available():
    devices.append(torch.device("cuda:0"))

if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    devices.append(torch.device("mps"))

for device in devices:
    run(device)

Observed result

One representative output element changes between PyTorch 2.11 and PyTorch 2.12 CPU on Arm:

Torch 2.11 baseline: 0.003967285156
Torch 2.12 CPU:      0.003631591797
Difference:         -0.000335693359
BF16 ULP diff:       11 ULP

The aggregate output also changes:

Torch 2.11 baseline:
sum  = -81.527687072754
mean = -0.039808440953
norm = 12.513469696045

Torch 2.12 CPU, macOS M4 Max:
sum  = -81.517723083496
mean = -0.039803575724
norm = 12.513530731201

Torch 2.12 CPU, DGX Spark / NVIDIA GB10:
sum  = -81.518699645996
mean = -0.039804052562
norm = 12.513515472412
<details> <summary>Full comparison table</summary>
HostCPU archTorchBackendSumMeanNormout[0,0,0,0]
macOS M4 MaxArm / Apple Silicon2.11.0CPU-81.527687072754-0.03980844095312.5134696960450.003967285156
DGX Spark / NVIDIA GB10Arm2.11.0+cu130CPU-81.527687072754-0.03980844095312.5134696960450.003967285156
DGX Spark / NVIDIA GB10Arm2.11.0+cu130CUDA-81.527687072754-0.03980844095312.5134696960450.003967285156
DGX Spark / NVIDIA GB10Arm2.12.0+cu130CUDA-81.527687072754-0.03980844095312.5134696960450.003967285156
macOS M4 MaxArm / Apple Silicon2.12.0CPU-81.517723083496-0.03980357572412.5135307312010.003631591797
DGX Spark / NVIDIA GB10Arm2.12.0+cu130CPU-81.518699645996-0.03980405256212.5135154724120.003631591797
macOS M4 MaxArm / Apple Silicon2.11.0MPS-81.540939331055-0.03981491178312.5137023925780.003540039062
macOS M4 MaxArm / Apple Silicon2.12.0MPS-81.540939331055-0.03981491178312.5137023925780.003540039062
</details> <details> <summary>Full macOS M4 Max logs</summary>
torch=2.12.0
cuda_available=False

torch=2.12.0
device=cpu
sum=-81.517723083496
mean=-0.039803575724
norm=12.513530731201
out[0,0,0,:8]=
  0.003631591797
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.030151367188
  -0.458984375000
  0.308593750000

torch=2.12.0
device=mps
sum=-81.540939331055
mean=-0.039814911783
norm=12.513702392578
out[0,0,0,:8]=
  0.003540039062
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.030151367188
  -0.460937500000
  0.308593750000


torch=2.11.0
cuda_available=False

torch=2.11.0
device=cpu
sum=-81.527687072754
mean=-0.039808440953
norm=12.513469696045
out[0,0,0,:8]=
  0.003967285156
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.029907226562
  -0.460937500000
  0.308593750000

torch=2.11.0
device=mps
sum=-81.540939331055
mean=-0.039814911783
norm=12.513702392578
out[0,0,0,:8]=
  0.003540039062
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.030151367188
  -0.460937500000
  0.308593750000
</details> <details> <summary>Full DGX Spark / NVIDIA GB10 logs</summary>
torch=2.12.0+cu130
cuda_available=True
cuda_device_count=1
cuda:0_name=NVIDIA GB10
cuda:0_free_memory=0.53 GiB
cuda:0_total_memory=121.69 GiB

torch=2.12.0+cu130
device=cpu
sum=-81.518699645996
mean=-0.039804052562
norm=12.513515472412
out[0,0,0,:8]=
  0.003631591797
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.030151367188
  -0.458984375000
  0.308593750000

torch=2.12.0+cu130
device=cuda:0
sum=-81.527687072754
mean=-0.039808440953
norm=12.513469696045
out[0,0,0,:8]=
  0.003967285156
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.029907226562
  -0.460937500000
  0.308593750000


torch=2.11.0+cu130
cuda_available=True
cuda_device_count=1
cuda:0_name=NVIDIA GB10
cuda:0_free_memory=0.63 GiB
cuda:0_total_memory=121.69 GiB

torch=2.11.0+cu130
device=cpu
sum=-81.527687072754
mean=-0.039808440953
norm=12.513469696045
out[0,0,0,:8]=
  0.003967285156
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.029907226562
  -0.460937500000
  0.308593750000

torch=2.11.0+cu130
device=cuda:0
sum=-81.527687072754
mean=-0.039808440953
norm=12.513469696045
out[0,0,0,:8]=
  0.003967285156
  0.198242187500
  -0.357421875000
  0.077148437500
  0.621093750000
  0.029907226562
  -0.460937500000
  0.308593750000
</details>

Expected behavior

For the same seeded inputs, shape, dtype, and SDPA configuration, I expected PyTorch 2.12 CPU on Arm to remain numerically consistent with PyTorch 2.11 CPU, or for the numerical behavior change to be documented.

The notable point is that PyTorch 2.12 CPU differs on the tested Arm CPU platforms, while PyTorch 2.12 CUDA on DGX Spark remains identical to the PyTorch 2.11 baseline.

and the change can be 11 ULP

Versions

<details> <summary>NVIDIA GB10 torch 2.11 </summary>

Collecting environment information...
PyTorch version: 2.11.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (aarch64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version: Could not collect
CMake version: version 3.31.10
Libc version: glibc-2.39

Python version: 3.12.13 | packaged by conda-forge | (main, Mar  5 2026, 16:39:32) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.17.0-1014-nvidia-aarch64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 13.0.88
CUDA_MODULE_LOADING set to: 
GPU models and configuration: GPU 0: NVIDIA GB10
Nvidia driver version: 580.142
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            aarch64
CPU op-mode(s):                          64-bit
Byte Order:                              Little Endian
CPU(s):                                  20
On-line CPU(s) list:                     0-19
Vendor ID:                               ARM
Model name:                              Cortex-X925
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      10
Socket(s):                               1
Stepping:                                r0p1
Frequency boost:                         disabled
CPU(s) scaling MHz:                      100%
CPU max MHz:                             3900.0000
CPU min MHz:                             1378.0000
BogoMIPS:                                2000.00
Flags:                                   fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
Model name:                              Cortex-A725
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      10
Socket(s):                               1
Stepping:                                r0p1
CPU(s) scaling MHz:                      100%
CPU max MHz:                             2808.0000
CPU min MHz:                             338.0000
BogoMIPS:                                2000.00
Flags:                                   fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
L1d cache:                               1.3 MiB (20 instances)
L1i cache:                               1.3 MiB (20 instances)
L2 cache:                                25 MiB (20 instances)
L3 cache:                                24 MiB (2 instances)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-19
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; CSV2, BHB
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Not affected

Versions of relevant libraries:
[pip3] numpy==2.1.3
[pip3] nvidia-cublas==13.1.0.3
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cudnn-cu13==9.19.0.56
[pip3] nvidia-cudnn-frontend==1.23.0
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-nccl-cu13==2.28.9
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvtx==13.0.85
[pip3] torch==2.11.0
[pip3] torch_c_dlpack_ext==0.1.5
[pip3] torch_memory_saver==0.0.9.post1
[pip3] torchao==0.18.0+gitdc538c33
[pip3] torchaudio==2.11.0
[pip3] torchvision==0.26.0
[pip3] tosa2torch==0.0.1
[pip3] triton==3.6.0
[conda] numpy                                 2.1.3               pypi_0                pypi
[conda] nvidia-cublas                         13.1.0.3            pypi_0                pypi
[conda] nvidia-cuda-cupti                     13.0.85             pypi_0                pypi
[conda] nvidia-cuda-nvrtc                     13.0.88             pypi_0                pypi
[conda] nvidia-cuda-runtime                   13.0.96             pypi_0                pypi
[conda] nvidia-cudnn-cu13                     9.19.0.56           pypi_0                pypi
[conda] nvidia-cudnn-frontend                 1.23.0              pypi_0                pypi
[conda] nvidia-cufft                          12.0.0.61           pypi_0                pypi
[conda] nvidia-curand                         10.4.0.35           pypi_0                pypi
[conda] nvidia-cusolver                       12.0.4.66           pypi_0                pypi
[conda] nvidia-cusparse                       12.6.3.3            pypi_0                pypi
[conda] nvidia-cusparselt-cu13                0.8.0               pypi_0                pypi
[conda] nvidia-nccl-cu13                      2.28.9              pypi_0                pypi
[conda] nvidia-nvjitlink                      13.0.88             pypi_0                pypi
[conda] nvidia-nvtx                           13.0.85             pypi_0                pypi
[conda] torch                                 2.11.0              pypi_0                pypi
[conda] torch-c-dlpack-ext                    0.1.5               pypi_0                pypi
[conda] torch-memory-saver                    0.0.9.post1         pypi_0                pypi
[conda] torchao                               0.18.0+gitdc538c33  pypi_0                pypi
[conda] torchaudio                            2.11.0              pypi_0                pypi
[conda] torchvision                           0.26.0              pypi_0                pypi
[conda] tosa2torch                            0.0.1               pypi_0                pypi
[conda] triton                                3.6.0               pypi_0                pypi
</details> <details> <summary>NVIDIA GB10 torch 2.12 </summary>
Collecting environment information...
PyTorch version: 2.12.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (aarch64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.13 | packaged by conda-forge | (main, Mar  5 2026, 16:39:32) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.17.0-1014-nvidia-aarch64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 13.0.88
CUDA_MODULE_LOADING set to: 
GPU models and configuration: GPU 0: NVIDIA GB10
Nvidia driver version: 580.142
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            aarch64
CPU op-mode(s):                          64-bit
Byte Order:                              Little Endian
CPU(s):                                  20
On-line CPU(s) list:                     0-19
Vendor ID:                               ARM
Model name:                              Cortex-X925
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      10
Socket(s):                               1
Stepping:                                r0p1
Frequency boost:                         disabled
CPU(s) scaling MHz:                      100%
CPU max MHz:                             3900.0000
CPU min MHz:                             1378.0000
BogoMIPS:                                2000.00
Flags:                                   fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
Model name:                              Cortex-A725
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      10
Socket(s):                               1
Stepping:                                r0p1
CPU(s) scaling MHz:                      100%
CPU max MHz:                             2808.0000
CPU min MHz:                             338.0000
BogoMIPS:                                2000.00
Flags:                                   fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
L1d cache:                               1.3 MiB (20 instances)
L1i cache:                               1.3 MiB (20 instances)
L2 cache:                                25 MiB (20 instances)
L3 cache:                                24 MiB (2 instances)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-19
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; CSV2, BHB
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Not affected

Versions of relevant libraries:
[pip3] nvidia-cublas==13.1.1.3
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cudnn-cu13==9.20.0.48
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.1
[pip3] nvidia-nccl-cu13==2.29.7
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvtx==13.0.85
[pip3] torch==2.12.0
[pip3] triton==3.7.0
[conda] nvidia-cublas              13.1.1.3         pypi_0                pypi
[conda] nvidia-cuda-cupti          13.0.85          pypi_0                pypi
[conda] nvidia-cuda-nvrtc          13.0.88          pypi_0                pypi
[conda] nvidia-cuda-runtime        13.0.96          pypi_0                pypi
[conda] nvidia-cudnn-cu13          9.20.0.48        pypi_0                pypi
[conda] nvidia-cufft               12.0.0.61        pypi_0                pypi
[conda] nvidia-curand              10.4.0.35        pypi_0                pypi
[conda] nvidia-cusolver            12.0.4.66        pypi_0                pypi
[conda] nvidia-cusparse            12.6.3.3         pypi_0                pypi
[conda] nvidia-cusparselt-cu13     0.8.1            pypi_0                pypi
[conda] nvidia-nccl-cu13           2.29.7           pypi_0                pypi
[conda] nvidia-nvjitlink           13.0.88          pypi_0                pypi
[conda] nvidia-nvtx                13.0.85          pypi_0                pypi
[conda] torch                      2.12.0           pypi_0                pypi
[conda] triton                     3.7.0            pypi_0                pypi
</details> <details> <summary> macOS M4 Max torch 2.11 </summary>
Collecting environment information...
PyTorch version: 2.11.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.7.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.4.2)
CMake version: version 4.3.1
Libc version: N/A

Python version: 3.12.13 | packaged by conda-forge | (main, Mar  5 2026, 17:06:14) [Clang 19.1.7 ] (64-bit runtime)
Python platform: macOS-15.7.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Max

Versions of relevant libraries:
[pip3] numpy==2.4.4
[pip3] torch==2.11.0
[conda] executorch                            1.3.0a0+502d2de     pypi_0              pypi
[conda] numpy                                 2.1.3               pypi_0              pypi
[conda] pytorch-tokenizers                    1.1.0               pypi_0              pypi
[conda] torch                                 2.12.0              pypi_0              pypi
[conda] torchao                               0.18.0+gitb9f7744c  pypi_0              pypi
[conda] torchdata                             0.11.0+cpu          pypi_0              pypi
[conda] torchfix                              0.6.0               pypi_0              pypi
[conda] torchvision                           0.27.0              pypi_0              pypi
</details> <details> <summary> macOS M4 Max torch 2.12 </summary>
Collecting environment information...
PyTorch version: 2.12.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.7.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.4.2)
CMake version: version 4.3.1
Libc version: N/A

Python version: 3.12.13 | packaged by conda-forge | (main, Mar  5 2026, 17:06:14) [Clang 19.1.7 ] (64-bit runtime)
Python platform: macOS-15.7.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Max

Versions of relevant libraries:
[pip3] executorch==1.3.0a0+502d2de
[pip3] flake8==6.1.0
[pip3] flake8-breakpoint==1.1.0
[pip3] flake8-bugbear==24.4.26
[pip3] flake8-comprehensions==3.14.0
[pip3] flake8-plugin-utils==1.3.3
[pip3] flake8-pyi==23.5.0
[pip3] mypy==1.14.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.1.3
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.12.0
[pip3] torchao==0.18.0+gitb9f7744c
[pip3] torchdata==0.11.0+cpu
[pip3] torchvision==0.27.0
[conda] executorch                            1.3.0a0+502d2de     pypi_0              pypi
[conda] numpy                                 2.1.3               pypi_0              pypi
[conda] pytorch-tokenizers                    1.1.0               pypi_0              pypi
[conda] torch                                 2.12.0              pypi_0              pypi
[conda] torchao                               0.18.0+gitb9f7744c  pypi_0              pypi
[conda] torchdata                             0.11.0+cpu          pypi_0              pypi
[conda] torchfix                              0.6.0               pypi_0              pypi
[conda] torchvision                           0.27.0              pypi_0              pypi
</details>

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

For the same seeded inputs, shape, dtype, and SDPA configuration, I expected PyTorch 2.12 CPU on Arm to remain numerically consistent with PyTorch 2.11 CPU, or for the numerical behavior change to be documented.

The notable point is that PyTorch 2.12 CPU differs on the tested Arm CPU platforms, while PyTorch 2.12 CUDA on DGX Spark remains identical to the PyTorch 2.11 baseline.

and the change can be 11 ULP

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `scaled_dot_product_attention` bf16 CPU output changes between PyTorch 2.11 and 2.12 on Arm CPUs