pytorch - ✅(Solved) Fix [CPU] `torch.compile` inductor C++ codegen fails when closure-composed mask functions are used with `F.scaled_dot_product_attention` across recompilations [1 pull requests, 1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178244Fetched 2026-04-08 01:21:07
View on GitHub
Comments
1
Participants
1
Timeline
21
Reactions
0
Participants
Timeline (top)
mentioned ×8subscribed ×8labeled ×3commented ×1

This pattern is used in HuggingFace transformers library's unified mask creation (masking_utils.py) which composes causal_mask_function, padding_mask_function, sliding_window_mask_function, etc. via and_masks() / or_masks(). This means any model using torch.compile with padding tokens on CPU will hit this bug (e.g. Mllama, Llama with padding, etc.).

Upstream issue: huggingface/transformers#44458 Upstream workaround PR: huggingface/transformers#44845 (avoids closure composition for padding mask)

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01

Error Message

""" torch.compile inductor bug: closure-based mask + SDPA fails on recompilation (CPU only).

Same mask logic written inline works fine. Only the closure composition pattern triggers a C++ codegen error (undeclared variable) on recompilation. """ import torch import torch.nn.functional as F

---- Closure-based mask construction (triggers bug) ----

def causal_fn(b, h, q, kv): return kv <= q

def padding_fn(padding_mask): def inner(b, h, q, kv): return padding_mask[b, kv] return inner

def and_masks(f1, f2): def combined(b, h, q, kv): return f1(b, h, q, kv) & f2(b, h, q, kv) return combined

def make_mask_closure(batch_size, q_len, kv_len, q_offset, padding_mask): fn = and_masks(causal_fn, padding_fn(padding_mask)) b = torch.arange(batch_size)[:, None, None, None] h = torch.arange(1)[None, :, None, None] q = (torch.arange(q_len) + q_offset)[None, None, :, None] kv = torch.arange(kv_len)[None, None, None, :] return fn(b, h, q, kv).expand(batch_size, 1, q_len, kv_len)

---- Inline mask construction (no closures, works fine) ----

def make_mask_inline(batch_size, q_len, kv_len, q_offset, padding_mask): b = torch.arange(batch_size)[:, None, None, None] kv = torch.arange(kv_len)[None, None, None, :] q = (torch.arange(q_len) + q_offset)[None, None, :, None] return (padding_mask[b, kv] & (kv <= q)).expand(batch_size, 1, q_len, kv_len)

---- Minimal model: mask + SDPA ----

class Model(torch.nn.Module): def init(self, use_closure=True): super().init() self.proj = torch.nn.Linear(32, 32, bias=False) self.use_closure = use_closure

def forward(self, x, past_k, past_v, padding_mask, q_offset):
    B, S, _ = x.shape
    q = self.proj(x).view(B, S, 4, 8).transpose(1, 2)
    k = x.view(B, S, 4, 8).transpose(1, 2)
    v = k.clone()
    if past_k is not None:
        k = torch.cat([past_k, k], dim=2)
        v = torch.cat([past_v, v], dim=2)

    if self.use_closure:
        mask = make_mask_closure(B, S, k.shape[2], q_offset, padding_mask)
    else:
        mask = make_mask_inline(B, S, k.shape[2], q_offset, padding_mask)

    out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
    return out.transpose(1, 2).reshape(B, S, 32), k, v

---- Test: prefill then decode (shape changes → recompilation) ----

def test(label, use_closure): print(f"\n--- {label} ---") torch._dynamo.reset() model = Model(use_closure=use_closure).eval() compiled = torch.compile(model)

pad_mask = torch.ones(1, 8, dtype=torch.bool)
pad_mask[0, :2] = False  # 2 padding tokens

# Prefill: seq_len=8
x = torch.randn(1, 8, 32)
with torch.no_grad():
    out, pk, pv = compiled(x, None, None, pad_mask, q_offset=0)
print(f"  Prefill (kv=8): OK")

# Decode steps: seq_len=1, kv grows each step → recompilation
for step in range(3):
    pad_mask = torch.cat([pad_mask, torch.ones(1, 1, dtype=torch.bool)], dim=1)
    x = torch.randn(1, 1, 32)
    try:
        with torch.no_grad():
            out, pk, pv = compiled(x, pk, pv, pad_mask, q_offset=pk.shape[2])
        print(f"  Decode step {step+1} (kv={pk.shape[2]}): OK")
    except Exception as e:
        import traceback
        print(f"  Decode step {step+1}: FAILED — {type(e).__name__}")
        traceback.print_exc()
        return

if name == "main": print(f"torch {torch.version}") test("Closure-based mask (BUG)", use_closure=True) test("Inline mask (OK)", use_closure=False)

Root Cause

This pattern is used in HuggingFace transformers library's unified mask creation (masking_utils.py) which composes causal_mask_function, padding_mask_function, sliding_window_mask_function, etc. via and_masks() / or_masks(). This means any model using torch.compile with padding tokens on CPU will hit this bug (e.g. Mllama, Llama with padding, etc.).

Upstream issue: huggingface/transformers#44458 Upstream workaround PR: huggingface/transformers#44845 (avoids closure composition for padding mask)

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01

Fix Action

Fix / Workaround

Upstream issue: huggingface/transformers#44458 Upstream workaround PR: huggingface/transformers#44845 (avoids closure composition for padding mask)

PR fix notes

PR #44845: Fix Mllama torch.compile failure caused by new attention mask logic

Description (problem / solution / changelog)

What does this PR do?

Fixes torch.compile failure for Mllama after #42848 introduced a new unified attention mask creation path.

The root cause is a torch inductor C++ codegen bug: when padding_mask_function uses advanced tensor indexing (padding_mask[batch_idx, kv_idx]), the generated C++ boundary-check code references an undeclared variable (tmp2), causing g++ compilation to fail with CppCompileError.

This PR applies two changes:

  1. masking_utils.py: In the non-vmap sdpa_mask path, apply the padding mask separately using slice-based indexing (padding_mask[:, kv_offset : kv_offset + kv_length]) instead of merging it into the mask_function with advanced tensor indexing. This avoids the inductor codegen bug while producing identical results.

  2. modeling_mllama.py: Replace torch.arange-based fancy indexing with simple slice indexing when extracting cross_attention_mask and full_text_row_masked_out_mask for the current sequence position. This is semantically equivalent but more torch.compile-friendly.

Fixes #44458

Changed files

  • src/transformers/masking_utils.py (modified, +11/-1)
  • src/transformers/models/mllama/modeling_mllama.py (modified, +4/-4)

Code Example

"""
torch.compile inductor bug: closure-based mask + SDPA fails on recompilation (CPU only).

Same mask logic written inline works fine. Only the closure composition
pattern triggers a C++ codegen error (undeclared variable) on recompilation.
"""
import torch
import torch.nn.functional as F


# ---- Closure-based mask construction (triggers bug) ----

def causal_fn(b, h, q, kv):
    return kv <= q

def padding_fn(padding_mask):
    def inner(b, h, q, kv):
        return padding_mask[b, kv]
    return inner

def and_masks(f1, f2):
    def combined(b, h, q, kv):
        return f1(b, h, q, kv) & f2(b, h, q, kv)
    return combined

def make_mask_closure(batch_size, q_len, kv_len, q_offset, padding_mask):
    fn = and_masks(causal_fn, padding_fn(padding_mask))
    b = torch.arange(batch_size)[:, None, None, None]
    h = torch.arange(1)[None, :, None, None]
    q = (torch.arange(q_len) + q_offset)[None, None, :, None]
    kv = torch.arange(kv_len)[None, None, None, :]
    return fn(b, h, q, kv).expand(batch_size, 1, q_len, kv_len)


# ---- Inline mask construction (no closures, works fine) ----

def make_mask_inline(batch_size, q_len, kv_len, q_offset, padding_mask):
    b = torch.arange(batch_size)[:, None, None, None]
    kv = torch.arange(kv_len)[None, None, None, :]
    q = (torch.arange(q_len) + q_offset)[None, None, :, None]
    return (padding_mask[b, kv] & (kv <= q)).expand(batch_size, 1, q_len, kv_len)


# ---- Minimal model: mask + SDPA ----

class Model(torch.nn.Module):
    def __init__(self, use_closure=True):
        super().__init__()
        self.proj = torch.nn.Linear(32, 32, bias=False)
        self.use_closure = use_closure

    def forward(self, x, past_k, past_v, padding_mask, q_offset):
        B, S, _ = x.shape
        q = self.proj(x).view(B, S, 4, 8).transpose(1, 2)
        k = x.view(B, S, 4, 8).transpose(1, 2)
        v = k.clone()
        if past_k is not None:
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        if self.use_closure:
            mask = make_mask_closure(B, S, k.shape[2], q_offset, padding_mask)
        else:
            mask = make_mask_inline(B, S, k.shape[2], q_offset, padding_mask)

        out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        return out.transpose(1, 2).reshape(B, S, 32), k, v


# ---- Test: prefill then decode (shape changes → recompilation) ----

def test(label, use_closure):
    print(f"\n--- {label} ---")
    torch._dynamo.reset()
    model = Model(use_closure=use_closure).eval()
    compiled = torch.compile(model)

    pad_mask = torch.ones(1, 8, dtype=torch.bool)
    pad_mask[0, :2] = False  # 2 padding tokens

    # Prefill: seq_len=8
    x = torch.randn(1, 8, 32)
    with torch.no_grad():
        out, pk, pv = compiled(x, None, None, pad_mask, q_offset=0)
    print(f"  Prefill (kv=8): OK")

    # Decode steps: seq_len=1, kv grows each step → recompilation
    for step in range(3):
        pad_mask = torch.cat([pad_mask, torch.ones(1, 1, dtype=torch.bool)], dim=1)
        x = torch.randn(1, 1, 32)
        try:
            with torch.no_grad():
                out, pk, pv = compiled(x, pk, pv, pad_mask, q_offset=pk.shape[2])
            print(f"  Decode step {step+1} (kv={pk.shape[2]}): OK")
        except Exception as e:
            import traceback
            print(f"  Decode step {step+1}: FAILED — {type(e).__name__}")
            traceback.print_exc()
            return


if __name__ == "__main__":
    print(f"torch {torch.__version__}")
    test("Closure-based mask (BUG)", use_closure=True)
    test("Inline mask (OK)", use_closure=False)

---

--- Closure-based mask (BUG) ---
  Prefill (kv=8): OK
  Decode step 1: FAILEDInductorError
  ...
  CppCompileError: C++ compile error
  ...
  error: 'tmp2' was not declared in this scope
  ...

--- Inline mask (OK) ---
  Prefill (kv=8): OK
  Decode step 1 (kv=9): OK
  Decode step 2 (kv=10): OK
  Decode step 3 (kv=11): OK
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

On CPU, when using torch.compile with a model that:

  1. Builds an attention mask via closure composition (nested closures that capture a tensor and index it with broadcast indices)
  2. Feeds the mask to F.scaled_dot_product_attention
  3. Is called multiple times with different input shapes (triggering recompilation, e.g. prefill then decode in an autoregressive generation loop)

The inductor C++ codegen generates code that references an undeclared variable (e.g. tmp2), causing g++ compilation to fail with CppCompileError.

The same mask logic written inline (without closures) works fine. The two approaches are mathematically equivalent — the only difference is whether the mask is constructed through composed closures or directly.

This bug was discovered in transformers (huggingface/transformers#44458), where masking_utils.py uses exactly this closure composition pattern to build attention masks for all models.

Key observation

Mask constructionResult
Closure-composed: and_masks(causal_fn, padding_fn(mask)) applied to broadcast indicesFAILS on recompilation
Inline equivalent: mask[b, kv] & (kv <= q) with same broadcast indicesWorks fine

This suggests dynamo traces the closure composition into a different FX graph structure than the inline equivalent, and inductor's C++ backend cannot correctly codegen for that specific graph on recompilation.

Reproduction script

"""
torch.compile inductor bug: closure-based mask + SDPA fails on recompilation (CPU only).

Same mask logic written inline works fine. Only the closure composition
pattern triggers a C++ codegen error (undeclared variable) on recompilation.
"""
import torch
import torch.nn.functional as F


# ---- Closure-based mask construction (triggers bug) ----

def causal_fn(b, h, q, kv):
    return kv <= q

def padding_fn(padding_mask):
    def inner(b, h, q, kv):
        return padding_mask[b, kv]
    return inner

def and_masks(f1, f2):
    def combined(b, h, q, kv):
        return f1(b, h, q, kv) & f2(b, h, q, kv)
    return combined

def make_mask_closure(batch_size, q_len, kv_len, q_offset, padding_mask):
    fn = and_masks(causal_fn, padding_fn(padding_mask))
    b = torch.arange(batch_size)[:, None, None, None]
    h = torch.arange(1)[None, :, None, None]
    q = (torch.arange(q_len) + q_offset)[None, None, :, None]
    kv = torch.arange(kv_len)[None, None, None, :]
    return fn(b, h, q, kv).expand(batch_size, 1, q_len, kv_len)


# ---- Inline mask construction (no closures, works fine) ----

def make_mask_inline(batch_size, q_len, kv_len, q_offset, padding_mask):
    b = torch.arange(batch_size)[:, None, None, None]
    kv = torch.arange(kv_len)[None, None, None, :]
    q = (torch.arange(q_len) + q_offset)[None, None, :, None]
    return (padding_mask[b, kv] & (kv <= q)).expand(batch_size, 1, q_len, kv_len)


# ---- Minimal model: mask + SDPA ----

class Model(torch.nn.Module):
    def __init__(self, use_closure=True):
        super().__init__()
        self.proj = torch.nn.Linear(32, 32, bias=False)
        self.use_closure = use_closure

    def forward(self, x, past_k, past_v, padding_mask, q_offset):
        B, S, _ = x.shape
        q = self.proj(x).view(B, S, 4, 8).transpose(1, 2)
        k = x.view(B, S, 4, 8).transpose(1, 2)
        v = k.clone()
        if past_k is not None:
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        if self.use_closure:
            mask = make_mask_closure(B, S, k.shape[2], q_offset, padding_mask)
        else:
            mask = make_mask_inline(B, S, k.shape[2], q_offset, padding_mask)

        out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        return out.transpose(1, 2).reshape(B, S, 32), k, v


# ---- Test: prefill then decode (shape changes → recompilation) ----

def test(label, use_closure):
    print(f"\n--- {label} ---")
    torch._dynamo.reset()
    model = Model(use_closure=use_closure).eval()
    compiled = torch.compile(model)

    pad_mask = torch.ones(1, 8, dtype=torch.bool)
    pad_mask[0, :2] = False  # 2 padding tokens

    # Prefill: seq_len=8
    x = torch.randn(1, 8, 32)
    with torch.no_grad():
        out, pk, pv = compiled(x, None, None, pad_mask, q_offset=0)
    print(f"  Prefill (kv=8): OK")

    # Decode steps: seq_len=1, kv grows each step → recompilation
    for step in range(3):
        pad_mask = torch.cat([pad_mask, torch.ones(1, 1, dtype=torch.bool)], dim=1)
        x = torch.randn(1, 1, 32)
        try:
            with torch.no_grad():
                out, pk, pv = compiled(x, pk, pv, pad_mask, q_offset=pk.shape[2])
            print(f"  Decode step {step+1} (kv={pk.shape[2]}): OK")
        except Exception as e:
            import traceback
            print(f"  Decode step {step+1}: FAILED — {type(e).__name__}")
            traceback.print_exc()
            return


if __name__ == "__main__":
    print(f"torch {torch.__version__}")
    test("Closure-based mask (BUG)", use_closure=True)
    test("Inline mask (OK)", use_closure=False)

Expected output

Both closure-based and inline tests should pass all decode steps, since they compute the same mask.

Actual output (2.12.0.dev20260316+cpu / 2.10.0+cpu)

--- Closure-based mask (BUG) ---
  Prefill (kv=8): OK
  Decode step 1: FAILED — InductorError
  ...
  CppCompileError: C++ compile error
  ...
  error: 'tmp2' was not declared in this scope
  ...

--- Inline mask (OK) ---
  Prefill (kv=8): OK
  Decode step 1 (kv=9): OK
  Decode step 2 (kv=10): OK
  Decode step 3 (kv=11): OK

Environment

  • Platform: CPU only (tested on Linux x86_64)
  • PyTorch: 2.12.0.dev20260316+cpu
  • Python: 3.12
  • Compiler: g++ (system default)

Context

This pattern is used in HuggingFace transformers library's unified mask creation (masking_utils.py) which composes causal_mask_function, padding_mask_function, sliding_window_mask_function, etc. via and_masks() / or_masks(). This means any model using torch.compile with padding tokens on CPU will hit this bug (e.g. Mllama, Llama with padding, etc.).

Upstream issue: huggingface/transformers#44458 Upstream workaround PR: huggingface/transformers#44845 (avoids closure composition for padding mask)

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01

extent analysis

Fix Plan

To fix the issue with torch.compile and closure-based mask construction, we need to avoid using closure composition for the mask creation. Here are the steps:

  • Replace the make_mask_closure function with the make_mask_inline function, which constructs the mask directly without using closures.
  • Update the Model class to use the make_mask_inline function by default.
  • If you need to use the closure-based mask construction for specific models, consider creating a separate implementation that avoids closure composition.

Example code changes:

class Model(torch.nn.Module):
    def __init__(self, use_closure=False):
        super().__init__()
        self.proj = torch.nn.Linear(32, 32, bias=False)
        # Always use the inline mask construction
        self.use_closure = False

    def forward(self, x, past_k, past_v, padding_mask, q_offset):
        # ...
        if self.use_closure:
            mask = make_mask_closure(B, S, k.shape[2], q_offset, padding_mask)
        else:
            mask = make_mask_inline(B, S, k.shape[2], q_offset, padding_mask)
        # ...

Alternatively, you can modify the make_mask_closure function to inline the closure composition, similar to the make_mask_inline function:

def make_mask_closure(batch_size, q_len, kv_len, q_offset, padding_mask):
    b = torch.arange(batch_size)[:, None, None, None]
    h = torch.arange(1)[None, :, None, None]
    q = (torch.arange(q_len) + q_offset)[None, None, :, None]
    kv = torch.arange(kv_len)[None, None, None, :]
    causal_mask = kv <= q
    padding_mask = padding_mask[b, kv]
    return (causal_mask & padding_mask).expand(batch_size, 1, q_len, kv_len)

Verification

To verify that the fix worked, run the test script again with the modified code. The closure-based mask construction should now work correctly, and the test should pass all decode steps.

Extra Tips

  • When using torch.compile, it's recommended to avoid using closure composition for performance-critical code paths.
  • If you need to use closure composition, consider inlining the closures or using a separate implementation that avoids closure composition.
  • Keep in mind that the torch.compile API is still experimental and may change in future releases.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [CPU] `torch.compile` inductor C++ codegen fails when closure-composed mask functions are used with `F.scaled_dot_product_attention` across recompilations [1 pull requests, 1 comments, 1 participants]