pytorch - 💡(How to fix) Fix [MPS] Incorrect gradients for `torch.bmm` with `complex64` tensors [1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177734Fetched 2026-04-08 00:58:03
View on GitHub
Comments
1
Participants
2
Timeline
70
Reactions
0
Participants
Timeline (top)
mentioned ×30subscribed ×30labeled ×7closed ×1

Error Message

import torch

torch.manual_seed(42) B, N = 4, 64

A_cpu = torch.randn(B, N, N, dtype=torch.complex64) x_cpu = torch.randn(B, N, 1, dtype=torch.complex64)

CPU reference

A_c = A_cpu.clone().requires_grad_(True) x_c = x_cpu.clone().requires_grad_(True) loss_c = (torch.bmm(A_c, x_c).abs() ** 2).sum() loss_c.backward()

MPS

A_m = A_cpu.clone().to("mps").requires_grad_(True) x_m = x_cpu.clone().to("mps").requires_grad_(True) loss_m = (torch.bmm(A_m, x_m).abs() ** 2).sum() loss_m.backward()

Compare

print("Forward max error:", (torch.bmm(A_m, x_m).cpu() - torch.bmm(A_c, x_c)).abs().max().item()) print("grad_x rel error:", (x_m.grad.cpu() - x_c.grad).abs().max().item() / x_c.grad.abs().max().item()) print("grad_A rel error:", (A_m.grad.cpu() - A_c.grad).abs().max().item() / A_c.grad.abs().max().item())

Root Cause

CPU converges normally. MPS diverges because the gradients point in the wrong direction.

Fix Action

Workaround

Split complex matrices into real/imaginary float32 parts and compute the complex product manually:

def complex_bmm_real(M_re, M_im, v_re, v_im):
    vr, vi = v_re.unsqueeze(-1), v_im.unsqueeze(-1)
    out_re = torch.bmm(M_re, vr) - torch.bmm(M_im, vi)
    out_im = torch.bmm(M_re, vi) + torch.bmm(M_im, vr)
    return out_re.squeeze(-1), out_im.squeeze(-1)

This gives correct gradients on MPS since float32 bmm backward is correct.

Code Example

import torch

torch.manual_seed(42)
B, N = 4, 64

A_cpu = torch.randn(B, N, N, dtype=torch.complex64)
x_cpu = torch.randn(B, N, 1, dtype=torch.complex64)

# CPU reference
A_c = A_cpu.clone().requires_grad_(True)
x_c = x_cpu.clone().requires_grad_(True)
loss_c = (torch.bmm(A_c, x_c).abs() ** 2).sum()
loss_c.backward()

# MPS
A_m = A_cpu.clone().to("mps").requires_grad_(True)
x_m = x_cpu.clone().to("mps").requires_grad_(True)
loss_m = (torch.bmm(A_m, x_m).abs() ** 2).sum()
loss_m.backward()

# Compare
print("Forward max error:", (torch.bmm(A_m, x_m).cpu() - torch.bmm(A_c, x_c)).abs().max().item())
print("grad_x rel error:", (x_m.grad.cpu() - x_c.grad).abs().max().item() / x_c.grad.abs().max().item())
print("grad_A rel error:", (A_m.grad.cpu() - A_c.grad).abs().max().item() / A_c.grad.abs().max().item())

---

Forward max error: 2.384185791015625e-06
grad_x rel error: 1.14
grad_A rel error: 1.90

---

def complex_bmm_real(M_re, M_im, v_re, v_im):
    vr, vi = v_re.unsqueeze(-1), v_im.unsqueeze(-1)
    out_re = torch.bmm(M_re, vr) - torch.bmm(M_im, vi)
    out_im = torch.bmm(M_re, vi) + torch.bmm(M_im, vr)
    return out_re.squeeze(-1), out_im.squeeze(-1)

---

"""Minimal reproducer: MPS complex64 bmm backward gives wrong gradients.

Compares gradients of torch.bmm with complex64 tensors on MPS vs CPU.
Forward pass matches, backward pass diverges significantly.
"""

import torch
import sys

def check_complex_bmm_grads(B, N, seed=42):
    torch.manual_seed(seed)

    A_cpu = torch.randn(B, N, N, dtype=torch.complex64)
    x_cpu = torch.randn(B, N, 1, dtype=torch.complex64)

    # --- CPU reference ---
    A_c = A_cpu.clone().requires_grad_(True)
    x_c = x_cpu.clone().requires_grad_(True)
    y_c = torch.bmm(A_c, x_c)
    loss_c = (y_c.real ** 2 + y_c.imag ** 2).sum()
    loss_c.backward()

    # --- MPS ---
    A_m = A_cpu.clone().to("mps").requires_grad_(True)
    x_m = x_cpu.clone().to("mps").requires_grad_(True)
    y_m = torch.bmm(A_m, x_m)
    loss_m = (y_m.real ** 2 + y_m.imag ** 2).sum()
    loss_m.backward()

    # --- Compare ---
    fwd_err = (y_m.cpu() - y_c.detach()).abs().max().item()
    loss_err = abs(loss_m.item() - loss_c.item())

    gx_cpu = x_c.grad
    gx_mps = x_m.grad.cpu()
    gA_cpu = A_c.grad
    gA_mps = A_m.grad.cpu()

    gx_abs = (gx_mps - gx_cpu).abs().max().item()
    gx_rel = gx_abs / gx_cpu.abs().max().item()
    gA_abs = (gA_mps - gA_cpu).abs().max().item()
    gA_rel = gA_abs / gA_cpu.abs().max().item()

    return {
        "B": B, "N": N,
        "fwd_max_err": fwd_err,
        "loss_abs_err": loss_err,
        "grad_x_max_abs_err": gx_abs,
        "grad_x_max_rel_err": gx_rel,
        "grad_A_max_abs_err": gA_abs,
        "grad_A_max_rel_err": gA_rel,
    }


def check_real_bmm_grads(B, N, seed=42):
    """Same test but with float32 -- should be fine."""
    torch.manual_seed(seed)

    A_cpu = torch.randn(B, N, N, dtype=torch.float32)
    x_cpu = torch.randn(B, N, 1, dtype=torch.float32)

    A_c = A_cpu.clone().requires_grad_(True)
    x_c = x_cpu.clone().requires_grad_(True)
    loss_c = (torch.bmm(A_c, x_c) ** 2).sum()
    loss_c.backward()

    A_m = A_cpu.clone().to("mps").requires_grad_(True)
    x_m = x_cpu.clone().to("mps").requires_grad_(True)
    loss_m = (torch.bmm(A_m, x_m) ** 2).sum()
    loss_m.backward()

    gx_abs = (x_m.grad.cpu() - x_c.grad).abs().max().item()
    gx_rel = gx_abs / x_c.grad.abs().max().item()
    gA_abs = (A_m.grad.cpu() - A_c.grad).abs().max().item()
    gA_rel = gA_abs / A_c.grad.abs().max().item()

    return {
        "B": B, "N": N,
        "grad_x_max_rel_err": gx_rel,
        "grad_A_max_rel_err": gA_rel,
    }


def training_divergence_demo():
    """Show that training a trivial linear model via complex bmm diverges on MPS."""
    torch.manual_seed(0)
    N = 64

    A = torch.randn(1, N, N, dtype=torch.complex64)
    target = torch.randn(1, N, 1, dtype=torch.complex64)

    results = {}
    for dev_name in ["cpu", "mps"]:
        A_d = A.to(dev_name)
        tgt_d = target.to(dev_name)
        x = torch.randn(1, N, 1, dtype=torch.complex64, device=dev_name, requires_grad=True)
        opt = torch.optim.SGD([x], lr=1e-4)

        losses = []
        for step in range(50):
            opt.zero_grad()
            y = torch.bmm(A_d, x)
            loss = ((y - tgt_d).real ** 2 + (y - tgt_d).imag ** 2).sum()
            loss.backward()
            opt.step()
            losses.append(loss.item())

        results[dev_name] = losses

    return results


if __name__ == "__main__":
    print(f"PyTorch version : {torch.__version__}")
    print(f"Python version  : {sys.version.split()[0]}")
    print(f"MPS available   : {torch.backends.mps.is_available()}")
    print()

    if not torch.backends.mps.is_available():
        print("MPS not available, cannot reproduce.")
        sys.exit(1)

    # 1. Gradient correctness
    print("=" * 60)
    print("1. Gradient correctness: complex64 bmm")
    print("=" * 60)
    for B, N in [(1, 16), (4, 64), (8, 128)]:
        r = check_complex_bmm_grads(B, N)
        print(f"  B={r['B']:2d}, N={r['N']:3d}  |  "
              f"fwd_err={r['fwd_max_err']:.2e}  "
              f"grad_x_rel={r['grad_x_max_rel_err']:.2e}  "
              f"grad_A_rel={r['grad_A_max_rel_err']:.2e}")

    print()
    print("=" * 60)
    print("2. Gradient correctness: float32 bmm (control)")
    print("=" * 60)
    for B, N in [(1, 16), (4, 64), (8, 128)]:
        r = check_real_bmm_grads(B, N)
        print(f"  B={r['B']:2d}, N={r['N']:3d}  |  "
              f"grad_x_rel={r['grad_x_max_rel_err']:.2e}  "
              f"grad_A_rel={r['grad_A_max_rel_err']:.2e}")

    # 2. Training divergence
    print()
    print("=" * 60)
    print("3. Training demo: SGD on loss = |A @ x - target|^2")
    print("=" * 60)
    tr = training_divergence_demo()
    print(f"  {'step':>5s}  {'CPU loss':>12s}  {'MPS loss':>12s}")
    for i in range(0, 50, 5):
        print(f"  {i:5d}  {tr['cpu'][i]:12.4e}  {tr['mps'][i]:12.4e}")
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

[MPS] Incorrect gradients for torch.bmm with complex64 tensors

Bug

torch.bmm with complex64 (or complex128) tensors on MPS produces correct forward results but completely wrong gradients. The relative error in the backward pass is ~100-200%, making any gradient-based optimization with complex matrix operations on MPS silently produce garbage.

Real-valued (float32) bmm gradients on MPS are correct (relative error ~1e-7).

Reproducer

import torch

torch.manual_seed(42)
B, N = 4, 64

A_cpu = torch.randn(B, N, N, dtype=torch.complex64)
x_cpu = torch.randn(B, N, 1, dtype=torch.complex64)

# CPU reference
A_c = A_cpu.clone().requires_grad_(True)
x_c = x_cpu.clone().requires_grad_(True)
loss_c = (torch.bmm(A_c, x_c).abs() ** 2).sum()
loss_c.backward()

# MPS
A_m = A_cpu.clone().to("mps").requires_grad_(True)
x_m = x_cpu.clone().to("mps").requires_grad_(True)
loss_m = (torch.bmm(A_m, x_m).abs() ** 2).sum()
loss_m.backward()

# Compare
print("Forward max error:", (torch.bmm(A_m, x_m).cpu() - torch.bmm(A_c, x_c)).abs().max().item())
print("grad_x rel error:", (x_m.grad.cpu() - x_c.grad).abs().max().item() / x_c.grad.abs().max().item())
print("grad_A rel error:", (A_m.grad.cpu() - A_c.grad).abs().max().item() / A_c.grad.abs().max().item())

Output:

Forward max error: 2.384185791015625e-06
grad_x rel error: 1.14
grad_A rel error: 1.90

Forward error is float32 precision (~1e-6). Gradient error is ~100-190% -- essentially random relative to the correct values.

Extended results

Full script: repro.py (also pasted below)

1. Gradient correctness: complex64 bmm

BatchNForward max errgrad_x rel errgrad_A rel err
1167.54e-071.18e+001.98e+00
4642.38e-061.14e+001.90e+00
81283.81e-061.20e+001.47e+00

2. Gradient correctness: float32 bmm (control)

BatchNgrad_x rel errgrad_A rel err
1161.60e-071.67e-07
4641.07e-078.83e-08
81281.97e-071.62e-07

3. Training divergence

Simple SGD minimizing |A @ x - target|^2 with fixed complex A and target, optimizing x:

StepCPU lossMPS loss
05.8277e+038.8978e+03
103.2032e+039.6224e+03
201.9525e+031.0719e+04
301.3085e+031.2230e+04
409.4771e+021.4217e+04

CPU converges normally. MPS diverges because the gradients point in the wrong direction.

Impact

This makes MPS unusable for any training involving complex-valued linear algebra, including:

  • Complex-valued neural networks
  • Any model using complex matrix multiplications in the loss

The bug is silent -- no error or warning is raised. The forward pass is correct, so inference works fine. Only training is affected.

Workaround

Split complex matrices into real/imaginary float32 parts and compute the complex product manually:

def complex_bmm_real(M_re, M_im, v_re, v_im):
    vr, vi = v_re.unsqueeze(-1), v_im.unsqueeze(-1)
    out_re = torch.bmm(M_re, vr) - torch.bmm(M_im, vi)
    out_im = torch.bmm(M_re, vi) + torch.bmm(M_im, vr)
    return out_re.squeeze(-1), out_im.squeeze(-1)

This gives correct gradients on MPS since float32 bmm backward is correct.

Environment

  • PyTorch: 2.10.0 (pip, commit 449b176)
  • Python: 3.13.12
  • macOS: 26.2 (build 25C56)
  • Hardware: Apple M1 Max
  • Install: pip install torch

Full reproducer script

<details> <summary>repro.py</summary>
"""Minimal reproducer: MPS complex64 bmm backward gives wrong gradients.

Compares gradients of torch.bmm with complex64 tensors on MPS vs CPU.
Forward pass matches, backward pass diverges significantly.
"""

import torch
import sys

def check_complex_bmm_grads(B, N, seed=42):
    torch.manual_seed(seed)

    A_cpu = torch.randn(B, N, N, dtype=torch.complex64)
    x_cpu = torch.randn(B, N, 1, dtype=torch.complex64)

    # --- CPU reference ---
    A_c = A_cpu.clone().requires_grad_(True)
    x_c = x_cpu.clone().requires_grad_(True)
    y_c = torch.bmm(A_c, x_c)
    loss_c = (y_c.real ** 2 + y_c.imag ** 2).sum()
    loss_c.backward()

    # --- MPS ---
    A_m = A_cpu.clone().to("mps").requires_grad_(True)
    x_m = x_cpu.clone().to("mps").requires_grad_(True)
    y_m = torch.bmm(A_m, x_m)
    loss_m = (y_m.real ** 2 + y_m.imag ** 2).sum()
    loss_m.backward()

    # --- Compare ---
    fwd_err = (y_m.cpu() - y_c.detach()).abs().max().item()
    loss_err = abs(loss_m.item() - loss_c.item())

    gx_cpu = x_c.grad
    gx_mps = x_m.grad.cpu()
    gA_cpu = A_c.grad
    gA_mps = A_m.grad.cpu()

    gx_abs = (gx_mps - gx_cpu).abs().max().item()
    gx_rel = gx_abs / gx_cpu.abs().max().item()
    gA_abs = (gA_mps - gA_cpu).abs().max().item()
    gA_rel = gA_abs / gA_cpu.abs().max().item()

    return {
        "B": B, "N": N,
        "fwd_max_err": fwd_err,
        "loss_abs_err": loss_err,
        "grad_x_max_abs_err": gx_abs,
        "grad_x_max_rel_err": gx_rel,
        "grad_A_max_abs_err": gA_abs,
        "grad_A_max_rel_err": gA_rel,
    }


def check_real_bmm_grads(B, N, seed=42):
    """Same test but with float32 -- should be fine."""
    torch.manual_seed(seed)

    A_cpu = torch.randn(B, N, N, dtype=torch.float32)
    x_cpu = torch.randn(B, N, 1, dtype=torch.float32)

    A_c = A_cpu.clone().requires_grad_(True)
    x_c = x_cpu.clone().requires_grad_(True)
    loss_c = (torch.bmm(A_c, x_c) ** 2).sum()
    loss_c.backward()

    A_m = A_cpu.clone().to("mps").requires_grad_(True)
    x_m = x_cpu.clone().to("mps").requires_grad_(True)
    loss_m = (torch.bmm(A_m, x_m) ** 2).sum()
    loss_m.backward()

    gx_abs = (x_m.grad.cpu() - x_c.grad).abs().max().item()
    gx_rel = gx_abs / x_c.grad.abs().max().item()
    gA_abs = (A_m.grad.cpu() - A_c.grad).abs().max().item()
    gA_rel = gA_abs / A_c.grad.abs().max().item()

    return {
        "B": B, "N": N,
        "grad_x_max_rel_err": gx_rel,
        "grad_A_max_rel_err": gA_rel,
    }


def training_divergence_demo():
    """Show that training a trivial linear model via complex bmm diverges on MPS."""
    torch.manual_seed(0)
    N = 64

    A = torch.randn(1, N, N, dtype=torch.complex64)
    target = torch.randn(1, N, 1, dtype=torch.complex64)

    results = {}
    for dev_name in ["cpu", "mps"]:
        A_d = A.to(dev_name)
        tgt_d = target.to(dev_name)
        x = torch.randn(1, N, 1, dtype=torch.complex64, device=dev_name, requires_grad=True)
        opt = torch.optim.SGD([x], lr=1e-4)

        losses = []
        for step in range(50):
            opt.zero_grad()
            y = torch.bmm(A_d, x)
            loss = ((y - tgt_d).real ** 2 + (y - tgt_d).imag ** 2).sum()
            loss.backward()
            opt.step()
            losses.append(loss.item())

        results[dev_name] = losses

    return results


if __name__ == "__main__":
    print(f"PyTorch version : {torch.__version__}")
    print(f"Python version  : {sys.version.split()[0]}")
    print(f"MPS available   : {torch.backends.mps.is_available()}")
    print()

    if not torch.backends.mps.is_available():
        print("MPS not available, cannot reproduce.")
        sys.exit(1)

    # 1. Gradient correctness
    print("=" * 60)
    print("1. Gradient correctness: complex64 bmm")
    print("=" * 60)
    for B, N in [(1, 16), (4, 64), (8, 128)]:
        r = check_complex_bmm_grads(B, N)
        print(f"  B={r['B']:2d}, N={r['N']:3d}  |  "
              f"fwd_err={r['fwd_max_err']:.2e}  "
              f"grad_x_rel={r['grad_x_max_rel_err']:.2e}  "
              f"grad_A_rel={r['grad_A_max_rel_err']:.2e}")

    print()
    print("=" * 60)
    print("2. Gradient correctness: float32 bmm (control)")
    print("=" * 60)
    for B, N in [(1, 16), (4, 64), (8, 128)]:
        r = check_real_bmm_grads(B, N)
        print(f"  B={r['B']:2d}, N={r['N']:3d}  |  "
              f"grad_x_rel={r['grad_x_max_rel_err']:.2e}  "
              f"grad_A_rel={r['grad_A_max_rel_err']:.2e}")

    # 2. Training divergence
    print()
    print("=" * 60)
    print("3. Training demo: SGD on loss = |A @ x - target|^2")
    print("=" * 60)
    tr = training_divergence_demo()
    print(f"  {'step':>5s}  {'CPU loss':>12s}  {'MPS loss':>12s}")
    for i in range(0, 50, 5):
        print(f"  {i:5d}  {tr['cpu'][i]:12.4e}  {tr['mps'][i]:12.4e}")
</details>

cc

@kulinseth @malfet (MPS backend maintainers)

Versions

ollecting environment information... PyTorch version: 2.10.0 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A OS: macOS 26.2 (arm64) GCC version: Could not collect Clang version: 17.0.0 (clang-1700.6.3.2) CMake version: Could not collect Libc version: N/A Python version: 3.13.12 (v3.13.12:1cbe4818347, Feb 3 2026, 13:36:53) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime) Python platform: macOS-26.2-arm64-arm-64bit-Mach-O Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A CPU: Apple M1 Max Versions of relevant libraries: [pip3] numpy==2.4.2 [pip3] torch==2.10.0 [conda] Could not collect

cc @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @anjali411 @dylanbespalko @mruberry @amjames @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

Fix Plan

To fix the issue with incorrect gradients for torch.bmm with complex64 tensors on MPS, we can use a workaround that splits complex matrices into real and imaginary parts and computes the complex product manually.

Step-by-Step Solution

  1. Split complex matrices: Split the complex matrices A and x into their real and imaginary parts.
  2. Compute complex product manually: Use the following formula to compute the complex product:
    • out_re = torch.bmm(A_re, x_re) - torch.bmm(A_im, x_im)
    • out_im = torch.bmm(A_re, x_im) + torch.bmm(A_im, x_re)

Here's an example code snippet:

def complex_bmm_real(M_re, M_im, v_re, v_im):
    """
    Compute the complex matrix product manually.
    
    Parameters:
    M_re (torch.Tensor): Real part of the matrix M.
    M_im (torch.Tensor): Imaginary part of the matrix M.
    v_re (torch.Tensor): Real part of the vector v.
    v_im (torch.Tensor): Imaginary part of the vector v.
    
    Returns:
    out_re (torch.Tensor): Real part of the output.
    out_im (torch.Tensor): Imaginary part of the output.
    """
    vr, vi = v_re.unsqueeze(-1), v_im.unsqueeze(-1)
    out_re = torch.bmm(M_re, vr) - torch.bmm(M_im, vi)
    out_im = torch.bmm(M_re, vi) + torch.bmm(M_im, vr)
    return out_re.squeeze(-1), out_im.squeeze(-1)

Example Usage

A = torch.randn(1, 64, 64, dtype=torch.complex64)
x = torch.randn(1, 64, 1, dtype=torch.complex64)

A_re, A_im = A.real, A.imag
x_re, x_im = x.real, x.imag

out_re, out_im = complex_bmm_real(A_re, A_im, x_re, x_im)

Verification

To verify that the fix worked, you can compare the gradients computed using the manual complex product with the gradients computed using the original torch.bmm function on the CPU.

Extra Tips

  • Make sure to use the correct data types for the real and imaginary parts of the matrices and vectors.
  • Be aware that this workaround may have performance implications, as it requires additional computations to split and combine the real and imaginary parts.
  • If you're using a version of PyTorch that supports complex numbers on MPS, you may want to consider updating to the latest version to see if the issue has been fixed.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING