pytorch - 💡(How to fix) Fix `torch.export` of `sdpa` fails with unbacked batch dim due to guard [2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180202Fetched 2026-04-15 06:19:33
View on GitHub
Comments
2
Participants
2
Timeline
24
Reactions
0
Timeline (top)
mentioned ×9subscribed ×9labeled ×4commented ×2

Error Message

import torch import torch.nn as nn import torch.nn.functional as F

class ModelWithMaskedSDPA(nn.Module): def init(self, hidden=32, heads=4): super().init() self.proj_q = nn.Linear(hidden, hidden) self.proj_k = nn.Linear(hidden, hidden) self.proj_v = nn.Linear(hidden, hidden) self.heads = heads self.head_dim = hidden // heads

def forward(self, x, bool_mask):
    # Boolean indexing creates unbacked symbolic dim u0
    x = x[bool_mask]  # (u0, seq, hidden)
    batch_size = x.shape[0]
    seq = x.shape[1]

    q = self.proj_q(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    k = self.proj_k(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    v = self.proj_v(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)

    # Create attention mask with unbacked batch dim (like idefics3)
    attn_mask = torch.ones(batch_size, 1, seq, seq, dtype=x.dtype, device=x.device)

    out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
    return out.transpose(1, 2).flatten(-2)

model = ModelWithMaskedSDPA().eval()

x = torch.randn(4, 8, 32) mask = torch.tensor([True, True, False, True])

with torch.no_grad(): eager_out = model(x, mask) print(f"Eager OK: {eager_out.shape}")

for device in ["cuda", "cpu"]: x = x.to(device) mask = mask.to(device) model = model.to(device) print(f"Exporting on {device.upper()}...") try: ep = torch.export.export(model, (x, mask), strict=False) print(f"Export on {device.upper()} OK") with torch.no_grad(): out = ep.module()(x, mask) print(f"Run OK: {out.shape}") except Exception as e: print(f"Export on {device.upper()} failed: {e}")

Root Cause

consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_true. Caused by: (_ops.py:865 in call) For more information, run with TORCH_LOGS="dynamic" For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0" If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

Fix Action

Fix / Workaround

This snippet fails export on cpu only (works on cuda) due to an added guard on single batch size (I guess some sort of batch size dependent implementation dispatch)

Code Example

import torch
import torch.nn as nn
import torch.nn.functional as F


class ModelWithMaskedSDPA(nn.Module):
    def __init__(self, hidden=32, heads=4):
        super().__init__()
        self.proj_q = nn.Linear(hidden, hidden)
        self.proj_k = nn.Linear(hidden, hidden)
        self.proj_v = nn.Linear(hidden, hidden)
        self.heads = heads
        self.head_dim = hidden // heads

    def forward(self, x, bool_mask):
        # Boolean indexing creates unbacked symbolic dim u0
        x = x[bool_mask]  # (u0, seq, hidden)
        batch_size = x.shape[0]
        seq = x.shape[1]

        q = self.proj_q(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
        k = self.proj_k(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
        v = self.proj_v(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)

        # Create attention mask with unbacked batch dim (like idefics3)
        attn_mask = torch.ones(batch_size, 1, seq, seq, dtype=x.dtype, device=x.device)

        out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        return out.transpose(1, 2).flatten(-2)


model = ModelWithMaskedSDPA().eval()

x = torch.randn(4, 8, 32)
mask = torch.tensor([True, True, False, True])

with torch.no_grad():
    eager_out = model(x, mask)
    print(f"Eager OK: {eager_out.shape}")

for device in ["cuda", "cpu"]:
    x = x.to(device)
    mask = mask.to(device)
    model = model.to(device)
    print(f"Exporting on {device.upper()}...")
    try:
        ep = torch.export.export(model, (x, mask), strict=False)
        print(f"Export on {device.upper()} OK")
        with torch.no_grad():
            out = ep.module()(x, mask)
            print(f"Run OK: {out.shape}")
    except Exception as e:
        print(f"Export on {device.upper()} failed: {e}")

---

Eager OK: torch.Size([3, 8, 32])
Exporting on CUDA...
Export on CUDA OK
Run OK: torch.Size([3, 8, 32])
Exporting on CPU...



def forward(self, arg0_1: "f32[32, 32]", arg1_1: "f32[32]", arg2_1: "f32[32, 32]", arg3_1: "f32[32]", arg4_1: "f32[32, 32]", arg5_1: "f32[32]", arg6_1: "f32[4, 8, 32]", arg7_1: "b8[4]"):
    # File: /home/ilyas/transformers/test_idefics3_export.py:17 in forward, code: x = x[bool_mask]  # (u0, seq, hidden)
    index: "f32[u0, 8, 32]" = torch.ops.aten.index.Tensor(arg6_1, [arg7_1]);  arg6_1 = arg7_1 = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:21 in forward, code: q = self.proj_q(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(index, 0)
    view: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear, [sym_size_int, 8, 4, 8]);  linear = None
    transpose: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear_1: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg2_1, arg3_1);  arg2_1 = arg3_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:22 in forward, code: k = self.proj_k(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    view_1: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear_1, [sym_size_int, 8, 4, 8]);  linear_1 = None
    transpose_1: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear_2: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg4_1, arg5_1);  index = arg4_1 = arg5_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:23 in forward, code: v = self.proj_v(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    view_2: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear_2, [sym_size_int, 8, 4, 8]);  linear_2 = None
    transpose_2: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:26 in forward, code: attn_mask = torch.ones(batch_size, 1, seq, seq, dtype=x.dtype, device=x.device)
    ones: "f32[u0, 1, 8, 8]" = torch.ops.aten.ones.default([sym_size_int, 1, 8, 8], dtype = torch.float32, device = device(type='cpu'), pin_memory = False);  sym_size_int = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:28 in forward, code: out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
    scaled_dot_product_attention = torch.ops.aten.scaled_dot_product_attention.default(transpose, transpose_1, transpose_2, ones);  transpose = transpose_1 = transpose_2 = ones = scaled_dot_product_attention = None
    



def forward(self, arg0_1: "f32[32, 32]", arg1_1: "f32[32]", arg2_1: "f32[32, 32]", arg3_1: "f32[32]", arg4_1: "f32[32, 32]", arg5_1: "f32[32]", arg6_1: "f32[4, 8, 32]", arg7_1: "b8[4]"):
    # File: /home/ilyas/transformers/test_idefics3_export.py:17 in forward, code: x = x[bool_mask]  # (u0, seq, hidden)
    index: "f32[u0, 8, 32]" = torch.ops.aten.index.Tensor(arg6_1, [arg7_1]);  arg6_1 = arg7_1 = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:21 in forward, code: q = self.proj_q(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(index, 0)
    view: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear, [sym_size_int, 8, 4, 8]);  linear = None
    transpose: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear_1: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg2_1, arg3_1);  arg2_1 = arg3_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:22 in forward, code: k = self.proj_k(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    view_1: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear_1, [sym_size_int, 8, 4, 8]);  linear_1 = None
    transpose_1: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear_2: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg4_1, arg5_1);  index = arg4_1 = arg5_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:23 in forward, code: v = self.proj_v(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    view_2: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear_2, [sym_size_int, 8, 4, 8]);  linear_2 = None
    transpose_2: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:26 in forward, code: attn_mask = torch.ones(batch_size, 1, seq, seq, dtype=x.dtype, device=x.device)
    ones: "f32[u0, 1, 8, 8]" = torch.ops.aten.ones.default([sym_size_int, 1, 8, 8], dtype = torch.float32, device = device(type='cpu'), pin_memory = False);  sym_size_int = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:28 in forward, code: out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
    scaled_dot_product_attention = torch.ops.aten.scaled_dot_product_attention.default(transpose, transpose_1, transpose_2, ones);  transpose = transpose_1 = transpose_2 = ones = scaled_dot_product_attention = None
    
Export on CPU failed: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: u0)

consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_true.
Caused by: (_ops.py:865 in __call__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/home/ilyas/transformers/test_idefics3_export.py", line 28, in forward
    out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)

To fix the error, insert one of the following checks before this call:
  1. torch._check(x.shape[0] == 1)
  2. torch._check(x.shape[0] != 1)

(These suggested fixes were derived by replacing `u0` with x.shape[0] or batch_size or q.shape[0] or k.shape[0] or v.shape[0] or attn_mask.shape[0] in Eq(u0, 1) and its negation.)

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

This snippet fails export on cpu only (works on cuda) due to an added guard on single batch size (I guess some sort of batch size dependent implementation dispatch)

import torch
import torch.nn as nn
import torch.nn.functional as F


class ModelWithMaskedSDPA(nn.Module):
    def __init__(self, hidden=32, heads=4):
        super().__init__()
        self.proj_q = nn.Linear(hidden, hidden)
        self.proj_k = nn.Linear(hidden, hidden)
        self.proj_v = nn.Linear(hidden, hidden)
        self.heads = heads
        self.head_dim = hidden // heads

    def forward(self, x, bool_mask):
        # Boolean indexing creates unbacked symbolic dim u0
        x = x[bool_mask]  # (u0, seq, hidden)
        batch_size = x.shape[0]
        seq = x.shape[1]

        q = self.proj_q(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
        k = self.proj_k(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
        v = self.proj_v(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)

        # Create attention mask with unbacked batch dim (like idefics3)
        attn_mask = torch.ones(batch_size, 1, seq, seq, dtype=x.dtype, device=x.device)

        out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        return out.transpose(1, 2).flatten(-2)


model = ModelWithMaskedSDPA().eval()

x = torch.randn(4, 8, 32)
mask = torch.tensor([True, True, False, True])

with torch.no_grad():
    eager_out = model(x, mask)
    print(f"Eager OK: {eager_out.shape}")

for device in ["cuda", "cpu"]:
    x = x.to(device)
    mask = mask.to(device)
    model = model.to(device)
    print(f"Exporting on {device.upper()}...")
    try:
        ep = torch.export.export(model, (x, mask), strict=False)
        print(f"Export on {device.upper()} OK")
        with torch.no_grad():
            out = ep.module()(x, mask)
            print(f"Run OK: {out.shape}")
    except Exception as e:
        print(f"Export on {device.upper()} failed: {e}")

fails with

Eager OK: torch.Size([3, 8, 32])
Exporting on CUDA...
Export on CUDA OK
Run OK: torch.Size([3, 8, 32])
Exporting on CPU...



def forward(self, arg0_1: "f32[32, 32]", arg1_1: "f32[32]", arg2_1: "f32[32, 32]", arg3_1: "f32[32]", arg4_1: "f32[32, 32]", arg5_1: "f32[32]", arg6_1: "f32[4, 8, 32]", arg7_1: "b8[4]"):
    # File: /home/ilyas/transformers/test_idefics3_export.py:17 in forward, code: x = x[bool_mask]  # (u0, seq, hidden)
    index: "f32[u0, 8, 32]" = torch.ops.aten.index.Tensor(arg6_1, [arg7_1]);  arg6_1 = arg7_1 = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:21 in forward, code: q = self.proj_q(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(index, 0)
    view: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear, [sym_size_int, 8, 4, 8]);  linear = None
    transpose: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear_1: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg2_1, arg3_1);  arg2_1 = arg3_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:22 in forward, code: k = self.proj_k(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    view_1: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear_1, [sym_size_int, 8, 4, 8]);  linear_1 = None
    transpose_1: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear_2: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg4_1, arg5_1);  index = arg4_1 = arg5_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:23 in forward, code: v = self.proj_v(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    view_2: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear_2, [sym_size_int, 8, 4, 8]);  linear_2 = None
    transpose_2: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:26 in forward, code: attn_mask = torch.ones(batch_size, 1, seq, seq, dtype=x.dtype, device=x.device)
    ones: "f32[u0, 1, 8, 8]" = torch.ops.aten.ones.default([sym_size_int, 1, 8, 8], dtype = torch.float32, device = device(type='cpu'), pin_memory = False);  sym_size_int = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:28 in forward, code: out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
    scaled_dot_product_attention = torch.ops.aten.scaled_dot_product_attention.default(transpose, transpose_1, transpose_2, ones);  transpose = transpose_1 = transpose_2 = ones = scaled_dot_product_attention = None
    



def forward(self, arg0_1: "f32[32, 32]", arg1_1: "f32[32]", arg2_1: "f32[32, 32]", arg3_1: "f32[32]", arg4_1: "f32[32, 32]", arg5_1: "f32[32]", arg6_1: "f32[4, 8, 32]", arg7_1: "b8[4]"):
    # File: /home/ilyas/transformers/test_idefics3_export.py:17 in forward, code: x = x[bool_mask]  # (u0, seq, hidden)
    index: "f32[u0, 8, 32]" = torch.ops.aten.index.Tensor(arg6_1, [arg7_1]);  arg6_1 = arg7_1 = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:21 in forward, code: q = self.proj_q(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(index, 0)
    view: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear, [sym_size_int, 8, 4, 8]);  linear = None
    transpose: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear_1: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg2_1, arg3_1);  arg2_1 = arg3_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:22 in forward, code: k = self.proj_k(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    view_1: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear_1, [sym_size_int, 8, 4, 8]);  linear_1 = None
    transpose_1: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None
    
    # File: /home/ilyas/transformers/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear_2: "f32[u0, 8, 32]" = torch.ops.aten.linear.default(index, arg4_1, arg5_1);  index = arg4_1 = arg5_1 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:23 in forward, code: v = self.proj_v(x).view(batch_size, seq, self.heads, self.head_dim).transpose(1, 2)
    view_2: "f32[u0, 8, 4, 8]" = torch.ops.aten.view.default(linear_2, [sym_size_int, 8, 4, 8]);  linear_2 = None
    transpose_2: "f32[u0, 4, 8, 8]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:26 in forward, code: attn_mask = torch.ones(batch_size, 1, seq, seq, dtype=x.dtype, device=x.device)
    ones: "f32[u0, 1, 8, 8]" = torch.ops.aten.ones.default([sym_size_int, 1, 8, 8], dtype = torch.float32, device = device(type='cpu'), pin_memory = False);  sym_size_int = None
    
    # File: /home/ilyas/transformers/test_idefics3_export.py:28 in forward, code: out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
    scaled_dot_product_attention = torch.ops.aten.scaled_dot_product_attention.default(transpose, transpose_1, transpose_2, ones);  transpose = transpose_1 = transpose_2 = ones = scaled_dot_product_attention = None
    
