pytorch - ✅(Solved) Fix [inductor] Merge SymPy printer path into typed ops/CSE path for Triton codegen [2 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#184394Fetched 2026-05-20 03:38:59
View on GitHub
Comments
0
Participants
1
Timeline
42
Reactions
0
Author
Participants
Timeline (top)
mentioned ×18subscribed ×18labeled ×5cross-referenced ×1

Fix Action

Fix / Workaround

Path B produces expressions whose dtype is unknown to the rest of the compiler. This causes dtype analysis inconsistencies, unnecessary casts, and requires workarounds like suppressing runtime_triton_dtype_assert.

Step 1: Route index_expr through sympy_interp

  • Add KernelArgs.sizevar_dtypes to track per-symbol dtypes
  • Add TritonKernel.sympy_interp_env() to build typed env from symbol assumptions
  • Rewrite index_expr to call sympy_interp(ops, env, indexing.index)
  • Add missing ops handlers (pow_by_natural, python_mod, sym_sum) so sympy_interp can dispatch all SymPy nodes that appear in index_expr
  • Populate ValueRanges bounds from SymPy symbol assumptions

PR fix notes

PR #182872: [inductor] Fix index_expr CSE variable dtype for ks* kernel args

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • #183740
  • #183938
  • -> #182872

ks* kernel args (symbolic scalars like alpha in torch.add) are always int64 in the Triton signature (per _decide_tl_dtype), but index_expr was setting the CSE variable dtype to the kernel's index_dtype which can be int32 for small tensors. This caused two issues:

  1. When a ks* arg needed casting (e.g., int64 -> float for math.sqrt), the CSE variable's wrong dtype could cause the cast to be skipped.
  2. runtime_triton_dtype_assert would fire incorrect static_assert checks comparing against int32 when the actual type was int64.

As detailed here the internal computation, and dtype analysis of index_expr is very inconsisent.

I am leaving that as a separate problem, and just making index_expr consistent under dtype analysis, by casting the output to the correct kernel indexing dtype.

Authored with Claude.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @azahed98

Changed files

  • benchmarks/dynamo/check_accuracy.py (modified, +2/-0)
  • test/inductor/test_op_dtype_prop.py (modified, +48/-11)
  • torch/_inductor/codegen/common.py (modified, +5/-3)
  • torch/_inductor/codegen/triton.py (modified, +15/-43)

Code Example

# After #182872: correct at boundary, untyped inside
tmp1 = (libdevice.pow(tl.full([], 2.0, tl.float64), (libdevice.floor(
    tl.full([], 1.0, tl.float64) + libdevice.log2((ks0.to(tl.float64)).to(tl.float32))
).to(tl.int32)).to(tl.float64))).to(tl.float32)

---

# Typed path: each node is a verified CSE variable
tmp1 = tl.full([1], 2, tl.int64)           # dtype=int64
tmp2 = ks0.to(tl.float64)                   # dtype=float64
tmp3 = libdevice.log2(tmp2)                  # dtype=float64
tmp4 = tl.full([1], 1.0, tl.float64)        # dtype=float64
tmp5 = tmp4 + tmp3                           # dtype=float64
tmp6 = libdevice.floor(tmp5).to(tl.int32)   # dtype=int32
tmp7 = libdevice.pow(...).to(tl.int64)      # dtype=int64

---

import torch, math
from torch._inductor.utils import run_and_get_code

def fn(x):
    return x + 2 ** (math.floor(math.log2(x.shape[0]) + 1))

x = torch.arange(10, device="cuda", dtype=torch.float32)
out, codes = run_and_get_code(torch.compile(fn, dynamic=True), x)
print(codes[0])  # shows ks0.to(tl.float64)).to(tl.float32) inside libdevice.log2
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

[inductor] Merge SymPy printer path into typed ops/CSE path for Triton codegen

Problem

The Triton backend has two expression-generation paths:

  • Path A: ops.*CSEProxyTritonOverrides
    • tracks dtype on every TritonCSEVariable
    • uses DtypePropagationOpsHandler
    • participates in CSE, bounds, shape, runtime dtype checks, and cast elision
  • Path B: texpr(sympy_expr) → raw string
    • no dtype tracking
    • the printer can emit casts or libdevice calls whose output dtype is not reported to CSE

Path B produces expressions whose dtype is unknown to the rest of the compiler. This causes dtype analysis inconsistencies, unnecessary casts, and requires workarounds like suppressing runtime_triton_dtype_assert.

Ref: https://github.com/liqiangxl/pytorch/pull/9#discussion_r3242279481

Proposal

Incrementally merge Path B into Path A by routing SymPy expressions through sympy_interp(ops, env, expr), which already maps SymPy nodes to ops.* calls. This gives every sub-expression a typed CSE variable and reuses the existing dtype propagation, shape propagation, and CSE infrastructure.

Concrete example: index_expr

ops.index_expr(expr, dtype) is the most prominent user of Path B. It receives a SymPy expression, renders it through the printer into a monolithic string, then stamps the requested dtype on the CSE variable.

PR #182872 fixes the boundary dtype by wrapping the output in an explicit cast. After that fix, the boundary is correct but the internals remain untyped:

# After #182872: correct at boundary, untyped inside
tmp1 = (libdevice.pow(tl.full([], 2.0, tl.float64), (libdevice.floor(
    tl.full([], 1.0, tl.float64) + libdevice.log2((ks0.to(tl.float64)).to(tl.float32))
).to(tl.int32)).to(tl.float64))).to(tl.float32)

Remaining issues:

  1. Double-casts_print_ToFloat and _print_OpaqueUnaryFn_log2 insert conflicting casts (ks0.to(fp64).to(fp32)) that the boundary cast cannot fix
  2. No CSE deduplication — entire expression is one cache key
  3. No per-intermediate dtype tracking — blocks cast elision between adjacent index_expr results

Routing through sympy_interp eliminates all three:

# Typed path: each node is a verified CSE variable
tmp1 = tl.full([1], 2, tl.int64)           # dtype=int64
tmp2 = ks0.to(tl.float64)                   # dtype=float64
tmp3 = libdevice.log2(tmp2)                  # dtype=float64
tmp4 = tl.full([1], 1.0, tl.float64)        # dtype=float64
tmp5 = tmp4 + tmp3                           # dtype=float64
tmp6 = libdevice.floor(tmp5).to(tl.int32)   # dtype=int32
tmp7 = libdevice.pow(...).to(tl.int64)      # dtype=int64

Repro:

import torch, math
from torch._inductor.utils import run_and_get_code

def fn(x):
    return x + 2 ** (math.floor(math.log2(x.shape[0]) + 1))

x = torch.arange(10, device="cuda", dtype=torch.float32)
out, codes = run_and_get_code(torch.compile(fn, dynamic=True), x)
print(codes[0])  # shows ks0.to(tl.float64)).to(tl.float32) inside libdevice.log2

Implementation steps

Step 1: Route index_expr through sympy_interp

  • Add KernelArgs.sizevar_dtypes to track per-symbol dtypes
  • Add TritonKernel.sympy_interp_env() to build typed env from symbol assumptions
  • Rewrite index_expr to call sympy_interp(ops, env, indexing.index)
  • Add missing ops handlers (pow_by_natural, python_mod, sym_sum) so sympy_interp can dispatch all SymPy nodes that appear in index_expr
  • Populate ValueRanges bounds from SymPy symbol assumptions

Step 2: Extend to other texpr call sites

  • rename_indexing precomputed size expressions
  • Range-tree bound expressions
  • Block pointer offset/stride computation

Step 3: Remove Path B from value-producing codegen

  • After Steps 1+2, all value-producing expressions route through Path A (ops.* → typed CSE)
  • TritonPrinter remains for structural uses (block pointer shape/stride arrays, range-tree variable definitions, kernel launch args, debug strings) but no longer produces expressions whose dtype matters for correctness
  • Eliminates the two-path inconsistency: one path for values, one for structure

Alternatives

No response

Additional context

No response

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING