pytorch - 💡(How to fix) Fix [sm_120] Non-deterministic segfault with sustained complex linear-algebra workload + allocator churn on Blackwell [3 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178367Fetched 2026-04-08 01:26:01
View on GitHub
Comments
3
Participants
2
Timeline
93
Reactions
0
Author
Participants
Timeline (top)
mentioned ×33subscribed ×33unsubscribed ×15labeled ×6

Root Cause

The root cause has not been isolated to a single operation. The crash requires both the complex linear-algebra workload and allocator churn (repeated large tensor alloc/free cycles). Without the allocator churn, the crash may not occur within 500 samples; with it, the crash occurs within 2–11 samples.

Fix Action

Fix / Workaround

Mitigations tested (none fully fix)

MitigationEffect
PYTORCH_NO_CUDA_MEMORY_CACHING=1Delays crash (more samples before segfault)
CUDA_LAUNCH_BLOCKING=1Delays crash further
PyTorch 2.7.1 ↔ 2.10.0No change
Python 3.10 ↔ 3.12No change
Driver 575 → 580No change

Code Example

"""
Minimal repro: Blackwell (sm_120) segfault with sustained complex linalg + allocator churn.
Crashes after ~2-11 samples on sm_120. Runs all 500 on sm_86.
"""
import torch, gc

class WorkloadState:
    """Pure allocator churn — allocates GPU tensors on init, freed on del.
    Does NOT participate in computation."""
    def __init__(self, F, T, M, N, K, device):
        self.W = torch.rand(N, F, K, device=device) + 1e-10
        self.H = torch.rand(N, K, T, device=device) + 1e-10
        self.G = torch.rand(N, M, device=device) + 1e-10
        self.X = torch.randn(F, T, M, dtype=torch.complex64, device=device)
        self.PSD = torch.einsum('nfk, nkt -> ftn', self.W, self.H)
        self.Y = torch.einsum('ftn, nm -> ftm', self.PSD, self.G) + 1e-10
        self.X_pow = (self.X * self.X.conj()).real + 1e-10
        self.Q = torch.eye(M, dtype=torch.complex64, device=device).unsqueeze(0).expand(F, -1, -1).contiguous()
        self.XX = torch.einsum('ftm, ftn -> ftmn', self.X, self.X.conj())
        self.eye = torch.eye(M, dtype=torch.complex64, device=device).unsqueeze(0).expand(F, -1, -1)

def iteration(F, T, M, N, K, device):
    W = torch.rand(N, F, K, device=device) + 1e-10
    H = torch.rand(N, K, T, device=device) + 1e-10
    G = torch.rand(N, M, device=device) + 1e-10
    X = torch.randn(F, T, M, dtype=torch.complex64, device=device)
    PSD = torch.einsum('nfk, nkt -> ftn', W, H)
    Y = torch.einsum('ftn, nm -> ftm', PSD, G) + 1e-10
    X_pow = (X * X.conj()).real + 1e-10

    for n in range(N):
        num = torch.einsum('kt, m, ftm, ftm -> fk', H[n], G[n], X_pow, Y.pow(-2))
        den = torch.einsum('kt, m, ftm -> fk', H[n], G[n], Y.pow(-1))
        W[n] = W[n] * torch.sqrt(num / (den + 1e-10)) + 1e-10
        PSD[:, :, n] = torch.einsum('fk, kt -> ft', W[n], H[n])
        Y = torch.einsum('ftn, nm -> ftm', PSD, G) + 1e-10
        num = torch.einsum('fk, m, ftm, ftm -> kt', W[n], G[n], X_pow, Y.pow(-2))
        den = torch.einsum('fk, m, ftm -> kt', W[n], G[n], Y.pow(-1))
        H[n] = H[n] * torch.sqrt(num / (den + 1e-10)) + 1e-10
        PSD[:, :, n] = torch.einsum('fk, kt -> ft', W[n], H[n])
        Y = torch.einsum('ftn, nm -> ftm', PSD, G) + 1e-10

    Q = torch.eye(M, dtype=torch.complex64, device=device).unsqueeze(0).expand(F, -1, -1).contiguous()
    XX = torch.einsum('ftm, ftn -> ftmn', X, X.conj())
    eye = torch.eye(M, dtype=torch.complex64, device=device).unsqueeze(0).expand(F, -1, -1)
    for m in range(M):
        V = torch.einsum('ftmn, ft -> fmn', XX, (1.0 / Y[..., m]).to(torch.complex64)) / T
        tmp = torch.linalg.solve(Q @ V, eye)
        t_m = tmp[..., m]
        d = torch.sqrt(torch.einsum('fi, fij, fj -> f', t_m.conj(), V, t_m))
        Q[:, m] = (t_m / d[:, None]).conj()

device = 'cuda'
print(f'PyTorch {torch.__version__}, CUDA {torch.version.cuda}')
print(f'GPU: {torch.cuda.get_device_name()} (sm_{torch.cuda.get_device_capability()[0]}{torch.cuda.get_device_capability()[1]})')
F, T, M, N, K = 256, 500, 6, 3, 64
for i in range(500):
    state = WorkloadState(F, T, M, N, K, device)
    for _ in range(200):
        iteration(F, T, M, N, K, device)
    del state; gc.collect(); torch.cuda.empty_cache()
    print(f'Sample {i+1}: OK', flush=True)
print('ALL PASSED')

---

PyTorch 2.7.1+cu128, CUDA 12.8
GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition (sm_120)
Sample 1: OK
Sample 2: OK

---

PyTorch 2.7.1+cu128, CUDA 12.8
GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition (sm_120)
Sample 1: OK
...
Sample 11: OK

---

PyTorch 2.5.1, CUDA 12.4
GPU: NVIDIA GeForce RTX 3090 (sm_86)
Sample 1: OK
...
Sample 500: OK
ALL 500 PASSED — no segfault
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

A sustained workload of batched complex64 einsum, torch.linalg.solve, in-place tensor slice updates, and repeated GPU memory allocation/deallocation (delgc.collect()empty_cache()) causes a non-deterministic segfault (SIGSEGV, exit code 139) on Blackwell GPUs (sm_120). The same code completes all 500 iterations without issue on Ampere (sm_86).

The root cause has not been isolated to a single operation. The crash requires both the complex linear-algebra workload and allocator churn (repeated large tensor alloc/free cycles). Without the allocator churn, the crash may not occur within 500 samples; with it, the crash occurs within 2–11 samples.

Environment

  • Crashes on: RTX PRO 6000 Blackwell Max-Q (sm_120), PyTorch 2.7.1+cu128 and 2.10.0+cu128, Python 3.10 and 3.12, Driver 580.126.20, CUDA 12.8, Ubuntu 24.04 kernel 6.17.0-19
  • Passes on: RTX 3090 (sm_86), PyTorch 2.5.1+cu124, CUDA 12.4 — same script, all 500 samples pass

Reproduction

"""
Minimal repro: Blackwell (sm_120) segfault with sustained complex linalg + allocator churn.
Crashes after ~2-11 samples on sm_120. Runs all 500 on sm_86.
"""
import torch, gc

class WorkloadState:
    """Pure allocator churn — allocates GPU tensors on init, freed on del.
    Does NOT participate in computation."""
    def __init__(self, F, T, M, N, K, device):
        self.W = torch.rand(N, F, K, device=device) + 1e-10
        self.H = torch.rand(N, K, T, device=device) + 1e-10
        self.G = torch.rand(N, M, device=device) + 1e-10
        self.X = torch.randn(F, T, M, dtype=torch.complex64, device=device)
        self.PSD = torch.einsum('nfk, nkt -> ftn', self.W, self.H)
        self.Y = torch.einsum('ftn, nm -> ftm', self.PSD, self.G) + 1e-10
        self.X_pow = (self.X * self.X.conj()).real + 1e-10
        self.Q = torch.eye(M, dtype=torch.complex64, device=device).unsqueeze(0).expand(F, -1, -1).contiguous()
        self.XX = torch.einsum('ftm, ftn -> ftmn', self.X, self.X.conj())
        self.eye = torch.eye(M, dtype=torch.complex64, device=device).unsqueeze(0).expand(F, -1, -1)

def iteration(F, T, M, N, K, device):
    W = torch.rand(N, F, K, device=device) + 1e-10
    H = torch.rand(N, K, T, device=device) + 1e-10
    G = torch.rand(N, M, device=device) + 1e-10
    X = torch.randn(F, T, M, dtype=torch.complex64, device=device)
    PSD = torch.einsum('nfk, nkt -> ftn', W, H)
    Y = torch.einsum('ftn, nm -> ftm', PSD, G) + 1e-10
    X_pow = (X * X.conj()).real + 1e-10

    for n in range(N):
        num = torch.einsum('kt, m, ftm, ftm -> fk', H[n], G[n], X_pow, Y.pow(-2))
        den = torch.einsum('kt, m, ftm -> fk', H[n], G[n], Y.pow(-1))
        W[n] = W[n] * torch.sqrt(num / (den + 1e-10)) + 1e-10
        PSD[:, :, n] = torch.einsum('fk, kt -> ft', W[n], H[n])
        Y = torch.einsum('ftn, nm -> ftm', PSD, G) + 1e-10
        num = torch.einsum('fk, m, ftm, ftm -> kt', W[n], G[n], X_pow, Y.pow(-2))
        den = torch.einsum('fk, m, ftm -> kt', W[n], G[n], Y.pow(-1))
        H[n] = H[n] * torch.sqrt(num / (den + 1e-10)) + 1e-10
        PSD[:, :, n] = torch.einsum('fk, kt -> ft', W[n], H[n])
        Y = torch.einsum('ftn, nm -> ftm', PSD, G) + 1e-10

    Q = torch.eye(M, dtype=torch.complex64, device=device).unsqueeze(0).expand(F, -1, -1).contiguous()
    XX = torch.einsum('ftm, ftn -> ftmn', X, X.conj())
    eye = torch.eye(M, dtype=torch.complex64, device=device).unsqueeze(0).expand(F, -1, -1)
    for m in range(M):
        V = torch.einsum('ftmn, ft -> fmn', XX, (1.0 / Y[..., m]).to(torch.complex64)) / T
        tmp = torch.linalg.solve(Q @ V, eye)
        t_m = tmp[..., m]
        d = torch.sqrt(torch.einsum('fi, fij, fj -> f', t_m.conj(), V, t_m))
        Q[:, m] = (t_m / d[:, None]).conj()

device = 'cuda'
print(f'PyTorch {torch.__version__}, CUDA {torch.version.cuda}')
print(f'GPU: {torch.cuda.get_device_name()} (sm_{torch.cuda.get_device_capability()[0]}{torch.cuda.get_device_capability()[1]})')
F, T, M, N, K = 256, 500, 6, 3, 64
for i in range(500):
    state = WorkloadState(F, T, M, N, K, device)
    for _ in range(200):
        iteration(F, T, M, N, K, device)
    del state; gc.collect(); torch.cuda.empty_cache()
    print(f'Sample {i+1}: OK', flush=True)
print('ALL PASSED')

Results

sm_120 (RTX PRO 6000 Blackwell) — crashes within 2–11 samples:

Run 1 (exit 139):

PyTorch 2.7.1+cu128, CUDA 12.8
GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition (sm_120)
Sample 1: OK
Sample 2: OK

Run 2 (exit 139):

PyTorch 2.7.1+cu128, CUDA 12.8
GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition (sm_120)
Sample 1: OK
...
Sample 11: OK

sm_86 (RTX 3090) — all 500 pass:

PyTorch 2.5.1, CUDA 12.4
GPU: NVIDIA GeForce RTX 3090 (sm_86)
Sample 1: OK
...
Sample 500: OK
ALL 500 PASSED — no segfault

Note: the sm_86 machine runs PyTorch 2.5.1+cu124, so this control does not isolate architecture from PyTorch/CUDA version. It confirms the workload itself is not inherently broken.

Mitigations tested (none fully fix)

MitigationEffect
PYTORCH_NO_CUDA_MEMORY_CACHING=1Delays crash (more samples before segfault)
CUDA_LAUNCH_BLOCKING=1Delays crash further
PyTorch 2.7.1 ↔ 2.10.0No change
Python 3.10 ↔ 3.12No change
Driver 575 → 580No change

What has NOT been isolated

  • Whether torch.linalg.solve is necessary, or the crash also occurs with only einsum + in-place updates
  • Whether empty_cache() / gc.collect() are necessary, or just accelerate an underlying issue
  • Which CUDA library (cuSOLVER, cuBLAS, or other) is involved
  • Sample counts are approximate — no torch.cuda.synchronize() before print, so async ops may still be in flight

Related issues

  • #145949 (Blackwell tracking)
  • #176426 (Triton codegen segfault on sm_120)

cc @malfet @eqy

extent analysis

Fix Plan

To address the non-deterministic segfault issue on Blackwell GPUs (sm_120) caused by a sustained workload of batched complex64 einsum, torch.linalg.solve, in-place tensor slice updates, and repeated GPU memory allocation/deallocation, we will implement the following steps:

  • Disable CUDA memory caching: Set the environment variable PYTORCH_NO_CUDA_MEMORY_CACHING to 1 to prevent CUDA memory caching, which can help delay the crash.
  • Use CUDA launch blocking: Set the environment variable CUDA_LAUNCH_BLOCKING to 1 to enable CUDA launch blocking, which can further delay the crash.
  • Synchronize CUDA operations: Add torch.cuda.synchronize() before printing sample status to ensure that all asynchronous CUDA operations are completed.
  • Update PyTorch and CUDA: Ensure that PyTorch and CUDA are updated to the latest versions.
  • Implement custom memory management: Consider implementing custom memory management using torch.cuda.memory_allocated() and torch.cuda.memory_reserved() to monitor and manage GPU memory allocation.

Example code changes:

import os
import torch

# Disable CUDA memory caching
os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1'

# Use CUDA launch blocking
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# ...

device = 'cuda'
print(f'PyTorch {torch.__version__}, CUDA {torch.version.cuda}')
print(f'GPU: {torch.cuda.get_device_name()} (sm_{torch.cuda.get_device_capability()[0]}{torch.cuda.get_device_capability()[1]})')
F, T, M, N, K = 256, 500, 6, 3, 64
for i in range(500):
    state = WorkloadState(F, T, M, N, K, device)
    for _ in range(200):
        iteration(F, T, M, N, K, device)
    del state; gc.collect(); torch.cuda.empty_cache()
    torch.cuda.synchronize()  # Synchronize CUDA operations
    print(f'Sample {i+1}: OK', flush=True)
print('ALL PASSED')

Verification

To verify that the fix worked, run the modified code on the Blackwell GPU (sm_120) and check if the segfault issue is resolved. Monitor the GPU memory allocation and CUDA operations to ensure that they are properly managed.

Extra Tips

  • Regularly update PyTorch and CUDA to the latest versions to ensure that the latest bug fixes and performance optimizations are applied.
  • Consider using tools like nvidia-smi and nvprof to monitor GPU memory allocation and performance.
  • Implement custom memory management and synchronization mechanisms to ensure that GPU memory allocation and CUDA operations are properly managed.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING