pytorch - 💡(How to fix) Fix [NestedTensor] torch.compile fails when accessing _max_seqlen with GuardOnDataDependentSymNode

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

#!/usr/bin/env python3

-- coding: utf-8 --

import os import platform import traceback

import torch from torch.nested._internal.nested_tensor import NestedTensor

def print_env(): print("Python:", platform.python_version()) print("Platform:", platform.platform()) print("PyTorch:", torch.version) print("CUDA available:", torch.cuda.is_available()) print("CUDA device count:", torch.cuda.device_count()) print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES", "")) if torch.cuda.is_available(): print("Current CUDA device:", torch.cuda.current_device()) print("CUDA device name:", torch.cuda.get_device_name(0))

def create_njt(el_per_row: int): # Do not access nt._max_seqlen here. # Accessing it before torch.compile may cache the value and hide the issue. values = torch.randn(10 * el_per_row, device="cuda") offsets = torch.arange(11, device="cuda") * el_per_row return NestedTensor(values, offsets)

def f(nt): nt = nt.clamp(0.1, 0.5) nt *= nt._max_seqlen return nt

def get_values(nt): return nt.values()

def main(): print_env()

if not torch.cuda.is_available():
    raise RuntimeError("This repro expects CUDA.")

torch.manual_seed(0)

eager_input = create_njt(el_per_row=2)
compiled_input = create_njt(el_per_row=2)

print("\nRunning eager...")
with torch.no_grad():
    eager_out = f(eager_input)
    eager_values = get_values(eager_out)
print("Eager succeeded.")
print("Eager output sum:", eager_values.sum().item())

print("\nRunning compiled...")
compiled_f = torch.compile(
    f,
    backend="inductor",
    fullgraph=True,
    dynamic=True,
)

try:
    with torch.no_grad():
        compiled_out = compiled_f(compiled_input)
        compiled_values = get_values(compiled_out)

    print("Compiled succeeded.")
    print("Compiled output sum:", compiled_values.sum().item())

    torch.testing.assert_close(compiled_values, eager_values)
    print("Compiled output matches eager.")

except Exception:
    print("Compiled failed:")
    traceback.print_exc()
    raise

if name == "main": main()

Code Example

torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
Could not extract specialized integer from data-dependent expression u0

---

nt *= nt._max_seqlen

---

torch/nested/_internal/nested_tensor.py", line 237, in _max_seqlen
    return self._get_max_seqlen()

torch/nested/_internal/nested_tensor.py", line 194, in _get_max_seqlen
    max_val = _get_sdpa_extreme_seqlen(...)

torch/nested/_internal/nested_tensor.py", line 36, in _get_sdpa_extreme_seqlen
    return int(func(tensor).item())

---

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import platform
import traceback

import torch
from torch.nested._internal.nested_tensor import NestedTensor


