pytorch - ✅(Solved) Fix Dynamo does not guard on function.__defaults__, causing silent correctness bugs with flex_attention [1 pull requests, 3 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178365Fetched 2026-04-08 01:26:02
View on GitHub
Comments
3
Participants
2
Timeline
65
Reactions
0
Author
Participants
Timeline (top)
mentioned ×25subscribed ×25labeled ×8commented ×3

Root Cause

When Dynamo traces flex_attention, it extracts mask_mod from BlockMask.mask_mod and passes it to speculate_subgraph. The guard installed is only:

ID_MATCH: ___check_obj_id(L['block_mask'].mask_mod.__code__, <id>)

No guard is installed on mask_mod.__defaults__. Since all mask_mod closures in the loop are defined at the same source location, they share the same __code__ object. The guard passes, and the first graph (with _offset=0 baked in as a constant) is reused for all subsequent calls.

Fix Action

Workaround

Call torch._dynamo.reset() before each invocation with a different mask_mod to force a fresh compilation.

PR fix notes

PR #178401: Guard on function.defaults in CLOSURE_MATCH

Description (problem / solution / changelog)

Summary

Fixes #178365 - Dynamo does not guard on function.defaults, causing silent correctness bugs with flex_attention.

When a function is defined in a loop with different default argument values, each iteration creates a new function with the same code but different defaults. Previously, CLOSURE_MATCH only guarded on code identity, so Dynamo would reuse the first cached graph with stale constants from the first call's defaults baked in.

Fix

Added an ID_MATCH guard on defaults in CLOSURE_MATCH to ensure recompilation when default args change.

Test

Added test test_function_defaults_guarded that verifies recompilation when functions with different defaults are passed.

Authored with Claude

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo

Changed files

  • test/dynamo/test_misc.py (modified, +20/-0)
  • torch/_dynamo/guards.py (modified, +6/-1)

Code Example

"""
Minimal repro: Dynamo does not guard on function.__defaults__,
causing stale compiled graphs when mask_mod closures differ only
in their default-arg values.

flex_attention's mask_mod is inlined during compilation.  When a
compiled function is re-invoked with a new BlockMask whose mask_mod
has the same __code__ but different __defaults__, Dynamo reuses the
first graph (wrong offset baked in) instead of recompiling.
"""

import argparse

import torch
import torch._dynamo
from torch.nn.attention.flex_attention import (
    AuxRequest,
    create_block_mask,
    flex_attention,
)


@torch.compile(fullgraph=True)
def flex_chunk(q, k, v, block_mask, scale):
    out, aux = flex_attention(
        q, k, v, block_mask=block_mask, scale=scale, return_aux=AuxRequest(lse=True)
    )
    return out, aux.lse


def merge(out, lse, new_out, new_lse):
    lse, new_lse = lse.unsqueeze(-1), new_lse.unsqueeze(-1)
    mx = torch.maximum(lse, new_lse)
    e0, e1 = torch.exp(lse - mx), torch.exp(new_lse - mx)
    d = e0 + e1
    return (out * e0 + new_out * e1) / d, (mx + torch.log(d)).squeeze(-1)


@torch.compile(fullgraph=True)
def ref_attn(q, k, v, block_mask, scale):
    return flex_attention(q, k, v, block_mask=block_mask, scale=scale)


def main(fix_bug: bool = False):
    torch.manual_seed(42)
    B, H, S, D = 1, 1, 512, 16
    device = "cuda"
    NUM_CHUNKS = 4
    chunk_size = S // NUM_CHUNKS

    q = torch.randn(B, H, S, D, device=device)
    k = torch.randn(B, H, S, D, device=device)
    v = torch.randn(B, H, S, D, device=device)
    scale = D**-0.5

    merged_out = merged_lse = None
    for step in range(NUM_CHUNKS):
        kv_offset = step * chunk_size

        def mask_mod(b, h, q_idx, kv_idx, _offset=kv_offset):
            return q_idx >= kv_idx + _offset

        bm = create_block_mask(
            mask_mod, B=B, H=H, Q_LEN=S, KV_LEN=chunk_size, device=device
        )
        if fix_bug:
            torch._dynamo.reset()
        out, lse = flex_chunk(
            q, k[:, :, kv_offset : kv_offset + chunk_size],
            v[:, :, kv_offset : kv_offset + chunk_size], bm, scale,
        )
        if merged_out is None:
            merged_out, merged_lse = out, lse
        else:
            merged_out, merged_lse = merge(merged_out, merged_lse, out, lse)

    def causal(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx

    ref_bm = create_block_mask(causal, B=B, H=H, Q_LEN=S, KV_LEN=S, device=device)
    ref_out = ref_attn(q, k, v, ref_bm, scale)

    diff = (merged_out - ref_out).abs()
    max_diff = diff.max().item()
    num_bad = (diff > 1e-4).sum().item()
    total = diff.numel()
    print(f"Max abs diff:       {max_diff:.6f}")
    print(f"Mismatched (>1e-4): {num_bad} / {total} ({100*num_bad/total:.1f}%)")

    if max_diff < 1e-3:
        print("PASS - __defaults__ properly guarded")
    else:
        print(
            "FAIL - __defaults__ NOT guarded; mask_mod offset baked in from first call"
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--fix-bug", action="store_true")
    args = parser.parse_args()
    main(fix_bug=args.fix_bug)

---

Max abs diff:       0.299253
Mismatched (>1e-4): 6072 / 8192 (74.1%)
FAIL - __defaults__ NOT guarded; mask_mod offset baked in from first call

---

Max abs diff:       0.000000
Mismatched (>1e-4): 0 / 8192 (0.0%)
PASS - __defaults__ properly guarded

---

ID_MATCH: ___check_obj_id(L['block_mask'].mask_mod.__code__, <id>)
RAW_BUFFERClick to expand / collapse

Human Note: I found this while working on this PR -> https://github.com/pytorch/pytorch/pull/178357

Bug

Dynamo does not install guards on function.__defaults__ when a callable is extracted from a BlockMask and traced via speculate_subgraph. It only guards on function.__code__ (identity match). When multiple mask_mod closures share the same bytecode but differ in their default parameter values, Dynamo reuses the first compiled graph with a stale constant baked in, producing silently wrong results.

This affects any pattern where flex_attention is called in a loop with different BlockMask objects whose mask_mod functions differ only in default-arg-captured state (e.g., a KV offset for ring attention).

Repro

Reproduces on nightly (2.12.0.dev20260310+cu130) and main.

"""
Minimal repro: Dynamo does not guard on function.__defaults__,
causing stale compiled graphs when mask_mod closures differ only
in their default-arg values.

flex_attention's mask_mod is inlined during compilation.  When a
compiled function is re-invoked with a new BlockMask whose mask_mod
has the same __code__ but different __defaults__, Dynamo reuses the
first graph (wrong offset baked in) instead of recompiling.
"""

import argparse

import torch
import torch._dynamo
from torch.nn.attention.flex_attention import (
    AuxRequest,
    create_block_mask,
    flex_attention,
)


@torch.compile(fullgraph=True)
def flex_chunk(q, k, v, block_mask, scale):
    out, aux = flex_attention(
        q, k, v, block_mask=block_mask, scale=scale, return_aux=AuxRequest(lse=True)
    )
    return out, aux.lse


def merge(out, lse, new_out, new_lse):
    lse, new_lse = lse.unsqueeze(-1), new_lse.unsqueeze(-1)
    mx = torch.maximum(lse, new_lse)
    e0, e1 = torch.exp(lse - mx), torch.exp(new_lse - mx)
    d = e0 + e1
    return (out * e0 + new_out * e1) / d, (mx + torch.log(d)).squeeze(-1)


@torch.compile(fullgraph=True)
def ref_attn(q, k, v, block_mask, scale):
    return flex_attention(q, k, v, block_mask=block_mask, scale=scale)


def main(fix_bug: bool = False):
    torch.manual_seed(42)
    B, H, S, D = 1, 1, 512, 16
    device = "cuda"
    NUM_CHUNKS = 4
    chunk_size = S // NUM_CHUNKS

    q = torch.randn(B, H, S, D, device=device)
    k = torch.randn(B, H, S, D, device=device)
    v = torch.randn(B, H, S, D, device=device)
    scale = D**-0.5

    merged_out = merged_lse = None
    for step in range(NUM_CHUNKS):
        kv_offset = step * chunk_size

        def mask_mod(b, h, q_idx, kv_idx, _offset=kv_offset):
            return q_idx >= kv_idx + _offset

        bm = create_block_mask(
            mask_mod, B=B, H=H, Q_LEN=S, KV_LEN=chunk_size, device=device
        )
        if fix_bug:
            torch._dynamo.reset()
        out, lse = flex_chunk(
            q, k[:, :, kv_offset : kv_offset + chunk_size],
            v[:, :, kv_offset : kv_offset + chunk_size], bm, scale,
        )
        if merged_out is None:
            merged_out, merged_lse = out, lse
        else:
            merged_out, merged_lse = merge(merged_out, merged_lse, out, lse)

    def causal(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx

    ref_bm = create_block_mask(causal, B=B, H=H, Q_LEN=S, KV_LEN=S, device=device)
    ref_out = ref_attn(q, k, v, ref_bm, scale)

    diff = (merged_out - ref_out).abs()
    max_diff = diff.max().item()
    num_bad = (diff > 1e-4).sum().item()
    total = diff.numel()
    print(f"Max abs diff:       {max_diff:.6f}")
    print(f"Mismatched (>1e-4): {num_bad} / {total} ({100*num_bad/total:.1f}%)")

    if max_diff < 1e-3:
        print("PASS - __defaults__ properly guarded")
    else:
        print(
            "FAIL - __defaults__ NOT guarded; mask_mod offset baked in from first call"
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--fix-bug", action="store_true")
    args = parser.parse_args()
    main(fix_bug=args.fix_bug)

Without workaround:

Max abs diff:       0.299253
Mismatched (>1e-4): 6072 / 8192 (74.1%)
FAIL - __defaults__ NOT guarded; mask_mod offset baked in from first call

With --fix-bug (torch._dynamo.reset() before each call):

Max abs diff:       0.000000
Mismatched (>1e-4): 0 / 8192 (0.0%)
PASS - __defaults__ properly guarded

Root Cause

When Dynamo traces flex_attention, it extracts mask_mod from BlockMask.mask_mod and passes it to speculate_subgraph. The guard installed is only:

ID_MATCH: ___check_obj_id(L['block_mask'].mask_mod.__code__, <id>)

No guard is installed on mask_mod.__defaults__. Since all mask_mod closures in the loop are defined at the same source location, they share the same __code__ object. The guard passes, and the first graph (with _offset=0 baked in as a constant) is reused for all subsequent calls.

Expected Behavior

Dynamo should also guard on function.__defaults__ (or treat default parameter values as guardable state). When __defaults__ changes, a recompilation should be triggered so the new constant is baked into a fresh graph.

Workaround

Call torch._dynamo.reset() before each invocation with a different mask_mod to force a fresh compilation.

Versions

  • Nightly: 2.12.0.dev20260310+cu130
  • Also reproduced on main (dev build)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo @Chillee @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv @drisspg

extent analysis

Fix Plan

To fix the issue, we need to ensure that Dynamo guards on function.__defaults__ when tracing flex_attention. Here are the steps:

  • Modify the speculate_subgraph function to install a guard on mask_mod.__defaults__ in addition to mask_mod.__code__.
  • Update the ___check_obj_id function to also check the __defaults__ attribute of the function.

Example code:

def speculate_subgraph(...):
    # ...
    mask_mod = block_mask.mask_mod
    guards.append(___check_obj_id(mask_mod.__code__, id(mask_mod.__code__)))
    guards.append(___check_obj_id(mask_mod.__defaults__, id(mask_mod.__defaults__)))
    # ...

def ___check_obj_id(obj, expected_id):
    # ...
    if isinstance(obj, tuple):  # __defaults__ is a tuple
        return all(id(item) == expected_id for item in obj)
    # ...

Alternatively, you can use the workaround provided in the issue body, which is to call torch._dynamo.reset() before each invocation with a different mask_mod to force a fresh compilation.

Verification

To verify that the fix worked, you can run the provided repro code with the modified speculate_subgraph function and check that the output is correct. Specifically, you should see that the max_diff value is close to zero and the num_bad value is zero.

Extra Tips

  • Make sure to test the fix thoroughly to ensure that it does not introduce any regressions.
  • Consider adding a test case to the PyTorch test suite to ensure that the fix is not broken in the future.
  • If you are using a version of PyTorch that is older than the one that includes this fix, you may need to use the workaround provided in the issue body.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix Dynamo does not guard on function.__defaults__, causing silent correctness bugs with flex_attention [1 pull requests, 3 comments, 2 participants]