pytorch - ✅(Solved) Fix Dynamo `TritonHOPifier.call_run` doesn't propagate kernel_source -> `AssertionError`: Can't construct an AttrSource without a valid base source [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178447Fetched 2026-04-08 01:30:26
View on GitHub
Comments
0
Participants
1
Timeline
81
Reactions
0
Participants
Timeline (top)
mentioned ×34subscribed ×34labeled ×7referenced ×3

Error Message

#!/usr/bin/env python3 """ Reproducer: TritonHOPifier.call_run doesn't propagate kernel_source

AssertionError: Can't construct an AttrSource without a valid base source

"""

import sys import traceback

import torch import torch.nn as nn import triton import triton.language as tl

def early_config_prune(configs, named_args, **kwargs): """No-op config pruner. The existence of prune_configs_by triggers Dynamo's wrap_user_defined_obj, which needs kernel_source.""" return configs

@triton.autotune( configs=[ triton.Config({"BLOCK": 64}, num_warps=4, num_stages=1), triton.Config({"BLOCK": 128}, num_warps=4, num_stages=1), ], key=["N"], prune_configs_by={"early_config_prune": early_config_prune}, ) @triton.jit def add_kernel(X, Y, N, BLOCK: tl.constexpr): """Simple vector add-one kernel.""" pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < N x = tl.load(X + offs, mask=mask) tl.store(Y + offs, x + 1.0, mask=mask)

def kernel_fn(x): """Launch the kernel via .run() instead of grid.

kernel.run() goes through TritonHOPifier.call_run, which creates a new
TritonKernelVariable WITHOUT kernel_source. Then call_triton_kernel hits
prune_configs_by and crashes in wrap_user_defined_obj.

In practice this path is triggered indirectly: kernel[grid]() syntax returns
a lambda that calls .run(), and Dynamo traces into that lambda on
recompilation after a graph break.
"""
N = x.shape[0]
y = torch.empty_like(x)
grid = (triton.cdiv(N, 64),)
add_kernel.run(x, y, N, grid=grid, warmup=False)
return y

class MyModule(nn.Module): def init(self, dim): super().init() self.linear = nn.Linear(dim, dim)

def forward(self, x):
    x = self.linear(x)
    x = kernel_fn(x.view(-1)).view(x.shape)
    return x

def main(): print("=" * 72) print("Dynamo bug: call_run doesn't propagate kernel_source") print("=" * 72) print() print(f"Python: {sys.version.split()[0]}") print(f"PyTorch: {torch.version}") print(f"Triton: {triton.version}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") print()

if not torch.cuda.is_available():
    print("ERROR: No GPU available.")
    sys.exit(1)

device = "cuda"
dim = 64

model = MyModule(dim).to(device).eval()
model.compile(dynamic=None, fullgraph=False)

torch._dynamo.reset()

try:
    with torch.no_grad():
        print("Running compiled model with kernel.run() + prune_configs_by ...")
        x = torch.randn(128, dim, device=device)
        y = model(x)
        print(f"  UNEXPECTED PASS: {y.shape}")
        print("  Bug may be fixed in this PyTorch version.")

except AssertionError as e:
    if "AttrSource" in str(e) or "valid base source" in str(e):
        print()
        print(f"EXPECTED CRASH: {e}")
        sys.exit(0)
    else:
        print(f"UNEXPECTED AssertionError: {e}")
        traceback.print_exc()
        sys.exit(1)

except Exception as e:
    print(f"UNEXPECTED ERROR: {type(e).__name__}: {e}")
    traceback.print_exc()
    sys.exit(1)

if name == "main": main()

Root Cause

When a Triton kernel with @triton.autotune + prune_configs_by is called via run() under torch.compile, Dynamo's TritonHOPifier.call_run creates a new TritonKernelVariable WITHOUT copying kernel_source from the original variable. The subsequent call_triton_kernel hits prune_configs_by, which calls wrap_user_defined_obj, which needs kernel_source to construct an AttrSource. Since kernel_source is None, the assertion fails.

In TritonHOPifier.call_run:

https://github.com/pytorch/pytorch/blob/73cae840f56193235f71c5c4c69f494f49e6d365/torch/_higher_order_ops/triton_kernel_wrap.py#L1866-L1885

def call_run(self, variable, args, kwargs, tx):
    ...
    return self.call_triton_kernel(
        type(variable)(
            kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=grid
        ),  # <-- NO kernel_source!
        args, kwargs, tx,
    )

Compare with DynamoTritonHOPifier.call_getitem:

https://github.com/pytorch/pytorch/blob/73cae840f56193235f71c5c4c69f494f49e6d365/torch/_dynamo/variables/functions.py#L3103-L3121

def call_getitem(self, variable, args):
    return type(variable)(
        kernel=variable.kernel,
        kernel_idx=variable.kernel_idx,
        grid=args[0],
        kernel_source=variable.source,  # <-- propagated
    )

The fix is to propagate kernel_source in call_run (and the other call sites in call_triton_kernel that create new variables.

Fix Action

Fix / Workaround

Workarounds

Traceback (most recent call last): File "dynamo_double_compile_bug.py", line 194, in <module> main() File "dynamo_double_compile_bug.py", line 172, in main y = model(x) ^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1774, in _wrapped_call_impl return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 953, in compile_wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 2202, in call result = self._torchdynamo_orig_backend( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1945, in call result = self._inner_convert( ^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 707, in call result = _compile( ^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1752, in _compile guarded_code, tracer_output = compile_inner(code, one_graph, hooks) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function return function(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1433, in compile_inner return _compile_inner(code, one_graph, hooks) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1467, in _compile_inner dynamo_output = compile_frame( ^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1341, in compile_frame bytecode, tracer_output = transform_code_object(code, transform) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1600, in transform_code_object tracer_output = transformations(instructions, code_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1313, in transform tracer_output = trace_frame( ^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 328, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 838, in trace_frame run_tracer() File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 819, in run_tracer tracer.run() File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1654, in run while self.step(): ^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1334, in step self.dispatch_table[inst.opcode](self, inst) File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 866, in wrapper return inner_fn(self, inst) ^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3807, in CALL self._call(inst) File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3798, in _call self.call_function(fn, args, kwargs) File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1240, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 229, in realize_and_forward return getattr(self.realize(), name)(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 685, in call_function return super().call_function(tx, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 401, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) # type: ignore[attr-defined] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1262, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/dynamo/symbolic_convert.py", line 4718, in inline_call return tracer.inline_call() ^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/dynamo/symbolic_convert.py", line 4935, in inline_call self.run() File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1654, in run while self.step(): ^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1334, in step self.dispatch_table[inst.opcode](self, inst) File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 866, in wrapper return inner_fn(self, inst) ^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3807, in CALL self._call(inst) File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3798, in _call self.call_function(fn, args, kwargs) File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1240, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/misc.py", line 1148, in call_function return self.obj.call_method(tx, self.name, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 2806, in call_method return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) # type: ignore[return-value] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1634, in call_run return self.call_triton_kernel( ^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1839, in call_triton_kernel wrapped_early_configs_prune = self.wrap_user_defined_obj( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 2652, in wrap_user_defined_obj tx, AttrSource(variable.kernel_source, f"{name}") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "<string>", line 5, in init File ".venv/lib/python3.12/site-packages/torch/_dynamo/source.py", line 288, in post_init assert self.base, "Can't construct an AttrSource without a valid base source" ^^^^^^^^^ AssertionError: Can't construct an AttrSource without a valid base source

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 160 On-line CPU(s) list: 0-159 Vendor ID: GenuineIntel Model name: INTEL(R) XEON(R) PLATINUM 8568Y+ CPU family: 6 Model: 207 Thread(s) per core: 1 Core(s) per socket: 80 Socket(s): 2 Stepping: 2 BogoMIPS: 4600.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq dtes64 vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities Virtualization: VT-x Hypervisor vendor: KVM Virtualization type: full L1d cache: 5 MiB (160 instances) L1i cache: 5 MiB (160 instances) L2 cache: 640 MiB (160 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-79 NUMA node1 CPU(s): 80-159 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Unknown: No mitigations Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Mitigation; TSX disabled

PR fix notes

PR #178996: Preserve Triton kernel_source across HOPifier rewrites

Description (problem / solution / changelog)

Fix #178179

Root cause problem

Dynamo only seeded Triton kernel_source at the edge of tracing, but later Triton kernel rewrites recreated wrapper variables without carrying that source forward. Once @triton.autotune(..., prune_configs_by=...) wraps user-defined pruning helpers, wrap_user_defined_obj needs that source to build an AttrSource, and the missing base source triggers the assertion from the issue. The same invariant break also affects the related .run() path from #178447.

Proposed fix

Add a single helper in TritonHOPifier that recreates Triton kernel variables while preserving the existing kernel_source (or the original source when seeding it for the first time). Route .run() and every autotune/heuristics/pruning rewrite through that helper, seed TritonKernelVariable.kernel_source from source at construction, and let TraceableTritonKernelWrapper carry the same optional field so both HOPifier paths share the same interface. Add regression tests for the reported heuristics + autotune + prune combination and for the direct .run() path.

Why this is the right long term fix

The bug is not one specific decorator combination; it is a broken source-tracking invariant. Centralizing Triton variable recreation makes that invariant explicit in one place, which prevents the same omission from reappearing in future rewrite paths and fixes the related .run() variant with the same mechanism.

Drafted via Codex, published after manual review by @bobrenjc93

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @Lucaskabela @azahed98

Changed files

  • test/inductor/test_triton_kernels.py (modified, +67/-0)
  • torch/_dynamo/variables/functions.py (modified, +3/-3)
  • torch/_higher_order_ops/triton_kernel_wrap.py (modified, +59/-7)

Code Example

def call_run(self, variable, args, kwargs, tx):
    ...
    return self.call_triton_kernel(
        type(variable)(
            kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=grid
        ),  # <-- NO kernel_source!
        args, kwargs, tx,
    )

---

def call_getitem(self, variable, args):
    return type(variable)(
        kernel=variable.kernel,
        kernel_idx=variable.kernel_idx,
        grid=args[0],
        kernel_source=variable.source,  # <-- propagated
    )

---

#!/usr/bin/env python3
"""
Reproducer: TritonHOPifier.call_run doesn't propagate kernel_source

    AssertionError: Can't construct an AttrSource without a valid base source
"""

import sys
import traceback

import torch
import torch.nn as nn
import triton
import triton.language as tl


def early_config_prune(configs, named_args, **kwargs):
    """No-op config pruner. The existence of prune_configs_by triggers
    Dynamo's wrap_user_defined_obj, which needs kernel_source."""
    return configs


@triton.autotune(
    configs=[
        triton.Config({"BLOCK": 64}, num_warps=4, num_stages=1),
        triton.Config({"BLOCK": 128}, num_warps=4, num_stages=1),
    ],
    key=["N"],
    prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def add_kernel(X, Y, N, BLOCK: tl.constexpr):
    """Simple vector add-one kernel."""
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < N
    x = tl.load(X + offs, mask=mask)
    tl.store(Y + offs, x + 1.0, mask=mask)


def kernel_fn(x):
    """Launch the kernel via .run() instead of [grid]().

    kernel.run() goes through TritonHOPifier.call_run, which creates a new
    TritonKernelVariable WITHOUT kernel_source. Then call_triton_kernel hits
    prune_configs_by and crashes in wrap_user_defined_obj.

    In practice this path is triggered indirectly: kernel[grid]() syntax returns
    a lambda that calls .run(), and Dynamo traces into that lambda on
    recompilation after a graph break.
    """
    N = x.shape[0]
    y = torch.empty_like(x)
    grid = (triton.cdiv(N, 64),)
    add_kernel.run(x, y, N, grid=grid, warmup=False)
    return y


class MyModule(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)

    def forward(self, x):
        x = self.linear(x)
        x = kernel_fn(x.view(-1)).view(x.shape)
        return x


def main():
    print("=" * 72)
    print("Dynamo bug: call_run doesn't propagate kernel_source")
    print("=" * 72)
    print()
    print(f"Python:  {sys.version.split()[0]}")
    print(f"PyTorch: {torch.__version__}")
    print(f"Triton:  {triton.__version__}")
    if torch.cuda.is_available():
        print(f"GPU:     {torch.cuda.get_device_name(0)}")
    print()

    if not torch.cuda.is_available():
        print("ERROR: No GPU available.")
        sys.exit(1)

    device = "cuda"
    dim = 64

    model = MyModule(dim).to(device).eval()
    model.compile(dynamic=None, fullgraph=False)

    torch._dynamo.reset()

    try:
        with torch.no_grad():
            print("Running compiled model with kernel.run() + prune_configs_by ...")
            x = torch.randn(128, dim, device=device)
            y = model(x)
            print(f"  UNEXPECTED PASS: {y.shape}")
            print("  Bug may be fixed in this PyTorch version.")

    except AssertionError as e:
        if "AttrSource" in str(e) or "valid base source" in str(e):
            print()
            print(f"EXPECTED CRASH: {e}")
            sys.exit(0)
        else:
            print(f"UNEXPECTED AssertionError: {e}")
            traceback.print_exc()
            sys.exit(1)

    except Exception as e:
        print(f"UNEXPECTED ERROR: {type(e).__name__}: {e}")
        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()

---

from user code:
   File "dynamo_double_compile_bug.py", line 140, in forward
    x = kernel_fn(x.view(-1)).view(x.shape)
  File "dynamo_double_compile_bug.py", line 129, in kernel_fn
    add_kernel.run(x, y, N, grid=grid, warmup=False)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Traceback (most recent call last):
  File "dynamo_double_compile_bug.py", line 194, in <module>
    main()
  File "dynamo_double_compile_bug.py", line 172, in main
    y = model(x)
        ^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1774, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 953, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 2202, in __call__
    result = self._torchdynamo_orig_backend(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1945, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 707, in __call__
    result = _compile(
             ^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1752, in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1433, in compile_inner
    return _compile_inner(code, one_graph, hooks)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1467, in _compile_inner
    dynamo_output = compile_frame(
                    ^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1341, in compile_frame
    bytecode, tracer_output = transform_code_object(code, transform)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1600, in transform_code_object
    tracer_output = transformations(instructions, code_options)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1313, in transform
    tracer_output = trace_frame(
                    ^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 838, in trace_frame
    run_tracer()
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 819, in run_tracer
    tracer.run()
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1654, in run
    while self.step():
          ^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1334, in step
    self.dispatch_table[inst.opcode](self, inst)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 866, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3807, in CALL
    self._call(inst)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3798, in _call
    self.call_function(fn, args, kwargs)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1240, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 229, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 685, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 401, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1262, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 4718, in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 4935, in inline_call_
    self.run()
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1654, in run
    while self.step():
          ^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1334, in step
    self.dispatch_table[inst.opcode](self, inst)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 866, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3807, in CALL
    self._call(inst)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3798, in _call
    self.call_function(fn, args, kwargs)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1240, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/misc.py", line 1148, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 2806, in call_method
    return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1634, in call_run
    return self.call_triton_kernel(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1839, in call_triton_kernel
    wrapped_early_configs_prune = self.wrap_user_defined_obj(
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 2652, in wrap_user_defined_obj
    tx, AttrSource(variable.kernel_source, f"{name}")
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 5, in __init__
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/source.py", line 288, in __post_init__
    assert self.base, "Can't construct an AttrSource without a valid base source"
           ^^^^^^^^^
AssertionError: Can't construct an AttrSource without a valid base source

from user code:
   File "dynamo_double_compile_bug.py", line 140, in forward
    x = kernel_fn(x.view(-1)).view(x.shape)
  File "dynamo_double_compile_bug.py", line 129, in kernel_fn
    add_kernel.run(x, y, N, grid=grid, warmup=False)

---

Collecting environment information...
PyTorch version: 2.10.0+rocm7.1
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 7.1.25424

OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.3 (main, Jan  8 2026, 11:30:50) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-85-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: AMD Instinct MI300X VF (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: 7.1.25424
MIOpen runtime version: 3.5.1
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               160
On-line CPU(s) list:                  0-159
Vendor ID:                            GenuineIntel
Model name:                           INTEL(R) XEON(R) PLATINUM 8568Y+
CPU family:                           6
Model:                                207
Thread(s) per core:                   1
Core(s) per socket:                   80
Socket(s):                            2
Stepping:                             2
BogoMIPS:                             4600.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq dtes64 vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            5 MiB (160 instances)
L1i cache:                            5 MiB (160 instances)
L2 cache:                             640 MiB (160 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-79
NUMA node1 CPU(s):                    80-159
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.5.0
[pip3] torch==2.10.0+rocm7.1
[pip3] torchmetrics==1.8.2
[pip3] torchvision==0.25.0+rocm7.1
[pip3] triton==3.6.0
[conda] Could not collect
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When using torch.compile with a Triton kernel that had @triton.autotune with prune_configs_by used in two separate modules that were both compiled and had graph breaks, I hit AssertionError: Can't construct an AttrSource without a valid base source.

I asked Claude Code to root cause and create a minimal reproducer and this is its analysis plus reproducer. I haven't fact-checked everything, but it seems right and the reproducer does reproduce the same error.

Root cause

When a Triton kernel with @triton.autotune + prune_configs_by is called via run() under torch.compile, Dynamo's TritonHOPifier.call_run creates a new TritonKernelVariable WITHOUT copying kernel_source from the original variable. The subsequent call_triton_kernel hits prune_configs_by, which calls wrap_user_defined_obj, which needs kernel_source to construct an AttrSource. Since kernel_source is None, the assertion fails.

In TritonHOPifier.call_run:

https://github.com/pytorch/pytorch/blob/73cae840f56193235f71c5c4c69f494f49e6d365/torch/_higher_order_ops/triton_kernel_wrap.py#L1866-L1885

def call_run(self, variable, args, kwargs, tx):
    ...
    return self.call_triton_kernel(
        type(variable)(
            kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=grid
        ),  # <-- NO kernel_source!
        args, kwargs, tx,
    )

Compare with DynamoTritonHOPifier.call_getitem:

https://github.com/pytorch/pytorch/blob/73cae840f56193235f71c5c4c69f494f49e6d365/torch/_dynamo/variables/functions.py#L3103-L3121

def call_getitem(self, variable, args):
    return type(variable)(
        kernel=variable.kernel,
        kernel_idx=variable.kernel_idx,
        grid=args[0],
        kernel_source=variable.source,  # <-- propagated
    )

The fix is to propagate kernel_source in call_run (and the other call sites in call_triton_kernel that create new variables.

How this manifest in practice

The reproducer uses run directly rather than kernel[grid](*args) to enable a minimal reproducer. This issue was hit in the wild with a kernel[grid](*args) call but it was really hard to turn that into a minimal reproducer. The .run() path was triggered because:

  1. kernel[grid] calls kernel.__getitem__(grid), which returns a lambda: lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  2. Normally, Dynamo intercepts kernel[grid](*args) as __getitem__ + __call__, routing through call_getitem (which propagates kernel_source).
  3. But when the kernel call is in a Dynamo "resume function" (created after a graph break from @torch.compiler.disable), Dynamo sometimes traces INTO the __getitem__ lambda and sees self.run(), routing through call_run instead.
  4. This only happens on recompilation (batch 2 with different shapes), because the resume function's compilation context differs from the original.

Using kernel.run() directly triggers the same underlying bug without needing the full complexity.

Environment

  • PyTorch 2.10.0+rocm7.1 (also likely affects CUDA builds)
  • Triton 3.6.0
  • AMD MI300X (not hardware-specific)

Related

  • The "ListVariable already tracked for mutation" bug (https://github.com/pytorch/pytorch/issues/177600) is at the same crash site but has a different root cause (._wrap vs __call__).
  • The @triton.heuristics + prune_configs_by bug (https://github.com/pytorch/pytorch/issues/178179) is also at the same site but caused by Heuristics unwrapping dropping kernel_source.
  • All three bugs share the same underlying issue: kernel_source is not consistently propagated when creating new TritonKernelVariable instances.

Workarounds

  1. Remove prune_configs_by from @triton.autotune.
  2. Add @torch.compiler.disable to the function containing the kernel call.
<details><summary>repro</summary>
#!/usr/bin/env python3
"""
Reproducer: TritonHOPifier.call_run doesn't propagate kernel_source

    AssertionError: Can't construct an AttrSource without a valid base source
"""

import sys
import traceback

import torch
import torch.nn as nn
import triton
import triton.language as tl


def early_config_prune(configs, named_args, **kwargs):
    """No-op config pruner. The existence of prune_configs_by triggers
    Dynamo's wrap_user_defined_obj, which needs kernel_source."""
    return configs


@triton.autotune(
    configs=[
        triton.Config({"BLOCK": 64}, num_warps=4, num_stages=1),
        triton.Config({"BLOCK": 128}, num_warps=4, num_stages=1),
    ],
    key=["N"],
    prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def add_kernel(X, Y, N, BLOCK: tl.constexpr):
    """Simple vector add-one kernel."""
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < N
    x = tl.load(X + offs, mask=mask)
    tl.store(Y + offs, x + 1.0, mask=mask)


def kernel_fn(x):
    """Launch the kernel via .run() instead of [grid]().

    kernel.run() goes through TritonHOPifier.call_run, which creates a new
    TritonKernelVariable WITHOUT kernel_source. Then call_triton_kernel hits
    prune_configs_by and crashes in wrap_user_defined_obj.

    In practice this path is triggered indirectly: kernel[grid]() syntax returns
    a lambda that calls .run(), and Dynamo traces into that lambda on
    recompilation after a graph break.
    """
    N = x.shape[0]
    y = torch.empty_like(x)
    grid = (triton.cdiv(N, 64),)
    add_kernel.run(x, y, N, grid=grid, warmup=False)
    return y


class MyModule(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)

    def forward(self, x):
        x = self.linear(x)
        x = kernel_fn(x.view(-1)).view(x.shape)
        return x


def main():
    print("=" * 72)
    print("Dynamo bug: call_run doesn't propagate kernel_source")
    print("=" * 72)
    print()
    print(f"Python:  {sys.version.split()[0]}")
    print(f"PyTorch: {torch.__version__}")
    print(f"Triton:  {triton.__version__}")
    if torch.cuda.is_available():
        print(f"GPU:     {torch.cuda.get_device_name(0)}")
    print()

    if not torch.cuda.is_available():
        print("ERROR: No GPU available.")
        sys.exit(1)

    device = "cuda"
    dim = 64

    model = MyModule(dim).to(device).eval()
    model.compile(dynamic=None, fullgraph=False)

    torch._dynamo.reset()

    try:
        with torch.no_grad():
            print("Running compiled model with kernel.run() + prune_configs_by ...")
            x = torch.randn(128, dim, device=device)
            y = model(x)
            print(f"  UNEXPECTED PASS: {y.shape}")
            print("  Bug may be fixed in this PyTorch version.")

    except AssertionError as e:
        if "AttrSource" in str(e) or "valid base source" in str(e):
            print()
            print(f"EXPECTED CRASH: {e}")
            sys.exit(0)
        else:
            print(f"UNEXPECTED AssertionError: {e}")
            traceback.print_exc()
            sys.exit(1)

    except Exception as e:
        print(f"UNEXPECTED ERROR: {type(e).__name__}: {e}")
        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()
</details>

Error logs

<details><summary>Error logs</summary>
from user code:
   File "dynamo_double_compile_bug.py", line 140, in forward
    x = kernel_fn(x.view(-1)).view(x.shape)
  File "dynamo_double_compile_bug.py", line 129, in kernel_fn
    add_kernel.run(x, y, N, grid=grid, warmup=False)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Traceback (most recent call last):
  File "dynamo_double_compile_bug.py", line 194, in <module>
    main()
  File "dynamo_double_compile_bug.py", line 172, in main
    y = model(x)
        ^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1774, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 953, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 2202, in __call__
    result = self._torchdynamo_orig_backend(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1945, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 707, in __call__
    result = _compile(
             ^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1752, in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1433, in compile_inner
    return _compile_inner(code, one_graph, hooks)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1467, in _compile_inner
    dynamo_output = compile_frame(
                    ^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1341, in compile_frame
    bytecode, tracer_output = transform_code_object(code, transform)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1600, in transform_code_object
    tracer_output = transformations(instructions, code_options)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1313, in transform
    tracer_output = trace_frame(
                    ^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 838, in trace_frame
    run_tracer()
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 819, in run_tracer
    tracer.run()
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1654, in run
    while self.step():
          ^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1334, in step
    self.dispatch_table[inst.opcode](self, inst)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 866, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3807, in CALL
    self._call(inst)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3798, in _call
    self.call_function(fn, args, kwargs)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1240, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 229, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 685, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 401, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1262, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 4718, in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 4935, in inline_call_
    self.run()
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1654, in run
    while self.step():
          ^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1334, in step
    self.dispatch_table[inst.opcode](self, inst)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 866, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3807, in CALL
    self._call(inst)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3798, in _call
    self.call_function(fn, args, kwargs)
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1240, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/misc.py", line 1148, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 2806, in call_method
    return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1634, in call_run
    return self.call_triton_kernel(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1839, in call_triton_kernel
    wrapped_early_configs_prune = self.wrap_user_defined_obj(
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 2652, in wrap_user_defined_obj
    tx, AttrSource(variable.kernel_source, f"{name}")
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 5, in __init__
  File ".venv/lib/python3.12/site-packages/torch/_dynamo/source.py", line 288, in __post_init__
    assert self.base, "Can't construct an AttrSource without a valid base source"
           ^^^^^^^^^
AssertionError: Can't construct an AttrSource without a valid base source

from user code:
   File "dynamo_double_compile_bug.py", line 140, in forward
    x = kernel_fn(x.view(-1)).view(x.shape)
  File "dynamo_double_compile_bug.py", line 129, in kernel_fn
    add_kernel.run(x, y, N, grid=grid, warmup=False)
</details>

Versions

<details><summary>env</summary>
Collecting environment information...
PyTorch version: 2.10.0+rocm7.1
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 7.1.25424

OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.3 (main, Jan  8 2026, 11:30:50) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-85-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: AMD Instinct MI300X VF (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: 7.1.25424
MIOpen runtime version: 3.5.1
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               160
On-line CPU(s) list:                  0-159
Vendor ID:                            GenuineIntel
Model name:                           INTEL(R) XEON(R) PLATINUM 8568Y+
CPU family:                           6
Model:                                207
Thread(s) per core:                   1
Core(s) per socket:                   80
Socket(s):                            2
Stepping:                             2
BogoMIPS:                             4600.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq dtes64 vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            5 MiB (160 instances)
L1i cache:                            5 MiB (160 instances)
L2 cache:                             640 MiB (160 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-79
NUMA node1 CPU(s):                    80-159
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.5.0
[pip3] torch==2.10.0+rocm7.1
[pip3] torchmetrics==1.8.2
[pip3] torchvision==0.25.0+rocm7.1
[pip3] triton==3.6.0
[conda] Could not collect
</summary>

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @Lucaskabela @jataylo @oulgen @aakhundov @davidberard98

extent analysis

Fix Plan

To fix the issue, we need to propagate kernel_source in TritonHOPifier.call_run.

Here are the steps:

  • Modify the call_run method in TritonHOPifier to include kernel_source when creating a new TritonKernelVariable.
  • Update the call_triton_kernel method to handle the new kernel_source attribute.

Example code:

def call_run(self, variable, args, kwargs, tx):
    ...
    return self.call_triton_kernel(
        type(variable)(
            kernel=variable.kernel, 
            kernel_idx=variable.kernel_idx, 
            grid=grid, 
            kernel_source=variable.kernel_source  # <-- Add this line
        ), 
        args, kwargs, tx,
    )

Verification

To verify that the fix worked, run the provided reproducer code again. If the fix is correct, the AssertionError should be resolved, and the code should execute without errors.

Extra Tips

  • Make sure to update the TritonHOPifier class with the modified call_run method.
  • If you're using a version of PyTorch that doesn't have the TritonHOPifier class, you may need to update your PyTorch version or apply the fix manually.
  • This fix assumes that the kernel_source attribute is available in the variable object. If this attribute is not available, you may need to modify the fix accordingly.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix Dynamo `TritonHOPifier.call_run` doesn't propagate kernel_source -> `AssertionError`: Can't construct an AttrSource without a valid base source [1 pull requests, 1 participants]