pytorch - 💡(How to fix) Fix [Inductor][flex attention] flex_attention autotuner uses misleading block-mask, may causing overhead and slower configs [1 participants]

Official PRs (…)
ON THIS PAGE

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180651Fetched 2026-04-18 05:51:41
View on GitHub
Comments
0
Participants
1
Timeline
57
Reactions
0
Author
Participants
Timeline (top)
subscribed ×26mentioned ×25labeled ×6

Root Cause

Because mask_mod cost is per-iteration and different BLOCK_N sizes iterate differently over each sparse block (SPARSE_KV_BLOCK_SIZE // BLOCK_N iterations), the inflated overhead hits some configs harder than others, distorting the autotuner ranking and potentially causing it to mispick.

Fix Action

Fix / Workaround

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 64 On-line CPU(s) list: 0-63 Vendor ID: AuthenticAMD Model name: AMD EPYC 7713 64-Core Processor CPU family: 25 Model: 1 Thread(s) per core: 1 Core(s) per socket: 64 Socket(s): 1 Stepping: 1 Frequency boost: disabled CPU max MHz: 3720.7029 CPU min MHz: 1500.0000 BogoMIPS: 4000.03 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm ibpb_exit_to_user Virtualization: AMD-V L1d cache: 2 MiB (64 instances) L1i cache: 2 MiB (64 instances) L2 cache: 32 MiB (64 instances) L3 cache: 256 MiB (8 instances) NUMA node(s): 8 NUMA node0 CPU(s): 0-7 NUMA node1 CPU(s): 8-15 NUMA node2 CPU(s): 16-23 NUMA node3 CPU(s): 24-31 NUMA node4 CPU(s): 32-39 NUMA node5 CPU(s): 40-47 NUMA node6 CPU(s): 48-55 NUMA node7 CPU(s): 56-63 Vulnerability Gather data sampling: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsa: Vulnerable: Clear CPU buffers attempted, no microcode Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

Code Example

# flex_flash_attention.py, lines 449-453 (same pattern in flex_attention.py and flex_decoding.py)
input_gen_fns = {
    4: create_num_blocks_fake_generator(kv_indices),       # partial blocks
    5: create_indices_fake,
    6: create_num_blocks_fake_generator(full_kv_indices),  # full blocks  ← BUG: same generator!
    7: create_indices_fake,
}

---

#!/usr/bin/env python3
"""
Reproducer (CUDA): flex_attention autotuner uses wrong fake block-mask data,
                   inflating mask_mod overhead and distorting config ranking.

Bug summary
-----------
`create_num_blocks_fake_generator` is shared by *both* `kv_num_blocks`
(partial / boundary blocks that need mask_mod) and `full_kv_num_blocks`
(interior blocks that skip mask_mod).  Both are filled with max_blocks,
so every block is counted TWICE — once as partial, once as full — giving
~2× the real workload and ~13× the real mask_mod evaluations.

Because mask_mod cost is per-iteration and different BLOCK_N sizes have
different iteration counts per sparse block (SPARSE_KV_BLOCK_SIZE // BLOCK_N),
the inflated mask_mod overhead hits some configs harder than others,
distorting the autotuner's ranking.

This script demonstrates the problem in two steps:
  Step 1Show the real vs fake block-mask structure (the root cause)
  Step 2Show autotuner timing vs independent ground-truth benchmark

Usage:  python issue_repro_cuda.py     (requires NVIDIA GPU with CUDA)
"""

import os, sys, io, re, shutil

# ── Environment (set before importing torch) ─────────────────────────
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"
os.environ["TORCHINDUCTOR_MAX_AUTOTUNE"] = "1"

CACHE_DIR = "/tmp/fa_issue_repro_cuda"
os.environ["TRITON_CACHE_DIR"] = os.path.join(CACHE_DIR, "triton")
os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CACHE_DIR, "inductor")

import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

# ── Workload: causal attention with Q_LEN < KV_LEN (LLM append/decode) ──
DEVICE   = "cuda"
DTYPE    = torch.float16
Z, H_Q, H_KV        = 1, 32, 8          # batch, q heads, kv heads (GQA)
Q_LEN, KV_LEN       = 512, 1664         # append-style: short query, long cache
D_HEAD               = 128
SM_SCALE             = 0.125
OFFSET               = KV_LEN - Q_LEN   # causal offset

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= (kv_idx - OFFSET)