def print_env():
    print("Python:", platform.python_version())
    print("Platform:", platform.platform())
    print("PyTorch:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    print("CUDA device count:", torch.cuda.device_count())
    print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES", ""))
    if torch.cuda.is_available():
        print("Current CUDA device:", torch.cuda.current_device())
        print("CUDA device name:", torch.cuda.get_device_name(0))


def create_njt(el_per_row: int):
    # Do not access nt._max_seqlen here.
    # Accessing it before torch.compile may cache the value and hide the issue.
    values = torch.randn(10 * el_per_row, device="cuda")
    offsets = torch.arange(11, device="cuda") * el_per_row
    return NestedTensor(values, offsets)


def f(nt):
    nt = nt.clamp(0.1, 0.5)
    nt *= nt._max_seqlen
    return nt


def get_values(nt):
    return nt.values()


def main():
    print_env()

    if not torch.cuda.is_available():
        raise RuntimeError("This repro expects CUDA.")

    torch.manual_seed(0)

    eager_input = create_njt(el_per_row=2)
    compiled_input = create_njt(el_per_row=2)

    print("\nRunning eager...")
    with torch.no_grad():
        eager_out = f(eager_input)
        eager_values = get_values(eager_out)
    print("Eager succeeded.")
    print("Eager output sum:", eager_values.sum().item())

    print("\nRunning compiled...")
    compiled_f = torch.compile(
        f,
        backend="inductor",
        fullgraph=True,
        dynamic=True,
    )

    try:
        with torch.no_grad():
            compiled_out = compiled_f(compiled_input)
            compiled_values = get_values(compiled_out)

        print("Compiled succeeded.")
        print("Compiled output sum:", compiled_values.sum().item())

        torch.testing.assert_close(compiled_values, eager_values)
        print("Compiled output matches eager.")

    except Exception:
        print("Compiled failed:")
        traceback.print_exc()
        raise


if __name__ == "__main__":
    main()

---

Running eager...
Eager succeeded.
Eager output sum: 9.313737869262695

---

torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
Could not extract specialized integer from data-dependent expression u0

---

File "repro.py", line 35, in f
    nt *= nt._max_seqlen

File "torch/nested/_internal/nested_tensor.py", line 237, in _max_seqlen
    return self._get_max_seqlen()

File "torch/nested/_internal/nested_tensor.py", line 194, in _get_max_seqlen
    max_val = _get_sdpa_extreme_seqlen(...)

File "torch/nested/_internal/nested_tensor.py", line 36, in _get_sdpa_extreme_seqlen
    return int(func(tensor).item())

---

PyTorch version:  2.13.0a0+git059c270
Is debug build: True
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
Is CUDA available: True
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile fails when a function accesses NestedTensor._max_seqlen inside the compiled region.

Eager execution succeeds, but the compiled function fails during Dynamo tracing with:

torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
Could not extract specialized integer from data-dependent expression u0

The failure is triggered by:

nt *= nt._max_seqlen

The stack trace shows that _max_seqlen internally calls _get_sdpa_extreme_seqlen, which attempts to convert a symbolic value to a Python integer:

torch/nested/_internal/nested_tensor.py", line 237, in _max_seqlen
    return self._get_max_seqlen()

torch/nested/_internal/nested_tensor.py", line 194, in _get_max_seqlen
    max_val = _get_sdpa_extreme_seqlen(...)

torch/nested/_internal/nested_tensor.py", line 36, in _get_sdpa_extreme_seqlen
    return int(func(tensor).item())

The important detail is that _max_seqlen is not accessed before calling the compiled function. If _max_seqlen is accessed before compilation, the value may appear to be cached and the failure can be hidden.

Minimal repro

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import platform
import traceback

import torch
from torch.nested._internal.nested_tensor import NestedTensor


def print_env():
    print("Python:", platform.python_version())
    print("Platform:", platform.platform())
    print("PyTorch:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    print("CUDA device count:", torch.cuda.device_count())
    print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES", ""))
    if torch.cuda.is_available():
        print("Current CUDA device:", torch.cuda.current_device())
        print("CUDA device name:", torch.cuda.get_device_name(0))


def create_njt(el_per_row: int):
    # Do not access nt._max_seqlen here.
    # Accessing it before torch.compile may cache the value and hide the issue.
    values = torch.randn(10 * el_per_row, device="cuda")
    offsets = torch.arange(11, device="cuda") * el_per_row
    return NestedTensor(values, offsets)


def f(nt):
    nt = nt.clamp(0.1, 0.5)
    nt *= nt._max_seqlen
    return nt


def get_values(nt):
    return nt.values()


def main():
    print_env()

    if not torch.cuda.is_available():
        raise RuntimeError("This repro expects CUDA.")

    torch.manual_seed(0)

    eager_input = create_njt(el_per_row=2)
    compiled_input = create_njt(el_per_row=2)

    print("\nRunning eager...")
    with torch.no_grad():
        eager_out = f(eager_input)
        eager_values = get_values(eager_out)
    print("Eager succeeded.")
    print("Eager output sum:", eager_values.sum().item())

    print("\nRunning compiled...")
    compiled_f = torch.compile(
        f,
        backend="inductor",
        fullgraph=True,
        dynamic=True,
    )

    try:
        with torch.no_grad():
            compiled_out = compiled_f(compiled_input)
            compiled_values = get_values(compiled_out)

        print("Compiled succeeded.")
        print("Compiled output sum:", compiled_values.sum().item())

        torch.testing.assert_close(compiled_values, eager_values)
        print("Compiled output matches eager.")

    except Exception:
        print("Compiled failed:")
        traceback.print_exc()
        raise


if __name__ == "__main__":
    main()

Actual behavior

Eager execution succeeds:

Running eager...
Eager succeeded.
Eager output sum: 9.313737869262695

Compiled execution fails:

Running compiled... Compiled failed:

torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
Could not extract specialized integer from data-dependent expression u0

Relevant stack trace:

File "repro.py", line 35, in f
    nt *= nt._max_seqlen

File "torch/nested/_internal/nested_tensor.py", line 237, in _max_seqlen
    return self._get_max_seqlen()

File "torch/nested/_internal/nested_tensor.py", line 194, in _get_max_seqlen
    max_val = _get_sdpa_extreme_seqlen(...)

File "torch/nested/_internal/nested_tensor.py", line 36, in _get_sdpa_extreme_seqlen
    return int(func(tensor).item())

Versions

PyTorch version:  2.13.0a0+git059c270
Is debug build: True
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
Is CUDA available: True

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @chauhang @penguinwu @ezyang @bobrenjc93 @aditvenk @laithsakka @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @jataylo @azahed98

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [NestedTensor] torch.compile fails when accessing _max_seqlen with GuardOnDataDependentSymNode