pytorch - 💡(How to fix) Fix [Bug] NaN gradients in varlen_attn backward pass when input length exceeds cu_seqlens[-1] [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#176793Fetched 2026-04-08 00:24:25
View on GitHub
Comments
1
Participants
1
Timeline
222
Reactions
0
Author
Participants
Assignees
Timeline (top)
mentioned ×105subscribed ×105labeled ×9assigned ×1

Error Message

  1. Explicit Error: If padding beyond cu_seqlens[-1] is strictly not allowed by the underlying SDPA/FlashAttention backend, a RuntimeError should be explicitly raised during the forward pass indicating a shape mismatch between the input tensor q, k, v and the maximum boundary defined by cu_seqlens.

Fix Action

Fix / Workaround

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 16 On-line CPU(s) list: 0-15 Vendor ID: AuthenticAMD Model name: AMD Ryzen 7 7700 8-Core Processor CPU family: 25 Model: 97 Thread(s) per core: 2 Core(s) per socket: 8 Socket(s): 1 Stepping: 2 BogoMIPS: 7585.56 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr arat npt nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm Virtualization: AMD-V Hypervisor vendor: Microsoft Virtualization type: full L1d cache: 256 KiB (8 instances) L1i cache: 256 KiB (8 instances) L2 cache: 8 MiB (8 instances) L3 cache: 32 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-15 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Code Example

import torch
import torch.nn as nn
from torch.nn.attention.varlen import varlen_attn

class Attn(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.head_dim = d_model // nhead
        self.to_qkv = nn.Linear(d_model, 3 * d_model)

    def forward(self, x, cu_seqlens, max_seqlen):
        q, k, v = (
            self.to_qkv(x).view(-1, self.nhead * 3, self.head_dim).chunk(3, dim=-2)
        )

        out = varlen_attn(
            q,
            k,
            v,
            cu_seq_q=cu_seqlens,
            cu_seq_k=cu_seqlens,
            max_q=max_seqlen,
            max_k=max_seqlen,
            is_causal=False,
        )
        out = out.view(-1, self.d_model)
        return out


d_model = 1024
nhead = 16
device = "cuda"

# Create random input data
SEQ_LENGTHS = [144, 288, 512]
TOTAL_TOKENS = sum(SEQ_LENGTHS)
cu_seqlens = torch.tensor(
    [0] + list(torch.tensor(SEQ_LENGTHS).cumsum(0).tolist()), dtype=torch.int32
)
max_seqlen = max(SEQ_LENGTHS)

# Adding padding tokens causes NaNs in the backward pass
x = torch.randn(TOTAL_TOKENS + 2, d_model) 

model = Attn(d_model, nhead).to(device)

print("Input shape:", x.shape)
print("cu_seqlens:", cu_seqlens)
print("max_seqlen:", max_seqlen)

with torch.autocast(device):
    out = model(x.to(device), cu_seqlens.to(device), max_seqlen)

    # Compute a dummy loss
    loss = out[: cu_seqlens[-1]].abs().sum()
    loss.backward()
    
    # Check for NaNs
    for name, param in model.named_parameters():
        if param.grad is not None and torch.isnan(param.grad).any():
            print(f"NaN detected in gradients for {name}!")
            break

---

Input shape: torch.Size([946, 1024])
cu_seqlens: tensor([  0, 144, 432, 944], dtype=torch.int32)
max_seqlen: 512
NaN detected in gradients for to_qkv.weight!
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When using torch.nn.attention.varlen.varlen_attn, padding the input tensor so that its total length is greater than the total number of tokens defined by cu_seqlens[-1] causes NaN gradients during the backward pass.

The forward pass executes without raising any shape mismatch or out-of-bounds errors, but calling .backward() results in NaN values propagating through the gradients of the model's parameters.

To Reproduce

Here is a minimal reproducible example demonstrating the issue. By adding just 2 padding tokens to the total token count (TOTAL_TOKENS + 2), NaN gradients are triggered.

import torch
import torch.nn as nn
from torch.nn.attention.varlen import varlen_attn

class Attn(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.head_dim = d_model // nhead
        self.to_qkv = nn.Linear(d_model, 3 * d_model)

    def forward(self, x, cu_seqlens, max_seqlen):
        q, k, v = (
            self.to_qkv(x).view(-1, self.nhead * 3, self.head_dim).chunk(3, dim=-2)
        )

        out = varlen_attn(
            q,
            k,
            v,
            cu_seq_q=cu_seqlens,
            cu_seq_k=cu_seqlens,
            max_q=max_seqlen,
            max_k=max_seqlen,
            is_causal=False,
        )
        out = out.view(-1, self.d_model)
        return out


d_model = 1024
nhead = 16
device = "cuda"

# Create random input data
SEQ_LENGTHS = [144, 288, 512]
TOTAL_TOKENS = sum(SEQ_LENGTHS)
cu_seqlens = torch.tensor(
    [0] + list(torch.tensor(SEQ_LENGTHS).cumsum(0).tolist()), dtype=torch.int32
)
max_seqlen = max(SEQ_LENGTHS)

# Adding padding tokens causes NaNs in the backward pass
x = torch.randn(TOTAL_TOKENS + 2, d_model) 

model = Attn(d_model, nhead).to(device)

print("Input shape:", x.shape)
print("cu_seqlens:", cu_seqlens)
print("max_seqlen:", max_seqlen)

with torch.autocast(device):
    out = model(x.to(device), cu_seqlens.to(device), max_seqlen)

    # Compute a dummy loss
    loss = out[: cu_seqlens[-1]].abs().sum()
    loss.backward()
    
    # Check for NaNs
    for name, param in model.named_parameters():
        if param.grad is not None and torch.isnan(param.grad).any():
            print(f"NaN detected in gradients for {name}!")
            break

Output:

Input shape: torch.Size([946, 1024])
cu_seqlens: tensor([  0, 144, 432, 944], dtype=torch.int32)
max_seqlen: 512
NaN detected in gradients for to_qkv.weight!

Expected behavior

There are two expected ways this should be handled to prevent silent gradient corruption:

  1. Graceful Handling: The function should ideally ignore the extra padded tokens (indices beyond cu_seqlens[-1]) during the backward pass without throwing NaNs.
  2. Explicit Error: If padding beyond cu_seqlens[-1] is strictly not allowed by the underlying SDPA/FlashAttention backend, a RuntimeError should be explicitly raised during the forward pass indicating a shape mismatch between the input tensor q, k, v and the maximum boundary defined by cu_seqlens.

Versions

Collecting environment information... PyTorch version: 2.10.0+cu128 Is debug build: False CUDA used to build PyTorch: 12.8 ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 13 (trixie) (x86_64) GCC version: (Debian 14.2.0-19) 14.2.0 Clang version: Could not collect CMake version: version 3.31.6 Libc version: glibc-2.41

Python version: 3.12.11 (main, Aug 28 2025, 17:07:59) [Clang 20.1.4 ] (64-bit runtime) Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.41 Is CUDA available: True CUDA runtime version: 13.0.88 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA GeForce RTX 5060 Ti Nvidia driver version: 576.88 cuDNN version: Could not collect Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 16 On-line CPU(s) list: 0-15 Vendor ID: AuthenticAMD Model name: AMD Ryzen 7 7700 8-Core Processor CPU family: 25 Model: 97 Thread(s) per core: 2 Core(s) per socket: 8 Socket(s): 1 Stepping: 2 BogoMIPS: 7585.56 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr arat npt nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm Virtualization: AMD-V Hypervisor vendor: Microsoft Virtualization type: full L1d cache: 256 KiB (8 instances) L1i cache: 256 KiB (8 instances) L2 cache: 8 MiB (8 instances) L3 cache: 32 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-15 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] Could not collect [conda] Could not collect

cc @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @mruberry @jbschlosser @walterddr @mikaylagawarecki @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @malfet @drisspg @liangel-02 @howardzhang-cv

extent analysis

Problem Summary

The issue is caused by padding the input tensor beyond the total number of tokens defined by cu_seqlens[-1] when using torch.nn.attention.varlen.varlen_attn. This results in NaN gradients during the backward pass.

Root Cause Analysis

The root cause is likely due to the fact that the varlen_attn function does not handle padding tokens correctly during the backward pass.

Fix Plan

To fix this issue, we need to modify the varlen_attn function to ignore the extra padded tokens during the backward pass. We can do this by adding a check to ensure that the input tensor does not exceed the maximum boundary defined by cu_seqlens.

Step-by-Step Solution Plan

  1. Modify the varlen_attn function to add a check for padding tokens during the backward pass.
def varlen_attn(q, k, v, cu_seq_q, cu_seq_k, max_q, max_k, is_causal):
    # Add a check for padding tokens
    if q.shape[0] > max_q:
        q = q[:max_q]
        k = k[:max_q]
        v = v[:max_q]

    # Rest of the function remains the same
  1. Update the Attn class to use the modified varlen_attn function.
class Attn(nn.Module):
    def forward(self, x, cu_seqlens, max_seqlen):
        q, k, v = (
            self.to_qkv(x).view(-1, self.nhead * 3, self.head_dim).chunk(3, dim=-2)
        )

        out = varlen_attn(
            q,
            k,
            v,
            cu_seq_q=cu_seqlens,
            cu_seq_k=cu_seqlens,
            max_q

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [Bug] NaN gradients in varlen_attn backward pass when input length exceeds cu_seqlens[-1] [1 comments, 1 participants]