pytorch - 💡(How to fix) Fix `aot_export_joint_with_descriptors` fails with `flex_attention` when BlockMask tensors are graph placeholders [1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#182659Fetched 2026-05-07 03:30:47
View on GitHub
Comments
1
Participants
2
Timeline
94
Reactions
0
Author
Participants
Timeline (top)
mentioned ×42subscribed ×42labeled ×9commented ×1

Error Message

AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in flex_attention(FakeTensor(...), FakeTensor(...), FakeTensor(...), ..., (16, 16, tensor([...], size=(1, 1, 1), dtype=torch.int32), ...))

Root Cause

The same model works with _dynamo_graph_capture_for_export (private API) because BlockMask tensors are captured as get_attr nodes instead of placeholders.

Fix Action

Fix / Workaround

The error occurs in run_functionalized_fw_and_collect_metadata which enters FunctionalTensorMode and runs the graph. The flex_attention HOP dispatch chain goes through flex_attention_autogradflex_attention_functionalize (unwraps FunctionalTensor) → FakeTensorMode.__torch_dispatch__validate_and_convert_non_fake_tensors. At this point, the BlockMask index tensors have lost their FakeTensor identity after being wrapped/unwrapped through FunctionalTensorMode.

Code Example

import torch, torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask
from contextlib import ExitStack
import torch.utils._pytree as pytree

def bm_flatten(bm):
    children = [bm.kv_num_blocks, bm.kv_indices, bm.full_kv_num_blocks, bm.full_kv_indices,
                bm.q_num_blocks, bm.q_indices, bm.full_q_num_blocks, bm.full_q_indices]
    return children, (bm.BLOCK_SIZE, bm.mask_mod, bm.seq_lengths)
def bm_unflatten(children, aux):
    BLOCK_SIZE, mask_mod, seq_lengths = aux
    return BlockMask(kv_num_blocks=children[0], kv_indices=children[1],
        full_kv_num_blocks=children[2], full_kv_indices=children[3],
        q_num_blocks=children[4], q_indices=children[5],
        full_q_num_blocks=children[6], full_q_indices=children[7],
        BLOCK_SIZE=BLOCK_SIZE, mask_mod=mask_mod, seq_lengths=seq_lengths)
pytree.register_pytree_node(BlockMask, bm_flatten, bm_unflatten)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.wq = nn.Linear(64, 64, bias=False)
        self.block_mask = create_block_mask(
            lambda b, h, q, kv: q >= kv, B=None, H=None, Q_LEN=16, KV_LEN=16, device="cuda")
    def forward(self, x):
        q = self.wq(x).view(1, 16, 4, 16).transpose(1, 2)
        return flex_attention(q, q, q, block_mask=self.block_mask).sum()

with torch.device("meta"):
    model = Model()

fm = FakeTensorMode()
with fm:
    for name, param in list(model.named_parameters()):
        parts = name.split(".")
        mod = model
        for part in parts[:-1]:
            mod = getattr(mod, part)
        setattr(mod, parts[-1], nn.Parameter(
            torch.empty(param.shape, dtype=param.dtype, device="cuda"),
            requires_grad=param.requires_grad))
    x = torch.randn(1, 16, 64, device="cuda")

torch._dynamo.reset()
gm = dynamo_graph_capture_for_export(model)(x)

stack = ExitStack()
joint = aot_export_joint_with_descriptors(stack, gm, (x,))  # Fails here
stack.__exit__(None, None, None)

---

AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode 
with 'allow_non_fake_inputs'. Found in flex_attention(FakeTensor(...), FakeTensor(...), 
FakeTensor(...), ..., (16, 16, tensor([...], size=(1, 1, 1), dtype=torch.int32), ...))
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Description:

When using dynamo_graph_capture_for_export to capture a model with flex_attention, BlockMask index tensors become graph placeholders. Passing this graph to aot_export_joint_with_descriptors fails during metadata collection (run_functionalized_fw_and_collect_metadata) with "Please convert all Tensors to FakeTensors first".

The same model works with _dynamo_graph_capture_for_export (private API) because BlockMask tensors are captured as get_attr nodes instead of placeholders.

Repro:

import torch, torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask
from contextlib import ExitStack
import torch.utils._pytree as pytree

def bm_flatten(bm):
    children = [bm.kv_num_blocks, bm.kv_indices, bm.full_kv_num_blocks, bm.full_kv_indices,
                bm.q_num_blocks, bm.q_indices, bm.full_q_num_blocks, bm.full_q_indices]
    return children, (bm.BLOCK_SIZE, bm.mask_mod, bm.seq_lengths)
def bm_unflatten(children, aux):
    BLOCK_SIZE, mask_mod, seq_lengths = aux
    return BlockMask(kv_num_blocks=children[0], kv_indices=children[1],
        full_kv_num_blocks=children[2], full_kv_indices=children[3],
        q_num_blocks=children[4], q_indices=children[5],
        full_q_num_blocks=children[6], full_q_indices=children[7],
        BLOCK_SIZE=BLOCK_SIZE, mask_mod=mask_mod, seq_lengths=seq_lengths)
pytree.register_pytree_node(BlockMask, bm_flatten, bm_unflatten)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.wq = nn.Linear(64, 64, bias=False)
        self.block_mask = create_block_mask(
            lambda b, h, q, kv: q >= kv, B=None, H=None, Q_LEN=16, KV_LEN=16, device="cuda")
    def forward(self, x):
        q = self.wq(x).view(1, 16, 4, 16).transpose(1, 2)
        return flex_attention(q, q, q, block_mask=self.block_mask).sum()

with torch.device("meta"):
    model = Model()

fm = FakeTensorMode()
with fm:
    for name, param in list(model.named_parameters()):
        parts = name.split(".")
        mod = model
        for part in parts[:-1]:
            mod = getattr(mod, part)
        setattr(mod, parts[-1], nn.Parameter(
            torch.empty(param.shape, dtype=param.dtype, device="cuda"),
            requires_grad=param.requires_grad))
    x = torch.randn(1, 16, 64, device="cuda")

torch._dynamo.reset()
gm = dynamo_graph_capture_for_export(model)(x)

stack = ExitStack()
joint = aot_export_joint_with_descriptors(stack, gm, (x,))  # Fails here
stack.__exit__(None, None, None)

Error:

AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode 
with 'allow_non_fake_inputs'. Found in flex_attention(FakeTensor(...), FakeTensor(...), 
FakeTensor(...), ..., (16, 16, tensor([...], size=(1, 1, 1), dtype=torch.int32), ...))

Analysis:

The error occurs in run_functionalized_fw_and_collect_metadata which enters FunctionalTensorMode and runs the graph. The flex_attention HOP dispatch chain goes through flex_attention_autogradflex_attention_functionalize (unwraps FunctionalTensor) → FakeTensorMode.__torch_dispatch__validate_and_convert_non_fake_tensors. At this point, the BlockMask index tensors have lost their FakeTensor identity after being wrapped/unwrapped through FunctionalTensorMode.

The private _dynamo_graph_capture_for_export avoids this because BlockMask tensors are get_attr nodes (module constants) that don't flow through FunctionalTensorMode wrapping — they're loaded directly from the module during interpretation.

Versions

PyTorch version 2.13.0.dev20260428+cu130

cc @bdhirsh @ezyang @chauhang @penguinwu @eellison @aorenste @ydwu4 @bobrenjc93 @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING