pytorch - 💡(How to fix) Fix flex_attention with head_dim=192 fails on SM86 GPUs (Triton shared memory OOM) [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180278Fetched 2026-04-15 06:18:54
View on GitHub
Comments
0
Participants
1
Timeline
182
Reactions
0
Author
Participants
Timeline (top)
mentioned ×70subscribed ×70unsubscribed ×32labeled ×9

Discovered via pytorch/torchtitan CI running the deepseek_v3_flex+pp+fsdp+ep+sacop integration test on linux.g5.48xlarge (A10G) runners. The test uses the DeepSeek V3 debug model with FlexAttention and pipeline parallelism.

Full stack trace from CI:

File "torchtitan/models/common/attention.py", line 321, in forward
    out, aux = FlexAttention._compiled_flex_attn(
File "torch/_inductor/runtime/triton_heuristics.py", line 681, in _make_launchers
    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
torch._inductor.exc.InductorError: RuntimeError: No valid triton configs.
OutOfMemoryError: out of resource: triton_tem_fused_flex_attention_0
Required: 139776  Hardware limit: 101376

Ideally the Triton autotuner would fall back to smaller block sizes / fewer stages when the default config exceeds the target GPU's shared memory, rather than raising with "No valid triton configs."

cc @ezyang @gchanan @kadeng @msaroufim @ptrblck @eqy @jerryzh168 @tinglvv @nWEIdia @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @muchulee8 @amjames @aakhundov @coconutruben @jataylo @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

Error Message

torch._inductor.exc.InductorError: RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_tem_fused_flex_attention_0 Required: 139776 Hardware limit: 101376 Reducing block sizes or num_stages may help.

Root Cause

Discovered via pytorch/torchtitan CI running the deepseek_v3_flex+pp+fsdp+ep+sacop integration test on linux.g5.48xlarge (A10G) runners. The test uses the DeepSeek V3 debug model with FlexAttention and pipeline parallelism.

Full stack trace from CI:

File "torchtitan/models/common/attention.py", line 321, in forward
    out, aux = FlexAttention._compiled_flex_attn(
File "torch/_inductor/runtime/triton_heuristics.py", line 681, in _make_launchers
    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
torch._inductor.exc.InductorError: RuntimeError: No valid triton configs.
OutOfMemoryError: out of resource: triton_tem_fused_flex_attention_0
Required: 139776  Hardware limit: 101376

Ideally the Triton autotuner would fall back to smaller block sizes / fewer stages when the default config exceeds the target GPU's shared memory, rather than raising with "No valid triton configs."

cc @ezyang @gchanan @kadeng @msaroufim @ptrblck @eqy @jerryzh168 @tinglvv @nWEIdia @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @muchulee8 @amjames @aakhundov @coconutruben @jataylo @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

Code Example

torch._inductor.exc.InductorError: RuntimeError: No valid triton configs.
OutOfMemoryError: out of resource: triton_tem_fused_flex_attention_0
Required: 139776  Hardware limit: 101376
Reducing block sizes or num_stages may help.

---

import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

# DeepSeek V3 MLA attention dimensions:
#   qk_head_dim = qk_nope_head_dim(128) + qk_rope_head_dim(64) = 192
#   v_head_dim = 128
B, H, S = 1, 16, 2048
QK_DIM = 192
V_DIM = 128

device = "cuda"

q = torch.randn(B, H, S, QK_DIM, device=device, dtype=torch.bfloat16)
k = torch.randn(B, H, S, QK_DIM, device=device, dtype=torch.bfloat16)
v = torch.randn(B, H, S, V_DIM, device=device, dtype=torch.bfloat16)

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, B=B, H=None, Q_LEN=S, KV_LEN=S, device=device)

compiled_flex = torch.compile(flex_attention)
out = compiled_flex(q, k, v, block_mask=block_mask)
print(f"Success: output shape = {out.shape}")

---

File "torchtitan/models/common/attention.py", line 321, in forward
    out, aux = FlexAttention._compiled_flex_attn(
File "torch/_inductor/runtime/triton_heuristics.py", line 681, in _make_launchers
    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
torch._inductor.exc.InductorError: RuntimeError: No valid triton configs.
OutOfMemoryError: out of resource: triton_tem_fused_flex_attention_0
Required: 139776  Hardware limit: 101376
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile(flex_attention) fails on SM86 GPUs (e.g. A10G) when head_dim=192 (as used by DeepSeek V3's MLA attention: qk_nope_head_dim=128 + qk_rope_head_dim=64). The generated Triton kernel exceeds the GPU's shared memory limit.

Error:

torch._inductor.exc.InductorError: RuntimeError: No valid triton configs.
OutOfMemoryError: out of resource: triton_tem_fused_flex_attention_0
Required: 139776  Hardware limit: 101376
Reducing block sizes or num_stages may help.

The kernel needs 139,776 bytes of shared memory but SM86 only has 101,376 bytes available. This works on SM90+ (H100) which has more shared memory.

Repro

https://github.com/pytorch/torchtitan/actions/runs/24355616606/job/71121682738?pr=2947

<details> <summary>minified repro (not sure if OOMs on A10G)</summary>
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

# DeepSeek V3 MLA attention dimensions:
#   qk_head_dim = qk_nope_head_dim(128) + qk_rope_head_dim(64) = 192
#   v_head_dim = 128
B, H, S = 1, 16, 2048
QK_DIM = 192
V_DIM = 128

device = "cuda"

q = torch.randn(B, H, S, QK_DIM, device=device, dtype=torch.bfloat16)
k = torch.randn(B, H, S, QK_DIM, device=device, dtype=torch.bfloat16)
v = torch.randn(B, H, S, V_DIM, device=device, dtype=torch.bfloat16)

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, B=B, H=None, Q_LEN=S, KV_LEN=S, device=device)

compiled_flex = torch.compile(flex_attention)
out = compiled_flex(q, k, v, block_mask=block_mask)
print(f"Success: output shape = {out.shape}")

Run on an A10G (SM86) GPU. Should fail during the first torch.compile invocation.

</details>

Environment

  • GPU: NVIDIA A10G (SM86, 101,376 bytes shared memory per SM)
  • PyTorch: nightly (2026-04-13)
  • CUDA: 12.8

Context

Discovered via pytorch/torchtitan CI running the deepseek_v3_flex+pp+fsdp+ep+sacop integration test on linux.g5.48xlarge (A10G) runners. The test uses the DeepSeek V3 debug model with FlexAttention and pipeline parallelism.

Full stack trace from CI:

File "torchtitan/models/common/attention.py", line 321, in forward
    out, aux = FlexAttention._compiled_flex_attn(
File "torch/_inductor/runtime/triton_heuristics.py", line 681, in _make_launchers
    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
torch._inductor.exc.InductorError: RuntimeError: No valid triton configs.
OutOfMemoryError: out of resource: triton_tem_fused_flex_attention_0
Required: 139776  Hardware limit: 101376

Ideally the Triton autotuner would fall back to smaller block sizes / fewer stages when the default config exceeds the target GPU's shared memory, rather than raising with "No valid triton configs."

cc @ezyang @gchanan @kadeng @msaroufim @ptrblck @eqy @jerryzh168 @tinglvv @nWEIdia @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @muchulee8 @amjames @aakhundov @coconutruben @jataylo @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

extent analysis

TL;DR

Reducing block sizes or num_stages in the Triton kernel configuration may help alleviate the out-of-memory error when compiling flex_attention on SM86 GPUs.

Guidance

  • Verify that the error occurs due to exceeding the shared memory limit of 101,376 bytes on SM86 GPUs.
  • Consider reducing block sizes or num_stages in the Triton kernel configuration to decrease the required shared memory.
  • Test the repro code on an SM90+ (H100) GPU to confirm that the issue is specific to SM86 GPUs.
  • Investigate the possibility of modifying the torch.compile invocation to accept custom Triton kernel configurations that account for the shared memory limitations of SM86 GPUs.

Example

No code example is provided as the issue is related to a specific hardware limitation and the solution involves modifying the kernel configuration rather than the code itself.

Notes

The solution may not be applicable to all use cases, and further investigation is required to determine the optimal block sizes or num_stages for the specific workload.

Recommendation

Apply a workaround by reducing block sizes or num_stages in the Triton kernel configuration, as this may help alleviate the out-of-memory error on SM86 GPUs. This approach is recommended because it directly addresses the root cause of the issue, which is the excessive shared memory requirement of the default Triton kernel configuration.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING