pytorch - 💡(How to fix) Fix `flex_attention` + `mask_mod` closure over a derived sym-int [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#182606Fetched 2026-05-07 03:31:13
View on GitHub
Comments
0
Participants
1
Timeline
121
Reactions
0
Author
Participants
Timeline (top)
mentioned ×56subscribed ×56labeled ×8cross-referenced ×1

Error Message

InductorError: LoweringException: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype' target: flex_attention

Root Cause

Hypothesised root cause

Fix Action

Fix / Workaround

Workaround for users

Code Example

InductorError: LoweringException: AttributeError:
    'ShapeAsConstantBuffer' object has no attribute 'dtype'
  target: flex_attention

---

import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

def fn(q, k, v, current_pos: int):
    P_plus = current_pos + 1                  # derived sym-int OUTSIDE the closure
    def mask_mod(b, h, q_idx, kv_idx):
        return (kv_idx >= 0) & (kv_idx <= P_plus)
    bm = create_block_mask(mask_mod, B=None, H=None,
                            Q_LEN=q.shape[2], KV_LEN=k.shape[2],
                            device=q.device)
    return flex_attention(q, k, v, block_mask=bm)

B, H, M, KV_LEN, D = 1, 4, 1, 13, 64          # decode shapes; prefill also fails
q = torch.randn(B, H, M, D, device='cuda', dtype=torch.bfloat16)
k = torch.randn(B, H, KV_LEN, D, device='cuda', dtype=torch.bfloat16)
v = torch.randn(B, H, KV_LEN, D, device='cuda', dtype=torch.bfloat16)

compiled = torch.compile(fn, fullgraph=False)
compiled(q, k, v, current_pos=0)              # PASS — dynamo specializes on value
compiled(q, k, v, current_pos=5)              # FAIL — recompile with sym-int

---

call 1 (current_pos=0): OK
call 2 (current_pos=5): FAIL  InductorError: LoweringException:
    AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'

---

def fn(q, k, v, current_pos_t: torch.Tensor):
    def mask_mod(b, h, q_idx, kv_idx):
        p_plus = current_pos_t[0] + 1
        return (kv_idx >= 0) & (kv_idx <= p_plus)
    bm = create_block_mask(mask_mod, ...)
    return flex_attention(q, k, v, block_mask=bm)

compiled = torch.compile(fn)
for pos in range(5):
    compiled(q, k, v, current_pos_t=torch.tensor([pos], dtype=torch.int64, device='cuda'))
# all 5 PASS, single graph compiled
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When a torch.compiled function takes a Python int argument, computes a derived expression from it (e.g. + 1), and a mask_mod closure captures that derived sym-int and is used in flex_attention(...), Inductor crashes on the second call with a different int value with:

InductorError: LoweringException: AttributeError:
    'ShapeAsConstantBuffer' object has no attribute 'dtype'
  target: flex_attention

This appears to be a sibling of #157833 (closed 2025-07-11 as fixed). The simple-int-closure case from #157833 does work on the latest nightly, but a derived sym-int closure — a more general pattern that arises naturally e.g. in ring-buffer KV-cache decode — still hits the same ShapeAsConstantBuffer.get_dtype() failure.

Repro

import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

def fn(q, k, v, current_pos: int):
    P_plus = current_pos + 1                  # derived sym-int OUTSIDE the closure
    def mask_mod(b, h, q_idx, kv_idx):
        return (kv_idx >= 0) & (kv_idx <= P_plus)
    bm = create_block_mask(mask_mod, B=None, H=None,
                            Q_LEN=q.shape[2], KV_LEN=k.shape[2],
                            device=q.device)
    return flex_attention(q, k, v, block_mask=bm)

B, H, M, KV_LEN, D = 1, 4, 1, 13, 64          # decode shapes; prefill also fails
q = torch.randn(B, H, M, D, device='cuda', dtype=torch.bfloat16)
k = torch.randn(B, H, KV_LEN, D, device='cuda', dtype=torch.bfloat16)
v = torch.randn(B, H, KV_LEN, D, device='cuda', dtype=torch.bfloat16)

compiled = torch.compile(fn, fullgraph=False)
compiled(q, k, v, current_pos=0)              # PASS — dynamo specializes on value
compiled(q, k, v, current_pos=5)              # FAIL — recompile with sym-int

Output:

call 1 (current_pos=0): OK
call 2 (current_pos=5): FAIL  InductorError: LoweringException:
    AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'

Same failure with current_pos % 12, max(0, current_pos - 7), etc. Replacing P_plus = current_pos + 1 with current_pos directly inside the closure (i.e., raw sym-int, no arithmetic) does not trigger the bug — that's the case fixed in #157833.

Both prefill (q_seqlen > 1) and decode (q_seqlen = 1) reproduce; the failure surfaces in the parallel mask_mod_other_buffers = maybe_realize(...) calls:

  • torch/_inductor/kernel/flex/flex_attention.py:303 (prefill)
  • torch/_inductor/kernel/flex/flex_decoding.py:232 (decode)

Trigger conditions (all required)

  1. torch.compiled function takes a Python int argument.
  2. Function uses arithmetic on that int outside the closure to produce a derived sym-int (e.g., + 1, % R, max(0, x - W + 1)).
  3. mask_mod closure captures the derived sym-int.
  4. The closure is passed to create_block_mask; the BlockMask is used in flex_attention(...) inside the same compiled function.
  5. Function is called twice with different int values. First call: dynamo value-specializes (compiles successfully). Second call: dynamo retraces with a dynamic sym-int → crash.

A closure that captures the int directly (mask_mod body contains kv_idx <= current_pos) does not crash — the bug appears specific to derived sym-int closures.

Hypothesised root cause

The derived expression (s34 + 1, PythonMod(s34, 12), Max(0, s34 - 7)) is materialized as a separate ShapeAsConstantBuffer in mask_mod_other_buffers. realize_inputs(...) calls ExternKernel.copy_input(x) which invokes x.get_dtype(). ShapeAsConstantBuffer has no dtype attribute → AttributeError. The path works fine for tensor mask_mod_other_buffers (real TensorBox / StorageBox).

A direct sym-int (raw closure capture of current_pos) is propagated as a graph input that already exists in the FX graph signature — so it doesn't end up as a new mask_mod_other_buffers entry. The derived expression has no graph-input identity, so it's "fresh" and gets routed through the realize path.

Workaround for users

Pass the int as a [1]-shape int64 tensor instead. Dynamo treats it as a runtime input from the very first call (no value specialization → no recompile → no derived-sym-int closure). The kernel compiles once and is reused across all calls:

def fn(q, k, v, current_pos_t: torch.Tensor):
    def mask_mod(b, h, q_idx, kv_idx):
        p_plus = current_pos_t[0] + 1
        return (kv_idx >= 0) & (kv_idx <= p_plus)
    bm = create_block_mask(mask_mod, ...)
    return flex_attention(q, k, v, block_mask=bm)

compiled = torch.compile(fn)
for pos in range(5):
    compiled(q, k, v, current_pos_t=torch.tensor([pos], dtype=torch.int64, device='cuda'))
# all 5 PASS, single graph compiled

Validated: 5 consecutive calls with different current_pos_t values complete in ~3.4 s (compile) + 4 × ~0.3 ms (cached launches) on a single Triton kernel.

Related

  • #157833 — closed as fixed 2025-07-11; appears to have addressed the simple-int-closure case but not the derived-sym-int case.
  • #182275 — open, related ShapeAsConstantBuffer family bug in nested_compile_region.

Versions

PyTorch 2.13.0a0+gitdbeb1a8 (dbeb1a8baf2b94948cfc55ad590570602bcb29cb)

cc @chauhang @penguinwu @ezyang @bobrenjc93 @aditvenk @laithsakka @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `flex_attention` + `mask_mod` closure over a derived sym-int [1 participants]