pytorch - 💡(How to fix) Fix torch.export generates unexpected guards on dynamic shapes [1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181097Fetched 2026-04-23 07:22:40
View on GitHub
Comments
1
Participants
2
Timeline
18
Reactions
0
Participants
Timeline (top)
mentioned ×7subscribed ×7labeled ×3commented ×1

Error Message

import torch from torch.export import export, Dim

class MyModel(torch.nn.Module): def init(self): super().init() self.c = 5

def forward(self, x, y):
    x1 = x.view(self.c, -1)
    return [x1[i] - y for i in range(self.c)]

m = MyModel() _dim = Dim("batch", min=1, max=8) ep = export( m, (torch.randn(5*4, ), torch.randn(4, )), dynamic_shapes={ "x": {0: _dim * 5, }, "y": {0: _dim, }, }, ) print(ep)

print(ep.module()(torch.randn(10, ), torch.randn(2, ))) # Suceed print(ep.module()(torch.randn(5, ), torch.randn(1, ))) # Fail due to AssertionError: Guard failed: x.size()[0] // 5 != 1

Fix Action

Fix / Workaround

When tracing the model symbolically, the shape of y starts with [2, 8], which leads to a check on x.size(0) // 5 != 1 and x.size(0) >= 2. Such guards are not wanted. Could these guards be removed in the future? Or is there any workaround during exporting such models?

Code Example

import torch
from torch.export import export, Dim

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.c = 5

    def forward(self, x, y):
        x1 = x.view(self.c, -1)
        return [x1[i] - y for i in range(self.c)]

m = MyModel()
_dim = Dim("batch", min=1, max=8)
ep = export(
    m,
    (torch.randn(5*4, ), torch.randn(4, )),
    dynamic_shapes={
        "x": {0: _dim * 5, },
        "y": {0: _dim, },
    },
)
print(ep)

print(ep.module()(torch.randn(10, ), torch.randn(2, ))) # Suceed
print(ep.module()(torch.randn(5, ), torch.randn(1, ))) # Fail due to AssertionError: Guard failed: x.size()[0] // 5 != 1

---

Traceback (most recent call last):
  File "/home/junyiq/newscratch/spring26/recsys-examples/examples/hstu/./inference/x.py", line 28, in <module>
    print(ep.module()(torch.randn(5, ), torch.randn(1, ))) # Fail
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 936, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 455, in __call__
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 442, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1884, in _call_impl
    return inner()
           ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1832, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.18", line 6, in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py", line 216, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 7, in _
  File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 2249, in _assert
    raise AssertionError(message)
AssertionError: Guard failed: x.size()[0] // 5 != 1
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.export can generate unexpected guards/assertions on dynamic shapes, rejecting dim_size = 1 from the ranges even when the model can generalize to dim_size = 1. This usually happens when the dim range of one input is a constant multiple of another.

Sample for reproduce:

import torch
from torch.export import export, Dim

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.c = 5

    def forward(self, x, y):
        x1 = x.view(self.c, -1)
        return [x1[i] - y for i in range(self.c)]

m = MyModel()
_dim = Dim("batch", min=1, max=8)
ep = export(
    m,
    (torch.randn(5*4, ), torch.randn(4, )),
    dynamic_shapes={
        "x": {0: _dim * 5, },
        "y": {0: _dim, },
    },
)
print(ep)

print(ep.module()(torch.randn(10, ), torch.randn(2, ))) # Suceed
print(ep.module()(torch.randn(5, ), torch.randn(1, ))) # Fail due to AssertionError: Guard failed: x.size()[0] // 5 != 1

Error message:

Traceback (most recent call last):
  File "/home/junyiq/newscratch/spring26/recsys-examples/examples/hstu/./inference/x.py", line 28, in <module>
    print(ep.module()(torch.randn(5, ), torch.randn(1, ))) # Fail
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 936, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 455, in __call__
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 442, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1884, in _call_impl
    return inner()
           ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1832, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.18", line 6, in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py", line 216, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 7, in _
  File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 2249, in _assert
    raise AssertionError(message)
AssertionError: Guard failed: x.size()[0] // 5 != 1

When tracing the model symbolically, the shape of y starts with [2, 8], which leads to a check on x.size(0) // 5 != 1 and x.size(0) >= 2. Such guards are not wanted. Could these guards be removed in the future? Or is there any workaround during exporting such models?

Versions

[pip3] intel-openmp==2021.4.0 [pip3] mkl==2021.1.1 [pip3] mkl-devel==2021.1.1 [pip3] mkl-include==2021.1.1 [pip3] mypy_extensions==1.1.0 [pip3] numpy==2.1.0 [pip3] nvidia-cuda-runtime-cu13==0.0.0a0 [pip3] nvidia-cudnn-frontend==1.18.0 [pip3] nvtx==0.2.14 [pip3] onnx==1.18.0 [pip3] onnx-ir==0.1.16 [pip3] onnxscript==0.6.2 [pip3] optree==0.18.0 [pip3] pytorch-triton==3.6.0+git9844da95.nv26.2 [pip3] tbb==2021.13.1 [pip3] torch==2.11.0a0+eb65b36914.nv26.2 [pip3] torch_tensorrt==2.11.0a0 [pip3] torchao==0.16.0+gita89eaab2 [pip3] torchdata==0.11.0 [pip3] torchmetrics==1.0.3 [pip3] torchrec==1.5.0+4002d8b [pip3] torchtitan==0.2.1+git9f211ec1 [pip3] torchvision==0.25.0a0+1e53952f.nv26.2.44259020 [pip3] torchx==0.7.0 [pip3] triton==3.6.0 [pip3] triton_kernels==1.0.0+git9844da95.nv26.2 [conda] Could not collect

cc @chauhang @penguinwu @avikchaudhuri @zhxchen17 @tugsbayasgalan @angelayi @ydwu4

extent analysis

TL;DR

The issue can be worked around by modifying the dynamic shape configuration to avoid generating unwanted guards/assertions.

Guidance

  • Review the dynamic shape configuration for the model inputs, specifically the multiplication factor used for the x input shape.
  • Consider adjusting the dynamic_shapes dictionary to use a different multiplication factor or to remove the multiplication factor altogether.
  • Verify that the modified dynamic shape configuration does not introduce any new issues or errors.
  • Test the model with different input shapes to ensure that it behaves as expected.

Example

dynamic_shapes={
    "x": {0: _dim, },  # Remove the multiplication factor
    "y": {0: _dim, },
}

Notes

The provided example code snippet is specific to the PyTorch library and its torch.export module. The workaround may not be applicable to other libraries or frameworks.

Recommendation

Apply the workaround by modifying the dynamic shape configuration to avoid generating unwanted guards/assertions. This approach allows for more control over the shape configuration and can help prevent errors caused by unwanted guards.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix torch.export generates unexpected guards on dynamic shapes [1 comments, 2 participants]