# =====================================================================
# Step 1: Expose the fake-mask bug  (pure data comparison)
# =====================================================================
def step1_show_mask_mismatch():
    """Compare the *real* BlockMask with what the autotuner's fake generator produces."""
    print("=" * 72)
    print("STEP 1: Real block-mask vs autotuner's fake block-mask")
    print("=" * 72)

    bm = create_block_mask(causal_mask, 1, 1, Q_LEN, KV_LEN, device=DEVICE)

    # ── Real mask ────────────────────────────────────────────────────
    real_partial = bm.kv_num_blocks[0, 0]          # partial (boundary) blocks per Q-row
    real_full    = bm.full_kv_num_blocks[0, 0]     # full (interior) blocks per Q-row
    real_total   = real_partial + real_full
    max_kv_idx   = bm.kv_indices.shape[-1]         # max blocks in index tensor

    print(f"\n  Real causal mask  (BLOCK_SIZE={bm.BLOCK_SIZE}):")
    print(f"    kv_num_blocks (partial) per Q-row : {real_partial.tolist()}")
    print(f"    full_kv_num_blocks (full) per Q-row: {real_full.tolist()}")
    print(f"    total per Q-row                    : {real_total.tolist()}")

    # ── Fake mask (what the autotuner sees) ──────────────────────────
    fake_partial = max_kv_idx
    fake_full    = max_kv_idx

    print(f"\n  Fake mask produced by create_num_blocks_fake_generator:")
    print(f"    kv_num_blocks (partial) per Q-row : {fake_partial}  (all rows)")
    print(f"    full_kv_num_blocks (full) per Q-row: {fake_full}  (all rows)")
    print(f"    total per Q-row                    : {fake_partial + fake_full}  (DOUBLE counted!)")

    # ── Impact ───────────────────────────────────────────────────────
    avg_real_partial = real_partial.float().mean().item()
    print(f"\n  Impact:")
    print(f"    Real mask_mod evaluations per Q-row:  ~{avg_real_partial:.0f} partial blocks")
    print(f"    Fake mask_mod evaluations per Q-row:  {fake_partial} partial blocks")
    print(f"    Inflation factor: {fake_partial / max(avg_real_partial, 1):.0f}×")
    print(f"    Real total work per Q-row: ~{real_total.float().mean().item():.0f} blocks")
    print(f"    Fake total work per Q-row: {fake_partial + fake_full} blocks (2× double-counted)")
    print()


# =====================================================================
# Step 2: Show autotuner overhead vs ground truth
# =====================================================================
def step2_show_mispick():
    """Run the autotuner, then independently benchmark to show the overhead."""
    from torch._inductor.template_heuristics.triton import FlexConfig
    import torch._inductor.kernel.flex.flex_attention as flex_attn
    from triton.testing import do_bench as triton_do_bench

    # Configs typical for CUDA (Triton Jinja template path)
    CONFIGS = [
        FlexConfig(128, 64, 2, 4),
        FlexConfig(128, 32, 2, 4),
        FlexConfig(64,  64, 2, 4),
        FlexConfig(128, 128, 2, 8),
    ]

    def label(c):
        return f"{c.block_m}x{c.block_n}_w{c.num_warps}"

    def clear():
        torch._dynamo.reset()
        shutil.rmtree(CACHE_DIR, ignore_errors=True)
        os.makedirs(os.path.join(CACHE_DIR, "triton"), exist_ok=True)
        os.makedirs(os.path.join(CACHE_DIR, "inductor"), exist_ok=True)

    def make_inputs():
        q = torch.randn(Z, H_Q, Q_LEN, D_HEAD, device=DEVICE, dtype=DTYPE)
        k = torch.randn(Z, H_KV, KV_LEN, D_HEAD, device=DEVICE, dtype=DTYPE)
        v = torch.randn(Z, H_KV, KV_LEN, D_HEAD, device=DEVICE, dtype=DTYPE)
        bm = create_block_mask(causal_mask, 1, 1, Q_LEN, KV_LEN, device=DEVICE)
        return q, k, v, bm

    COMPILE_KWARGS = dict(scale=SM_SCALE, enable_gqa=True)

    # ── 2a: Run autotuner ────────────────────────────────────────────
    print("=" * 72)
    print("STEP 2: Autotuner ranking vs independent ground-truth benchmark")
    print("=" * 72)
    print("\n  2a) Running Inductor autotuner with all configs...\n")

    clear()
    flex_attn.V.choices.get_flex_attention_fwd_configs = lambda *a, **kw: CONFIGS
    q, k, v, bm = make_inputs()
    torch._dynamo.config.recompile_limit = 100
    compiled = torch.compile(flex_attention, dynamic=False)

    old_stderr = sys.stderr
    captured = io.StringIO()
    sys.stderr = captured
    compiled(q, k, v, block_mask=bm, **COMPILE_KWARGS)
    torch.cuda.synchronize()
    sys.stderr = old_stderr
    text = captured.getvalue()
    print(text, file=sys.stderr)

    # Parse autotuner timings
    at_timings = {}
    for line in text.split("\n"):
        m = re.match(
            r"\s+triton_flex_attention_\d+\s+([\d.]+)\s+ms\s+"
            r"([\d.]+)%.*BLOCK_M=(\d+),\s*BLOCK_N=(\d+).*num_warps=(\d+)",
            line,
        )
        if m:
            ms = float(m.group(1))
            key = f"{m.group(3)}x{m.group(4)}_w{m.group(5)}"
            at_timings[key] = ms

    if at_timings:
        picked = min(at_timings, key=at_timings.get)
        print(f"  Autotuner ranking (lower = better):")
        for k, v in sorted(at_timings.items(), key=lambda x: x[1]):
            tag = " <-- PICKED" if k == picked else ""
            print(f"    {k:20s}  {v:.4f} ms{tag}")
    else:
        picked = "?"
        print("  WARNING: could not parse autotuner output")

    # ── 2b: Independent benchmark (ground truth) ─────────────────────
    print(f"\n  2b) Independent benchmark per config (triton do_bench, warmup=600ms):\n")

    gt_timings = {}
    for cfg in CONFIGS:
        clear()
        flex_attn.V.choices.get_flex_attention_fwd_configs = lambda *a, c=cfg, **kw: [c]
        q, k, v, bm = make_inputs()
        torch._dynamo.config.recompile_limit = 100
        comp = torch.compile(flex_attention, dynamic=False)
        fn = lambda: comp(q, k, v, block_mask=bm, **COMPILE_KWARGS)
        fn(); torch.cuda.synchronize()
        med = triton_do_bench(fn, warmup=600, rep=100, quantiles=(0.5, 0.2, 0.8))[0]
        gt_timings[label(cfg)] = med
        print(f"    {label(cfg):20s}  {med:.4f} ms")

    gt_best = min(gt_timings, key=gt_timings.get)

    # ── Summary ──────────────────────────────────────────────────────
    print(f"\n{'=' * 72}")
    print(f"SUMMARY")
    print(f"{'=' * 72}")
    print(f"\n  {'Config':20s} | {'Autotuner':>12s} | {'Ground truth':>12s} | {'Overhead':>10s}")
    print(f"  {'-'*20}-+-{'-'*12}-+-{'-'*12}-+-{'-'*10}")
    for k in sorted(at_timings, key=at_timings.get):
        at_ms = at_timings.get(k, 0)
        gt_ms = gt_timings.get(k, 0)
        oh = f"{at_ms - gt_ms:+.4f} ms" if gt_ms else "N/A"
        tag = ""
        if k == picked:
            tag = " <-- autotuner picks"
        elif k == gt_best:
            tag = " <-- actually fastest"
        print(f"  {k:20s} | {at_ms:9.4f} ms | {gt_ms:9.4f} ms | {oh:>10s}{tag}")

    if picked != gt_best:
        loss = (gt_timings[picked] - gt_timings[gt_best]) / gt_timings[gt_best] * 100
        print(f"\n  BUG: Autotuner picks '{picked}' but '{gt_best}' is actually "
              f"{loss:.1f}% faster.")
        print(f"       Root cause: fake block-mask inflates mask_mod overhead unevenly.")
    else:
        print(f"\n  Autotuner picked the correct config '{picked}'.")
        at_best_ms = at_timings.get(picked, 0)
        gt_best_ms = gt_timings.get(picked, 0)
        if at_best_ms and gt_best_ms:
            oh_pct = (at_best_ms - gt_best_ms) / gt_best_ms * 100
            print(f"  However, autotuner overhead is still {oh_pct:+.1f}% "
                  f"({at_best_ms:.4f} vs {gt_best_ms:.4f} ms) due to inflated fake mask.")

    print("=" * 72)


# ── Entry point ──────────────────────────────────────────────────────
if __name__ == "__main__":
    assert torch.cuda.is_available(), "This reproducer requires an NVIDIA CUDA GPU"
    print(f"PyTorch {torch.__version__}  |  Device: {torch.cuda.get_device_name()}")
    print(f"Shape: Z={Z} H_Q={H_Q} H_KV={H_KV} Q={Q_LEN} KV={KV_LEN} D={D_HEAD}\n")
    step1_show_mask_mismatch()
    step2_show_mispick()

---

Collecting environment information...
PyTorch version: 2.11.0.dev20260112+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.19 | packaged by conda-forge | (main, Oct 22 2025, 22:29:10) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-164-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: 
GPU models and configuration: GPU 0: NVIDIA A100-PCIE-40GB
Nvidia driver version: 570.211.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           48 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  64
On-line CPU(s) list:                     0-63
Vendor ID:                               AuthenticAMD
Model name:                              AMD EPYC 7713 64-Core Processor
CPU family:                              25
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      64
Socket(s):                               1
Stepping:                                1
Frequency boost:                         disabled
CPU max MHz:                             3720.7029
CPU min MHz:                             1500.0000
BogoMIPS:                                4000.03
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm ibpb_exit_to_user
Virtualization:                          AMD-V
L1d cache:                               2 MiB (64 instances)
L1i cache:                               2 MiB (64 instances)
L2 cache:                                32 MiB (64 instances)
L3 cache:                                256 MiB (8 instances)
NUMA node(s):                            8
NUMA node0 CPU(s):                       0-7
NUMA node1 CPU(s):                       8-15
NUMA node2 CPU(s):                       16-23
NUMA node3 CPU(s):                       24-31
NUMA node4 CPU(s):                       32-39
NUMA node5 CPU(s):                       40-47
NUMA node6 CPU(s):                       48-55
NUMA node7 CPU(s):                       56-63
Vulnerability Gather data sampling:      Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Mitigation; safe RET
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Vulnerable: Clear CPU buffers attempted, no microcode
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB before exit to userspace

Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.28.9
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] torch==2.11.0.dev20260112+cu126
[pip3] torchao==0.16.0+git4b3ebc444
[pip3] torchvision==0.25.0.dev20260112+cu126
[pip3] triton==3.6.0+git9844da95
[conda] numpy                       2.2.6                     pypi_0              pypi
[conda] nvidia-cublas-cu12          12.6.4.1                  pypi_0              pypi
[conda] nvidia-cuda-cupti-cu12      12.6.80                   pypi_0              pypi
[conda] nvidia-cuda-nvrtc-cu12      12.6.77                   pypi_0              pypi
[conda] nvidia-cuda-runtime-cu12    12.6.77                   pypi_0              pypi
[conda] nvidia-cudnn-cu12           9.10.2.21                 pypi_0              pypi
[conda] nvidia-cufft-cu12           11.3.0.4                  pypi_0              pypi
[conda] nvidia-curand-cu12          10.3.7.77                 pypi_0              pypi
[conda] nvidia-cusolver-cu12        11.7.1.2                  pypi_0              pypi
[conda] nvidia-cusparse-cu12        12.5.4.2                  pypi_0              pypi
[conda] nvidia-cusparselt-cu12      0.7.1                     pypi_0              pypi
[conda] nvidia-nccl-cu12            2.28.9                    pypi_0              pypi
[conda] nvidia-nvjitlink-cu12       12.6.85                   pypi_0              pypi
[conda] nvidia-nvtx-cu12            12.6.77                   pypi_0              pypi
[conda] torch                       2.11.0.dev20260112+cu126  pypi_0              pypi
[conda] torchao                     0.16.0+git4b3ebc444       pypi_0              pypi
[conda] torchvision                 0.25.0.dev20260112+cu126  pypi_0              pypi
[conda] triton                      3.6.0+git9844da95         pypi_0              pypi
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

This issue is writed by claude opus model.

The Inductor autotuner's fake block-mask generator (create_num_blocks_fake_generator in torch/_inductor/kernel/flex/common.py) is shared by both kv_num_blocks (partial/boundary blocks that need mask_mod) and full_kv_num_blocks (interior blocks that skip mask_mod). Both are filled with max_blocks, causing:

  1. Every KV block is double-counted — once as partial, once as full — giving ~2× the real workload per Q-row.
  2. mask_mod evaluations are inflated ~13× — real causal masks have ~1 partial block per Q-row, but the fake fills all 13 as partial.

Because mask_mod cost is per-iteration and different BLOCK_N sizes iterate differently over each sparse block (SPARSE_KV_BLOCK_SIZE // BLOCK_N iterations), the inflated overhead hits some configs harder than others, distorting the autotuner ranking and potentially causing it to mispick.

Detailed analysis

Flex attention splits KV blocks into two disjoint sets per Q-row:

  • Partial (boundary) blocks: straddle the mask edge → need expensive mask_mod evaluation
  • Full (interior) blocks: entirely inside the mask → skip mask_mod

The invariant is: partial + full = total_blocks_to_process.

For a typical causal mask with Q_LEN=512, KV_LEN=1664, BLOCK_SIZE=128:

Partial (kv_num_blocks)Full (full_kv_num_blocks)Total
Real mask[1, 1, 1, 1][9, 10, 11, 12]10–13
Fake (autotuner)[13, 13, 13, 13][13, 13, 13, 13]26 (double!)

The fake generator fills both with max_blocks=13 because the same create_num_blocks_fake_generator function is called for both:

# flex_flash_attention.py, lines 449-453 (same pattern in flex_attention.py and flex_decoding.py)
input_gen_fns = {
    4: create_num_blocks_fake_generator(kv_indices),       # partial blocks
    5: create_indices_fake,
    6: create_num_blocks_fake_generator(full_kv_indices),  # full blocks  ← BUG: same generator!
    7: create_indices_fake,
}

reproducer

#!/usr/bin/env python3
"""
Reproducer (CUDA): flex_attention autotuner uses wrong fake block-mask data,
                   inflating mask_mod overhead and distorting config ranking.

Bug summary
-----------
`create_num_blocks_fake_generator` is shared by *both* `kv_num_blocks`
(partial / boundary blocks that need mask_mod) and `full_kv_num_blocks`
(interior blocks that skip mask_mod).  Both are filled with max_blocks,
so every block is counted TWICE — once as partial, once as full — giving
~2× the real workload and ~13× the real mask_mod evaluations.

Because mask_mod cost is per-iteration and different BLOCK_N sizes have
different iteration counts per sparse block (SPARSE_KV_BLOCK_SIZE // BLOCK_N),
the inflated mask_mod overhead hits some configs harder than others,
distorting the autotuner's ranking.

This script demonstrates the problem in two steps:
  Step 1 — Show the real vs fake block-mask structure (the root cause)
  Step 2 — Show autotuner timing vs independent ground-truth benchmark

Usage:  python issue_repro_cuda.py     (requires NVIDIA GPU with CUDA)
"""

import os, sys, io, re, shutil

# ── Environment (set before importing torch) ─────────────────────────
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"
os.environ["TORCHINDUCTOR_MAX_AUTOTUNE"] = "1"

CACHE_DIR = "/tmp/fa_issue_repro_cuda"
os.environ["TRITON_CACHE_DIR"] = os.path.join(CACHE_DIR, "triton")
os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CACHE_DIR, "inductor")

import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

# ── Workload: causal attention with Q_LEN < KV_LEN (LLM append/decode) ──
DEVICE   = "cuda"
DTYPE    = torch.float16
Z, H_Q, H_KV        = 1, 32, 8          # batch, q heads, kv heads (GQA)
Q_LEN, KV_LEN       = 512, 1664         # append-style: short query, long cache
D_HEAD               = 128
SM_SCALE             = 0.125
OFFSET               = KV_LEN - Q_LEN   # causal offset

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= (kv_idx - OFFSET)


# =====================================================================
# Step 1: Expose the fake-mask bug  (pure data comparison)
# =====================================================================
def step1_show_mask_mismatch():
    """Compare the *real* BlockMask with what the autotuner's fake generator produces."""
    print("=" * 72)
    print("STEP 1: Real block-mask vs autotuner's fake block-mask")
    print("=" * 72)

    bm = create_block_mask(causal_mask, 1, 1, Q_LEN, KV_LEN, device=DEVICE)

    # ── Real mask ────────────────────────────────────────────────────
    real_partial = bm.kv_num_blocks[0, 0]          # partial (boundary) blocks per Q-row
    real_full    = bm.full_kv_num_blocks[0, 0]     # full (interior) blocks per Q-row
    real_total   = real_partial + real_full
    max_kv_idx   = bm.kv_indices.shape[-1]         # max blocks in index tensor

    print(f"\n  Real causal mask  (BLOCK_SIZE={bm.BLOCK_SIZE}):")
    print(f"    kv_num_blocks (partial) per Q-row : {real_partial.tolist()}")
    print(f"    full_kv_num_blocks (full) per Q-row: {real_full.tolist()}")
    print(f"    total per Q-row                    : {real_total.tolist()}")

    # ── Fake mask (what the autotuner sees) ──────────────────────────
    fake_partial = max_kv_idx
    fake_full    = max_kv_idx

    print(f"\n  Fake mask produced by create_num_blocks_fake_generator:")
    print(f"    kv_num_blocks (partial) per Q-row : {fake_partial}  (all rows)")
    print(f"    full_kv_num_blocks (full) per Q-row: {fake_full}  (all rows)")
    print(f"    total per Q-row                    : {fake_partial + fake_full}  (DOUBLE counted!)")

    # ── Impact ───────────────────────────────────────────────────────
    avg_real_partial = real_partial.float().mean().item()
    print(f"\n  Impact:")
    print(f"    Real mask_mod evaluations per Q-row:  ~{avg_real_partial:.0f} partial blocks")
    print(f"    Fake mask_mod evaluations per Q-row:  {fake_partial} partial blocks")
    print(f"    Inflation factor: {fake_partial / max(avg_real_partial, 1):.0f}×")
    print(f"    Real total work per Q-row: ~{real_total.float().mean().item():.0f} blocks")
    print(f"    Fake total work per Q-row: {fake_partial + fake_full} blocks (2× double-counted)")
    print()