Export on CPU failed: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: u0)

consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_true.
Caused by: (_ops.py:865 in __call__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/home/ilyas/transformers/test_idefics3_export.py", line 28, in forward
    out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)

To fix the error, insert one of the following checks before this call:
  1. torch._check(x.shape[0] == 1)
  2. torch._check(x.shape[0] != 1)

(These suggested fixes were derived by replacing `u0` with x.shape[0] or batch_size or q.shape[0] or k.shape[0] or v.shape[0] or attn_mask.shape[0] in Eq(u0, 1) and its negation.)

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

Versions

latest torch.

cc @chauhang @penguinwu @avikchaudhuri @zhxchen17 @tugsbayasgalan @angelayi @ydwu4

extent analysis

TL;DR

The issue is likely due to a data-dependent expression that cannot be guarded, and a potential fix is to add a check for the batch size before calling F.scaled_dot_product_attention.

Guidance

  • The error message suggests using data-dependent friendly APIs such as guard_or_false and guard_or_true, and statically known true.
  • The issue is related to the line out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask), where the batch size is dynamic and cannot be guarded.
  • To fix the error, insert a check before this call, such as torch._check(x.shape[0] == 1) or torch._check(x.shape[0] != 1).
  • Consider using draft_export() instead of export() to get more information about the error and other potential issues.

Example

if x.shape[0] == 1:
    out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
else:
    # Handle the case where batch size is not 1
    pass

Notes

  • The issue is specific to the CPU export, and the code works on CUDA.
  • The error message provides suggestions for fixing the issue, but the root cause is related to the dynamic batch size.

Recommendation

Apply a workaround by adding a check for the batch size before calling F.scaled_dot_product_attention, as suggested in the error message. This will allow the code to export on CPU while handling the dynamic batch size.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING