pytorch - 💡(How to fix) Fix [HOP] Alias/mutation fake traces should ignore discarded unbacked symbols

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

FlexAttention can fail during HOP metadata/functionalization when it is used with dynamic block-mask inputs.

The relevant path is:

torch.compile(flex_attention) -> flex_attention_hop -> FlexAttention functionalization -> score_mod alias/mutation query -> nested fake trace.

That nested trace is only a diagnostic query: it answers whether score_mod may mutate captured inputs. The graph is discarded. With dynamic block-mask metadata, fake propagation can leave fresh unbacked symbols pending in the surrounding ShapeEnv. Those symbols are not semantic outputs of the score-mod alias/mutation query, but compute_unbacked_bindings can still require them to appear in the returned value and raise PendingUnbackedSymbolNotFound.

Root Cause

if not torch.cuda.is_available(): raise RuntimeError("This repro needs CUDA because FlexAttention backward is CUDA-only.")

Code Example

import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention


if not torch.cuda.is_available():
    raise RuntimeError("This repro needs CUDA because FlexAttention backward is CUDA-only.")


device = "cuda"
dtype = torch.float16


def score_mod(score, batch, head, q_idx, kv_idx):
    return score


compiled_create_block_mask = torch.compile(
    create_block_mask,
    dynamic=True,
    fullgraph=True,
)


def create_dynamic_block_mask(q_batch, kv_batch):
    q_len = q_batch.size(0)
    kv_len = kv_batch.size(0)

    def mask_mod(batch, head, q_idx, kv_idx):
        q_group = q_batch[q_idx]
        kv_group = kv_batch[kv_idx]
        return (q_group == kv_group) & (q_group != -1) & (kv_group != -1)

    return compiled_create_block_mask(
        mask_mod,
        B=None,
        H=None,
        Q_LEN=q_len,
        KV_LEN=kv_len,
        device=device,
        BLOCK_SIZE=128,
    )


groups = torch.zeros(128, dtype=torch.int64, device=device)
block_mask = create_dynamic_block_mask(groups, groups)

q = torch.randn(1, 1, 128, 64, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(1, 1, 128, 64, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(1, 1, 128, 64, device=device, dtype=dtype, requires_grad=True)

compiled_flex_attention = torch.compile(
    flex_attention,
    fullgraph=True,
    dynamic=True,
    backend="aot_eager",
)

out = compiled_flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)
out.sum().backward()
torch.cuda.synchronize()
print("ok", out.shape)

---

PendingUnbackedSymbolNotFound: Pending unbacked symbols {...} not in returned outputs ...
RAW_BUFFERClick to expand / collapse

Summary

FlexAttention can fail during HOP metadata/functionalization when it is used with dynamic block-mask inputs.

The relevant path is:

torch.compile(flex_attention) -> flex_attention_hop -> FlexAttention functionalization -> score_mod alias/mutation query -> nested fake trace.

That nested trace is only a diagnostic query: it answers whether score_mod may mutate captured inputs. The graph is discarded. With dynamic block-mask metadata, fake propagation can leave fresh unbacked symbols pending in the surrounding ShapeEnv. Those symbols are not semantic outputs of the score-mod alias/mutation query, but compute_unbacked_bindings can still require them to appear in the returned value and raise PendingUnbackedSymbolNotFound.

Reproducer

import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention


if not torch.cuda.is_available():
    raise RuntimeError("This repro needs CUDA because FlexAttention backward is CUDA-only.")


device = "cuda"
dtype = torch.float16


def score_mod(score, batch, head, q_idx, kv_idx):
    return score


compiled_create_block_mask = torch.compile(
    create_block_mask,
    dynamic=True,
    fullgraph=True,
)


def create_dynamic_block_mask(q_batch, kv_batch):
    q_len = q_batch.size(0)
    kv_len = kv_batch.size(0)

    def mask_mod(batch, head, q_idx, kv_idx):
        q_group = q_batch[q_idx]
        kv_group = kv_batch[kv_idx]
        return (q_group == kv_group) & (q_group != -1) & (kv_group != -1)

    return compiled_create_block_mask(
        mask_mod,
        B=None,
        H=None,
        Q_LEN=q_len,
        KV_LEN=kv_len,
        device=device,
        BLOCK_SIZE=128,
    )


groups = torch.zeros(128, dtype=torch.int64, device=device)
block_mask = create_dynamic_block_mask(groups, groups)

q = torch.randn(1, 1, 128, 64, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(1, 1, 128, 64, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(1, 1, 128, 64, device=device, dtype=dtype, requires_grad=True)

compiled_flex_attention = torch.compile(
    flex_attention,
    fullgraph=True,
    dynamic=True,
    backend="aot_eager",
)

out = compiled_flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)
out.sum().backward()
torch.cuda.synchronize()
print("ok", out.shape)

Expected behavior

The compiled FlexAttention forward and backward should complete. The score-mod alias/mutation query should be allowed to discard fresh unbacked symbols created only for HOP metadata/fake propagation.

Actual behavior

The FlexAttention HOP metadata path can fail with:

PendingUnbackedSymbolNotFound: Pending unbacked symbols {...} not in returned outputs ...

The pending symbols are not part of the user-visible output. They are created during the internal fake/metadata trace used by the HOP alias/mutation check.

Notes

The internal HOP trace should run under the existing "ignore fresh unbacked symbols" contract, and compute_unbacked_bindings should honor that contract by treating pending symbols as ignorable while that TLS is active. This keeps diagnostic fake traces from imposing output-binding requirements on symbols that are intentionally discarded with the diagnostic graph.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

The compiled FlexAttention forward and backward should complete. The score-mod alias/mutation query should be allowed to discard fresh unbacked symbols created only for HOP metadata/fake propagation.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [HOP] Alias/mutation fake traces should ignore discarded unbacked symbols