# =====================================================================
# Step 2: Show autotuner overhead vs ground truth
# =====================================================================
def step2_show_mispick():
    """Run the autotuner, then independently benchmark to show the overhead."""
    from torch._inductor.template_heuristics.triton import FlexConfig
    import torch._inductor.kernel.flex.flex_attention as flex_attn
    from triton.testing import do_bench as triton_do_bench

    # Configs typical for CUDA (Triton Jinja template path)
    CONFIGS = [
        FlexConfig(128, 64, 2, 4),
        FlexConfig(128, 32, 2, 4),
        FlexConfig(64,  64, 2, 4),
        FlexConfig(128, 128, 2, 8),
    ]

    def label(c):
        return f"{c.block_m}x{c.block_n}_w{c.num_warps}"

    def clear():
        torch._dynamo.reset()
        shutil.rmtree(CACHE_DIR, ignore_errors=True)
        os.makedirs(os.path.join(CACHE_DIR, "triton"), exist_ok=True)
        os.makedirs(os.path.join(CACHE_DIR, "inductor"), exist_ok=True)

    def make_inputs():
        q = torch.randn(Z, H_Q, Q_LEN, D_HEAD, device=DEVICE, dtype=DTYPE)
        k = torch.randn(Z, H_KV, KV_LEN, D_HEAD, device=DEVICE, dtype=DTYPE)
        v = torch.randn(Z, H_KV, KV_LEN, D_HEAD, device=DEVICE, dtype=DTYPE)
        bm = create_block_mask(causal_mask, 1, 1, Q_LEN, KV_LEN, device=DEVICE)
        return q, k, v, bm

    COMPILE_KWARGS = dict(scale=SM_SCALE, enable_gqa=True)

    # ── 2a: Run autotuner ────────────────────────────────────────────
    print("=" * 72)
    print("STEP 2: Autotuner ranking vs independent ground-truth benchmark")
    print("=" * 72)
    print("\n  2a) Running Inductor autotuner with all configs...\n")

    clear()
    flex_attn.V.choices.get_flex_attention_fwd_configs = lambda *a, **kw: CONFIGS
    q, k, v, bm = make_inputs()
    torch._dynamo.config.recompile_limit = 100
    compiled = torch.compile(flex_attention, dynamic=False)

    old_stderr = sys.stderr
    captured = io.StringIO()
    sys.stderr = captured
    compiled(q, k, v, block_mask=bm, **COMPILE_KWARGS)
    torch.cuda.synchronize()
    sys.stderr = old_stderr
    text = captured.getvalue()
    print(text, file=sys.stderr)

    # Parse autotuner timings
    at_timings = {}
    for line in text.split("\n"):
        m = re.match(
            r"\s+triton_flex_attention_\d+\s+([\d.]+)\s+ms\s+"
            r"([\d.]+)%.*BLOCK_M=(\d+),\s*BLOCK_N=(\d+).*num_warps=(\d+)",
            line,
        )
        if m:
            ms = float(m.group(1))
            key = f"{m.group(3)}x{m.group(4)}_w{m.group(5)}"
            at_timings[key] = ms

    if at_timings:
        picked = min(at_timings, key=at_timings.get)
        print(f"  Autotuner ranking (lower = better):")
        for k, v in sorted(at_timings.items(), key=lambda x: x[1]):
            tag = " <-- PICKED" if k == picked else ""
            print(f"    {k:20s}  {v:.4f} ms{tag}")
    else:
        picked = "?"
        print("  WARNING: could not parse autotuner output")

    # ── 2b: Independent benchmark (ground truth) ─────────────────────
    print(f"\n  2b) Independent benchmark per config (triton do_bench, warmup=600ms):\n")

    gt_timings = {}
    for cfg in CONFIGS:
        clear()
        flex_attn.V.choices.get_flex_attention_fwd_configs = lambda *a, c=cfg, **kw: [c]
        q, k, v, bm = make_inputs()
        torch._dynamo.config.recompile_limit = 100
        comp = torch.compile(flex_attention, dynamic=False)
        fn = lambda: comp(q, k, v, block_mask=bm, **COMPILE_KWARGS)
        fn(); torch.cuda.synchronize()
        med = triton_do_bench(fn, warmup=600, rep=100, quantiles=(0.5, 0.2, 0.8))[0]
        gt_timings[label(cfg)] = med
        print(f"    {label(cfg):20s}  {med:.4f} ms")

    gt_best = min(gt_timings, key=gt_timings.get)

    # ── Summary ──────────────────────────────────────────────────────
    print(f"\n{'=' * 72}")
    print(f"SUMMARY")
    print(f"{'=' * 72}")
    print(f"\n  {'Config':20s} | {'Autotuner':>12s} | {'Ground truth':>12s} | {'Overhead':>10s}")
    print(f"  {'-'*20}-+-{'-'*12}-+-{'-'*12}-+-{'-'*10}")
    for k in sorted(at_timings, key=at_timings.get):
        at_ms = at_timings.get(k, 0)
        gt_ms = gt_timings.get(k, 0)
        oh = f"{at_ms - gt_ms:+.4f} ms" if gt_ms else "N/A"
        tag = ""
        if k == picked:
            tag = " <-- autotuner picks"
        elif k == gt_best:
            tag = " <-- actually fastest"
        print(f"  {k:20s} | {at_ms:9.4f} ms | {gt_ms:9.4f} ms | {oh:>10s}{tag}")

    if picked != gt_best:
        loss = (gt_timings[picked] - gt_timings[gt_best]) / gt_timings[gt_best] * 100
        print(f"\n  BUG: Autotuner picks '{picked}' but '{gt_best}' is actually "
              f"{loss:.1f}% faster.")
        print(f"       Root cause: fake block-mask inflates mask_mod overhead unevenly.")
    else:
        print(f"\n  Autotuner picked the correct config '{picked}'.")
        at_best_ms = at_timings.get(picked, 0)
        gt_best_ms = gt_timings.get(picked, 0)
        if at_best_ms and gt_best_ms:
            oh_pct = (at_best_ms - gt_best_ms) / gt_best_ms * 100
            print(f"  However, autotuner overhead is still {oh_pct:+.1f}% "
                  f"({at_best_ms:.4f} vs {gt_best_ms:.4f} ms) due to inflated fake mask.")

    print("=" * 72)


# ── Entry point ──────────────────────────────────────────────────────
if __name__ == "__main__":
    assert torch.cuda.is_available(), "This reproducer requires an NVIDIA CUDA GPU"
    print(f"PyTorch {torch.__version__}  |  Device: {torch.cuda.get_device_name()}")
    print(f"Shape: Z={Z} H_Q={H_Q} H_KV={H_KV} Q={Q_LEN} KV={KV_LEN} D={D_HEAD}\n")
    step1_show_mask_mismatch()
    step2_show_mispick()

Versions

