pytorch - ✅(Solved) Fix Fusable GQA-style QKV matmuls are overlooked [1 pull requests, 2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178387Fetched 2026-04-08 01:30:44
View on GitHub
Comments
2
Participants
2
Timeline
83
Reactions
0
Author
Assignees
Timeline (top)
subscribed ×31mentioned ×30labeled ×8referenced ×6

Fix Action

Fix / Workaround

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 32 On-line CPU(s) list: 0-31 Vendor ID: AuthenticAMD Model name: AMD RYZEN AI MAX+ 395 w/ Radeon 8060S CPU family: 26 Model: 112 Thread(s) per core: 2 Core(s) per socket: 16 Socket(s): 1 Stepping: 0 Frequency boost: enabled CPU(s) scaling MHz: 50% CPU max MHz: 5187.0000 CPU min MHz: 599.0000 BogoMIPS: 5988.24 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d amd_lbr_pmc_freeze Virtualization: AMD-V L1d cache: 768 KiB (16 instances) L1i cache: 512 KiB (16 instances) L2 cache: 16 MiB (16 instances) L3 cache: 64 MiB (2 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-31 Vulnerability Gather data sampling: Not affected Vulnerability Ghostwrite: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; IBPB on VMEXIT only Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsa: Not affected Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Mitigation; IBPB on VMEXIT

PR fix notes

PR #178523: Relax concat-linear fusion to support GQA QKV

Description (problem / solution / changelog)

The freezing concat-linear optimization fuses independent matmuls that share the same input (e.g., QKV projections) into a single GEMM. It previously required all weights to have identical shapes, preventing fusion of GQA-style attention where Q has a larger output dimension than K/V (e.g., Q=960, K=V=320).

Relax the shape check to require only that non-concatenated dimensions match (shape[:-1]), allowing weights and biases to differ in the last dimension. Use split (which supports unequal sizes) instead of chunk (which requires equal sizes) to split the fused result.

Fixes #178387, "Fusable GQA-style QKV matmuls are overlooked."

Using the reproducer script from the ticket with commit c0c089817e69e999ee07b5efa0396c1cc5b9c655,

$ LD_LIBRARY_PATH="$(rocm-sdk path --root)/lib:${LD_LIBRARY_PATH}" python ./repro_concat_linear.py
PASS: 1 matmul(s) in generated code

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Changed files

  • test/inductor/test_inductor_freezing.py (modified, +2/-4)
  • torch/_inductor/fx_passes/freezing_patterns.py (modified, +10/-7)

Code Example

"""Reproducer: does TorchInductor concat-linear fuse GQA-style QKV projections?
The submitter ran this in a ROCm venv:

    LD_LIBRARY_PATH="$(rocm-sdk path --root)/lib:${LD_LIBRARY_PATH}" \
        python ./repro_concat_linear.py
"""
import re
import torch
import torch._inductor.config as inductor_config
from torch._inductor.utils import run_and_get_code

inductor_config.freezing = True

class GQA_QKV(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.q_proj = torch.nn.Linear(720, 960, bias=False)
        self.k_proj = torch.nn.Linear(720, 320, bias=False)
        self.v_proj = torch.nn.Linear(720, 320, bias=False)

    def forward(self, x):
        return self.q_proj(x), self.k_proj(x), self.v_proj(x)

model = GQA_QKV().cuda().bfloat16().eval()
x = torch.randn(50, 720, device="cuda", dtype=torch.bfloat16)

compiled = torch.compile(model, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
    _, code = run_and_get_code(compiled, x)

# Extract the "def call" body from the generated code and count mm kernel
# launches.  Fully fused: 1 launch.  Partially fused (pre-existing): 2.
call_body = re.search(r"def call\(self.*?return .*?\n", code[0], re.DOTALL).group()
mm_calls = re.findall(r"\.run\(", call_body)
fused = len(mm_calls) == 1
print(f"{'PASS' if fused else 'FAIL'}: {len(mm_calls)} matmul(s) in generated code")

---

FAIL: 3 matmul(s) in generated code

---

PASS: 1 matmul(s) in generated code

---

Collecting environment information...
PyTorch version: 2.12.0a0+gitb19e5a7
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 7.13.60800

OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.31.10
Libc version: glibc-2.39

Python version: 3.12.3 (main, Jan  8 2026, 11:30:50) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.14.0-1019-oem-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: 
GPU models and configuration: Radeon 8060S Graphics (gfx1151)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: 7.13.60800
MIOpen runtime version: 3.5.1
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           48 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  32
On-line CPU(s) list:                     0-31
Vendor ID:                               AuthenticAMD
Model name:                              AMD RYZEN AI MAX+ 395 w/ Radeon 8060S
CPU family:                              26
Model:                                   112
Thread(s) per core:                      2
Core(s) per socket:                      16
Socket(s):                               1
Stepping:                                0
Frequency boost:                         enabled
CPU(s) scaling MHz:                      50%
CPU max MHz:                             5187.0000
CPU min MHz:                             599.0000
BogoMIPS:                                5988.24
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d amd_lbr_pmc_freeze
Virtualization:                          AMD-V
L1d cache:                               768 KiB (16 instances)
L1i cache:                               512 KiB (16 instances)
L2 cache:                                16 MiB (16 instances)
L3 cache:                                64 MiB (2 instances)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-31
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Mitigation; IBPB on VMEXIT only
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB on VMEXIT

Versions of relevant libraries:
[pip3] numpy==2.4.3
[pip3] optree==0.19.0
[pip3] torch==2.12.0a0+gitb19e5a7
[pip3] triton==3.6.0+rocm7.13.0a20260324
[conda] Could not collect
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

TorchInductor's freezing_patterns.py has a concat-linear optimization that fuses independent matmuls sharing the same input (e.g. QKV projections, gate+up in MLP) into a single larger GEMM. This reduces kernel launch overhead, which is critical for small-GEMM-dominated VLA models like SmolVLA on Strix Halo (gfx1151).

Currently fused (when inductor_config.freezing = True):

  • gate+up projections with same output dim (e.g. 2 x (720->2048) > 1 x (720>4096))
  • QKV where all three have the same output dim (e.g. 3 x (768->768) > 1 x (768>2304))

Not fused due to same-shape requirement in check_concat_weights():

  • GQA-style QKV where Q has different output dim than K/V (e.g. Q=960, K=320, V=320)
"""Reproducer: does TorchInductor concat-linear fuse GQA-style QKV projections?
The submitter ran this in a ROCm venv:

    LD_LIBRARY_PATH="$(rocm-sdk path --root)/lib:${LD_LIBRARY_PATH}" \
        python ./repro_concat_linear.py
"""
import re
import torch
import torch._inductor.config as inductor_config
from torch._inductor.utils import run_and_get_code

inductor_config.freezing = True

class GQA_QKV(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.q_proj = torch.nn.Linear(720, 960, bias=False)
        self.k_proj = torch.nn.Linear(720, 320, bias=False)
        self.v_proj = torch.nn.Linear(720, 320, bias=False)

    def forward(self, x):
        return self.q_proj(x), self.k_proj(x), self.v_proj(x)

model = GQA_QKV().cuda().bfloat16().eval()
x = torch.randn(50, 720, device="cuda", dtype=torch.bfloat16)

compiled = torch.compile(model, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
    _, code = run_and_get_code(compiled, x)

# Extract the "def call" body from the generated code and count mm kernel
# launches.  Fully fused: 1 launch.  Partially fused (pre-existing): 2.
call_body = re.search(r"def call\(self.*?return .*?\n", code[0], re.DOTALL).group()
mm_calls = re.findall(r"\.run\(", call_body)
fused = len(mm_calls) == 1
print(f"{'PASS' if fused else 'FAIL'}: {len(mm_calls)} matmul(s) in generated code")

Output at b19e5a7f7329dadcf639fd6fc53f934588193595:

FAIL: 3 matmul(s) in generated code

Expected output:

PASS: 1 matmul(s) in generated code

Versions

Environment gathered with curl -sL https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py | LD_LIBRARY_PATH="$(rocm-sdk path --root)/lib:${LD_LIBRARY_PATH}" python.

Collecting environment information...
PyTorch version: 2.12.0a0+gitb19e5a7
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 7.13.60800

OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.31.10
Libc version: glibc-2.39

Python version: 3.12.3 (main, Jan  8 2026, 11:30:50) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.14.0-1019-oem-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: 
GPU models and configuration: Radeon 8060S Graphics (gfx1151)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: 7.13.60800
MIOpen runtime version: 3.5.1
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           48 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  32
On-line CPU(s) list:                     0-31
Vendor ID:                               AuthenticAMD
Model name:                              AMD RYZEN AI MAX+ 395 w/ Radeon 8060S
CPU family:                              26
Model:                                   112
Thread(s) per core:                      2
Core(s) per socket:                      16
Socket(s):                               1
Stepping:                                0
Frequency boost:                         enabled
CPU(s) scaling MHz:                      50%
CPU max MHz:                             5187.0000
CPU min MHz:                             599.0000
BogoMIPS:                                5988.24
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d amd_lbr_pmc_freeze
Virtualization:                          AMD-V
L1d cache:                               768 KiB (16 instances)
L1i cache:                               512 KiB (16 instances)
L2 cache:                                16 MiB (16 instances)
L3 cache:                                64 MiB (2 instances)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-31
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Mitigation; IBPB on VMEXIT only
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB on VMEXIT

Versions of relevant libraries:
[pip3] numpy==2.4.3
[pip3] optree==0.19.0
[pip3] torch==2.12.0a0+gitb19e5a7
[pip3] triton==3.6.0+rocm7.13.0a20260324
[conda] Could not collect

cc @jerryzh168 @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben

extent analysis

Fix Plan

To fix the issue of TorchInductor's freezing_patterns.py not fusing GQA-style QKV projections, we need to modify the check_concat_weights() function to relax the same-shape requirement.

Here are the steps:

  • Modify the check_concat_weights() function to allow concatenation of weights with different shapes.
  • Update the concat_linear() function to handle the new weights concatenation logic.

Code Changes

def check_concat_weights(weights):
    # Relax the same-shape requirement
    return all(weight.dtype == weights[0].dtype for weight in weights)

def concat_linear(linears):
    weights = [linear.weight for linear in linears]
    if check_concat_weights(weights):
        # Concatenate the weights
        concatenated_weight = torch.cat(weights, dim=0)
        # Create a new linear layer with the concatenated weight
        new_linear = torch.nn.Linear(linears[0].in_features, concatenated_weight.shape[0], bias=False)
        new_linear.weight = concatenated_weight
        return new_linear
    else:
        # Handle the case where weights cannot be concatenated
        return linears

Verification

To verify that the fix worked, we can use the provided reproducer code:

class GQA_QKV(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.q_proj = torch.nn.Linear(720, 960, bias=False)
        self.k_proj = torch.nn.Linear(720, 320, bias=False)
        self.v_proj = torch.nn.Linear(720, 320, bias=False)

    def forward(self, x):
        return self.q_proj(x), self.k_proj(x), self.v_proj(x)

model = GQA_QKV().cuda().bfloat16().eval()
x = torch.randn(50, 720, device="cuda", dtype=torch.bfloat16)

compiled = torch.compile(model, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
    _, code = run_and_get_code(compiled, x)

call_body = re.search(r"def call\(self.*?return .*?\n", code[0], re.DOTALL).group()
mm_calls = re.findall(r"\.run\(", call_body)
fused = len(mm_calls) == 1
print(f"{'PASS' if fused else 'FAIL'}: {len(mm_calls)} matmul(s) in generated code")

If the fix is correct, the output should be PASS: 1 matmul(s) in generated code.

Extra Tips

  • Make sure to test the fix with different input shapes and sizes to ensure that it works correctly in all cases.
  • Consider adding additional tests to the TorchInductor test suite to cover this specific scenario.
  • If you encounter any issues or regressions, please report them to the TorchInductor maintainers.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING