pytorch - 💡(How to fix) Fix Use of dynamic shapes impedes Inductor fusion for XSA projection [2 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181108Fetched 2026-04-23 07:22:34
View on GitHub
Comments
2
Participants
1
Timeline
57
Reactions
1
Author
Participants
Timeline (top)
mentioned ×24subscribed ×24labeled ×7commented ×2

Root Cause

The unsqueeze config is the one I actually care about; the expand one showed up because that's what ChatGPT generated for XSA when I asked it to read the paper and then improve its generated XSA projection with GQA support, but it's inefficient because you recompute the norm for each KV-head--it would be nice benchmaxxing to fix but IMO it is generally a bad idea to rely on the compiler to give you asymptotic perf improvements so this would mostly be cleanup for naive users.

Code Example

"""Compare kernel counts: static vs dynamic, unsqueeze vs expand.

Run:
    python test_xsa_mark_dynamic.py
"""

import re
import sys

import torch
import torch.nn.functional as F
from torch._inductor.utils import run_and_get_code


def xsa_unsqueeze(
    output: torch.Tensor,   # (B, T, Hq, D)
    xv: torch.Tensor,       # (B, T, Hkv, D)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    head_dim = xv.shape[-1]
    n_kv = xv.shape[2]
    n_heads = output.shape[2]
    torch._check(output.shape[3] == head_dim)
    torch._check(n_heads % n_kv == 0)
    n_rep = n_heads // n_kv

    vn = F.normalize(xv, dim=-1, eps=1e-6).unsqueeze(3)
    y = output.view(*output.shape[:2], n_kv, n_rep, head_dim)
    ps = (y * vn).sum(dim=-1, keepdim=True)
    cos_sim = ps.squeeze(-1).abs()
    proj_frac = cos_sim / y.norm(dim=-1).clamp(min=1e-6)
    result = (y - ps * vn).view_as(output)
    return result, cos_sim, proj_frac


def xsa_expand(
    output: torch.Tensor,   # (B, T, Hq, D)
    xv: torch.Tensor,       # (B, T, Hkv, D)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    B, T, Hq, D = output.shape
    Hkv = xv.shape[2]
    torch._check(xv.shape[3] == D)
    torch._check(Hq % Hkv == 0)
    gs = Hq // Hkv

    v_exp = xv[:, :, :, None, :].expand(B, T, Hkv, gs, D).reshape(B, T, Hq, D)
    vn = F.normalize(v_exp, p=2, dim=-1, eps=1e-6)
    coeff = (output * vn).sum(dim=-1, keepdim=True)
    projected = output - coeff * vn

    cg = coeff.squeeze(-1).view(B, T, Hkv, gs)
    cos_sim = cg.abs()
    yg = output.view(B, T, Hkv, gs, D)
    proj_frac = cos_sim / yg.norm(dim=-1).clamp(min=1e-6)
    return projected, cos_sim, proj_frac


def run_variant(name, fn, B, T, HQ, HKV, N_REP, D, dynamic):
    device = "cuda"
    dtype = torch.bfloat16
    torch._dynamo.reset()

    output = torch.randn(B, T, HQ, D, device=device, dtype=dtype, requires_grad=True)
    xv = torch.randn(B, T, HKV, D, device=device, dtype=dtype, requires_grad=True)
    g_res = torch.randn(B, T, HQ, D, device=device, dtype=dtype)
    g_cos = torch.randn(B, T, HKV, N_REP, device=device, dtype=dtype)
    g_pf = torch.randn(B, T, HKV, N_REP, device=device, dtype=dtype)

    if dynamic:
        torch._dynamo.mark_dynamic(output, 2)
        torch._dynamo.mark_dynamic(output, 3)
        torch._dynamo.mark_dynamic(xv, 3)
        torch._dynamo.mark_dynamic(g_res, 2)
        torch._dynamo.mark_dynamic(g_res, 3)
        torch._dynamo.mark_dynamic(g_cos, 3)
        torch._dynamo.mark_dynamic(g_pf, 3)

    def fwd_bwd():
        results = fn(output, xv)
        torch.autograd.backward(list(results), [g_res, g_cos, g_pf])

    compiled = torch.compile(fwd_bwd)
    _, codes = run_and_get_code(compiled)

    counts = []
    for i, code in enumerate(codes):
        n = len(re.findall(r"def (triton_\w+)\(", code))
        counts.append(n)
    tag = "dynamic" if dynamic else "static "
    print(f"  {name:20s} {tag}  fwd={counts[0]} bwd={counts[1]}  total={sum(counts)}", file=sys.stderr)


def main():
    assert torch.cuda.is_available()
    B, T, HKV, N_REP, D = 1, 2048, 4, 8, 64
    HQ = HKV * N_REP

    print(f"\n{'variant':20s} {'shapes':8s}  kernels", file=sys.stderr)
    print("-" * 55, file=sys.stderr)

    for fn, name in [(xsa_unsqueeze, "unsqueeze"), (xsa_expand, "expand")]:
        for dynamic in [False, True]:
            run_variant(name, fn, B, T, HQ, HKV, N_REP, D, dynamic)

    import os
    os._exit(0)


if __name__ == "__main__":
    main()

---

variant              shapes    kernels
-------------------------------------------------------
  unsqueeze            static   fwd=2 bwd=4  total=6
  unsqueeze            dynamic  fwd=4 bwd=4  total=8
  expand               static   fwd=1 bwd=2  total=3
  expand               dynamic  fwd=2 bwd=2  total=4
RAW_BUFFERClick to expand / collapse

I found a user who was doing very naughty things with regex'ing out constant integers for constexpr variables in Inductor generated code and I wanted to see if I could do it with dynamic shapes, but alas, it does not work. Here is a repro:

"""Compare kernel counts: static vs dynamic, unsqueeze vs expand.

Run:
    python test_xsa_mark_dynamic.py
"""

import re
import sys

import torch
import torch.nn.functional as F
from torch._inductor.utils import run_and_get_code


def xsa_unsqueeze(
    output: torch.Tensor,   # (B, T, Hq, D)
    xv: torch.Tensor,       # (B, T, Hkv, D)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    head_dim = xv.shape[-1]
    n_kv = xv.shape[2]
    n_heads = output.shape[2]
    torch._check(output.shape[3] == head_dim)
    torch._check(n_heads % n_kv == 0)
    n_rep = n_heads // n_kv

    vn = F.normalize(xv, dim=-1, eps=1e-6).unsqueeze(3)
    y = output.view(*output.shape[:2], n_kv, n_rep, head_dim)
    ps = (y * vn).sum(dim=-1, keepdim=True)
    cos_sim = ps.squeeze(-1).abs()
    proj_frac = cos_sim / y.norm(dim=-1).clamp(min=1e-6)
    result = (y - ps * vn).view_as(output)
    return result, cos_sim, proj_frac


def xsa_expand(
    output: torch.Tensor,   # (B, T, Hq, D)
    xv: torch.Tensor,       # (B, T, Hkv, D)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    B, T, Hq, D = output.shape
    Hkv = xv.shape[2]
    torch._check(xv.shape[3] == D)
    torch._check(Hq % Hkv == 0)
    gs = Hq // Hkv

    v_exp = xv[:, :, :, None, :].expand(B, T, Hkv, gs, D).reshape(B, T, Hq, D)
    vn = F.normalize(v_exp, p=2, dim=-1, eps=1e-6)
    coeff = (output * vn).sum(dim=-1, keepdim=True)
    projected = output - coeff * vn

    cg = coeff.squeeze(-1).view(B, T, Hkv, gs)
    cos_sim = cg.abs()
    yg = output.view(B, T, Hkv, gs, D)
    proj_frac = cos_sim / yg.norm(dim=-1).clamp(min=1e-6)
    return projected, cos_sim, proj_frac


def run_variant(name, fn, B, T, HQ, HKV, N_REP, D, dynamic):
    device = "cuda"
    dtype = torch.bfloat16
    torch._dynamo.reset()

    output = torch.randn(B, T, HQ, D, device=device, dtype=dtype, requires_grad=True)
    xv = torch.randn(B, T, HKV, D, device=device, dtype=dtype, requires_grad=True)
    g_res = torch.randn(B, T, HQ, D, device=device, dtype=dtype)
    g_cos = torch.randn(B, T, HKV, N_REP, device=device, dtype=dtype)
    g_pf = torch.randn(B, T, HKV, N_REP, device=device, dtype=dtype)

    if dynamic:
        torch._dynamo.mark_dynamic(output, 2)
        torch._dynamo.mark_dynamic(output, 3)
        torch._dynamo.mark_dynamic(xv, 3)
        torch._dynamo.mark_dynamic(g_res, 2)
        torch._dynamo.mark_dynamic(g_res, 3)
        torch._dynamo.mark_dynamic(g_cos, 3)
        torch._dynamo.mark_dynamic(g_pf, 3)

    def fwd_bwd():
        results = fn(output, xv)
        torch.autograd.backward(list(results), [g_res, g_cos, g_pf])

    compiled = torch.compile(fwd_bwd)
    _, codes = run_and_get_code(compiled)

    counts = []
    for i, code in enumerate(codes):
        n = len(re.findall(r"def (triton_\w+)\(", code))
        counts.append(n)
    tag = "dynamic" if dynamic else "static "
    print(f"  {name:20s} {tag}  fwd={counts[0]} bwd={counts[1]}  total={sum(counts)}", file=sys.stderr)


def main():
    assert torch.cuda.is_available()
    B, T, HKV, N_REP, D = 1, 2048, 4, 8, 64
    HQ = HKV * N_REP

    print(f"\n{'variant':20s} {'shapes':8s}  kernels", file=sys.stderr)
    print("-" * 55, file=sys.stderr)

    for fn, name in [(xsa_unsqueeze, "unsqueeze"), (xsa_expand, "expand")]:
        for dynamic in [False, True]:
            run_variant(name, fn, B, T, HQ, HKV, N_REP, D, dynamic)

    import os
    os._exit(0)


if __name__ == "__main__":
    main()

On 7231f9e7a302e0368eee7adf8dcbcd6fd79fe2be this prints for me:

variant              shapes    kernels
-------------------------------------------------------
  unsqueeze            static   fwd=2 bwd=4  total=6
  unsqueeze            dynamic  fwd=4 bwd=4  total=8
  expand               static   fwd=1 bwd=2  total=3
  expand               dynamic  fwd=2 bwd=2  total=4

The unsqueeze config is the one I actually care about; the expand one showed up because that's what ChatGPT generated for XSA when I asked it to read the paper and then improve its generated XSA projection with GQA support, but it's inefficient because you recompute the norm for each KV-head--it would be nice benchmaxxing to fix but IMO it is generally a bad idea to rely on the compiler to give you asymptotic perf improvements so this would mostly be cleanup for naive users.

In any case, it's not great that we end up with so many kernels when we make HQ and D symbolic. I'm not enough of a Triton expert to know if it's actually unavoidable if you truly want one kernel. In the original use case, the user was happy to make HQ and D constexpr and force Triton to recompile the kernel at different sizes, but I don't think we have a mode for this in Inductor's codegen, nor is it clear to me that the juice would be worth the squeeze here.

cc @jerryzh168 @chauhang @penguinwu @bobrenjc93 @aditvenk @laithsakka @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The issue can be addressed by exploring ways to reduce the number of kernels generated when using dynamic shapes with Inductor, potentially by recompiling the kernel at different sizes or optimizing the code to minimize kernel creation.

Guidance

  • Investigate the use of torch._dynamo.mark_dynamic to see if there are any alternatives or additional options that can help reduce the number of kernels generated.
  • Consider recompiling the kernel at different sizes using a loop or other control structure to minimize the number of kernels created.
  • Review the Inductor codegen to see if there are any existing modes or options that can help optimize kernel creation for dynamic shapes.
  • Optimize the xsa_unsqueeze and xsa_expand functions to reduce the number of operations and kernel creations.

Example

No specific code example is provided, but the xsa_unsqueeze and xsa_expand functions can be reviewed and optimized to reduce kernel creation.

Notes

The issue is specific to the use of dynamic shapes with Inductor, and the solution may require a deep understanding of the Inductor codegen and kernel creation process. The provided code snippet is a reproducible example, but the solution may require modifications to the Inductor codebase or the use of alternative optimization techniques.

Recommendation

Apply a workaround by optimizing the xsa_unsqueeze and xsa_expand functions to reduce kernel creation, as recompiling the kernel at different sizes may not be feasible or efficient.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix Use of dynamic shapes impedes Inductor fusion for XSA projection [2 comments, 1 participants]