Collecting environment information...
PyTorch version: 2.11.0.dev20260112+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.19 | packaged by conda-forge | (main, Oct 22 2025, 22:29:10) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-164-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: 
GPU models and configuration: GPU 0: NVIDIA A100-PCIE-40GB
Nvidia driver version: 570.211.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           48 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  64
On-line CPU(s) list:                     0-63
Vendor ID:                               AuthenticAMD
Model name:                              AMD EPYC 7713 64-Core Processor
CPU family:                              25
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      64
Socket(s):                               1
Stepping:                                1
Frequency boost:                         disabled
CPU max MHz:                             3720.7029
CPU min MHz:                             1500.0000
BogoMIPS:                                4000.03
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm ibpb_exit_to_user
Virtualization:                          AMD-V
L1d cache:                               2 MiB (64 instances)
L1i cache:                               2 MiB (64 instances)
L2 cache:                                32 MiB (64 instances)
L3 cache:                                256 MiB (8 instances)
NUMA node(s):                            8
NUMA node0 CPU(s):                       0-7
NUMA node1 CPU(s):                       8-15
NUMA node2 CPU(s):                       16-23
NUMA node3 CPU(s):                       24-31
NUMA node4 CPU(s):                       32-39
NUMA node5 CPU(s):                       40-47
NUMA node6 CPU(s):                       48-55
NUMA node7 CPU(s):                       56-63
Vulnerability Gather data sampling:      Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Mitigation; safe RET
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Vulnerable: Clear CPU buffers attempted, no microcode
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB before exit to userspace

Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.28.9
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] torch==2.11.0.dev20260112+cu126
[pip3] torchao==0.16.0+git4b3ebc444
[pip3] torchvision==0.25.0.dev20260112+cu126
[pip3] triton==3.6.0+git9844da95
[conda] numpy                       2.2.6                     pypi_0              pypi
[conda] nvidia-cublas-cu12          12.6.4.1                  pypi_0              pypi
[conda] nvidia-cuda-cupti-cu12      12.6.80                   pypi_0              pypi
[conda] nvidia-cuda-nvrtc-cu12      12.6.77                   pypi_0              pypi
[conda] nvidia-cuda-runtime-cu12    12.6.77                   pypi_0              pypi
[conda] nvidia-cudnn-cu12           9.10.2.21                 pypi_0              pypi
[conda] nvidia-cufft-cu12           11.3.0.4                  pypi_0              pypi
[conda] nvidia-curand-cu12          10.3.7.77                 pypi_0              pypi
[conda] nvidia-cusolver-cu12        11.7.1.2                  pypi_0              pypi
[conda] nvidia-cusparse-cu12        12.5.4.2                  pypi_0              pypi
[conda] nvidia-cusparselt-cu12      0.7.1                     pypi_0              pypi
[conda] nvidia-nccl-cu12            2.28.9                    pypi_0              pypi
[conda] nvidia-nvjitlink-cu12       12.6.85                   pypi_0              pypi
[conda] nvidia-nvtx-cu12            12.6.77                   pypi_0              pypi
[conda] torch                       2.11.0.dev20260112+cu126  pypi_0              pypi
[conda] torchao                     0.16.0+git4b3ebc444       pypi_0              pypi
[conda] torchvision                 0.25.0.dev20260112+cu126  pypi_0              pypi
[conda] triton                      3.6.0+git9844da95         pypi_0              pypi

cc @jerryzh168 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

extent analysis

TL;DR

The most likely fix is to modify the create_num_blocks_fake_generator function to correctly distinguish between partial and full blocks, ensuring that each block is only counted once.

Guidance

  1. Identify the root cause: The issue arises from the shared use of create_num_blocks_fake_generator for both partial and full blocks, leading to double-counting and inflated mask_mod evaluations.
  2. Modify the fake generator: Update create_num_blocks_fake_generator to accurately reflect the distinction between partial and full blocks, ensuring that each block is only counted once.
  3. Verify the fix: Run the reproducer script to confirm that the autotuner's ranking aligns with the ground-truth benchmark, and that the mask_mod evaluations are no longer inflated.
  4. Test with various configurations: Validate the fix across different configurations to ensure that it generalizes and does not introduce new issues.

Example

# Example modification to create_num_blocks_fake_generator
def create_num_blocks_fake_generator(indices, is_partial):
    if is_partial:
        # Generate fake partial blocks
        return len([idx for idx in indices if idx == 0])  # Simplified example
    else:
        # Generate fake full blocks
        return len([idx for idx in indices if idx != 0])  # Simplified example

Notes

  • The provided example is a simplified illustration and may require adjustments to fit the actual implementation.
  • The fix should ensure that the autotuner's ranking accurately reflects the performance of different configurations.

Recommendation

Apply the workaround by modifying the create_num_blocks_fake_generator function to correctly handle partial and full blocks, as this directly addresses the root cause of the issue.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [Inductor][flex attention] flex_attention autotuner uses misleading block-mask, may causing overhead and slower configs [1 participants]