pytorch - 💡(How to fix) Fix torch.compile(flex_attention) fails with compiled create_block_mask(..., BLOCK_SIZE=(16, 16)) in decoding case - LoweringException: NoValidChoicesError [3 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178437Fetched 2026-04-08 01:30:28
View on GitHub
Comments
3
Participants
2
Timeline
62
Reactions
0
Participants
Timeline (top)
mentioned ×26subscribed ×26labeled ×6commented ×3

When using flex_attention with enable_gqa=True in a decode-style setup (L=1, S=100), compiling create_block_mask with an explicit BLOCK_SIZE=(16, 16) causes flex_attention to fail during Inductor lowering with:

NoValidChoicesError: No choices to select

If BLOCK_SIZE is not passed, the same example works. The failure appears during flex_attention compilation rather than at eager mask creation time.

fyi: I found this while running vLLM: when using flexattention, FORCE_USE_FLEX_ATTENTION must be set to true, but with that configuration flex_decode cannot be used, which leads to very poor performance.

Error Message

/usr/local/lib/python3.11/dist-packages/torch/init.py:1551: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.) return _C._get_float32_matmul_precision() Traceback (most recent call last): File "test_flex.py", line 35, in <module> out = flex_attention_compiled( ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 990, in _compile_fx_inner raise InductorError(e, currentframe()).with_traceback( File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 974, in _compile_fx_inner mb_compiled_graph = fx_codegen_and_compile( ^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1695, in fx_codegen_and_compile return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1420, in codegen_and_compile graph.run(*example_inputs) File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 937, in run return super().run(*args) ^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/fx/interpreter.py", line 174, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1624, in run_node result = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/fx/interpreter.py", line 256, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1289, in call_function raise LoweringException(e, target, args, kwargs).with_traceback( File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1279, in call_function out = lowerings[target](*args, **kwargs) # type: ignore[index] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/lowering.py", line 488, in wrapped out = decomp_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/kernel/flex/flex_attention.py", line 172, in flex_attention return create_flex_decoding_kernel( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/kernel/flex/flex_decoding.py", line 388, in create_flex_decoding_kernel buf_ACC = autotune_select_algorithm( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/select_algorithm.py", line 3443, in autotune_select_algorithm return cache(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/select_algorithm.py", line 2423, in call raise NoValidChoicesError( torch._inductor.exc.InductorError: LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. target: flex_attention args[0]: TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 16, 1, 256], stride=[4096, 256, 256, 1])) )) args[1]: TensorBox(StorageBox( InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 8, 100, 256], stride=[204800, 25600, 256, 1])) )) args[2]: TensorBox(StorageBox( InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 8, 100, 256], stride=[204800, 25600, 256, 1])) )) args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None) args[4]: (1, 100, TensorBox(StorageBox( InputBuffer(name='arg4_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1], stride=[1, 1, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg3_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1, 7], stride=[7, 7, 7, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg5_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1], stride=[1, 1, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg6_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1, 7], stride=[7, 7, 7, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg7_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7], stride=[7, 7, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg8_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7, 1], stride=[7, 7, 1, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg9_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7], stride=[7, 7, 1])) )), TensorBox(StorageBox( InputBuffer(name='arg10_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7, 1], stride=[7, 7, 1, 1])) )), 16, 16, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None)) args[5]: 0.0625 args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False} args[7]: () args[8]: ()

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Root Cause

When using flex_attention with enable_gqa=True in a decode-style setup (L=1, S=100), compiling create_block_mask with an explicit BLOCK_SIZE=(16, 16) causes flex_attention to fail during Inductor lowering with:

NoValidChoicesError: No choices to select

If BLOCK_SIZE is not passed, the same example works. The failure appears during flex_attention compilation rather than at eager mask creation time.

fyi: I found this while running vLLM: when using flexattention, FORCE_USE_FLEX_ATTENTION must be set to true, but with that configuration flex_decode cannot be used, which leads to very poor performance.

Fix Action

Fix / Workaround

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 52 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 216 On-line CPU(s) list: 0-215 Vendor ID: GenuineIntel Model name: INTEL(R) XEON(R) PLATINUM 8581C CPU @ 2.10GHz CPU family: 6 Model: 207 Thread(s) per core: 2 Core(s) per socket: 54 Socket(s): 2 Stepping: 2 BogoMIPS: 4200.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities Hypervisor vendor: KVM Virtualization type: full L1d cache: 5.1 MiB (108 instances) L1i cache: 3.4 MiB (108 instances) L2 cache: 216 MiB (108 instances) L3 cache: 520 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-53,108-161 NUMA node1 CPU(s): 54-107,162-215 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Code Example

import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

flex_attention_compiled = torch.compile(flex_attention)

device = "cuda"
dtype = torch.bfloat16

B = 1
Hq = 16
Hkv = 8
L = 1
S = 100
E = 256

q = torch.randn((B, Hq, L, E), dtype=dtype, device=device)
k = torch.randn((B, Hkv, S, E), dtype=dtype, device=device)
v = torch.randn((B, Hkv, S, E), dtype=dtype, device=device)

def causal_mask_mod(
    b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
):
    return q_idx >= kv_idx

# failure case
block_mask = torch.compile(create_block_mask)(
    causal_mask_mod,
    B=None,
    H=None,
    Q_LEN=q.shape[2],
    KV_LEN=k.shape[2],
    BLOCK_SIZE=(16, 16),
)

# ok case
# block_mask = torch.compile(create_block_mask)(
#     causal_mask_mod,
#     B=None,
#     H=None,
#     Q_LEN=q.shape[2],
#     KV_LEN=k.shape[2],
# )

out = flex_attention_compiled(
    q,
    k,
    v,
    block_mask=block_mask,
    enable_gqa=True,
)

torch.cuda.synchronize()
print(out.shape)

---

/usr/local/lib/python3.11/dist-packages/torch/__init__.py:1551: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
  return _C._get_float32_matmul_precision()
Traceback (most recent call last):
  File "test_flex.py", line 35, in <module>
    out = flex_attention_compiled(
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 990, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 974, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1695, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1420, in codegen_and_compile
    graph.run(*example_inputs)
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 937, in run
    return super().run(*args)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1624, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/fx/interpreter.py", line 256, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1289, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1279, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/lowering.py", line 488, in wrapped
    out = decomp_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/kernel/flex/flex_attention.py", line 172, in flex_attention
    return create_flex_decoding_kernel(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/kernel/flex/flex_decoding.py", line 388, in create_flex_decoding_kernel
    buf_ACC = autotune_select_algorithm(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/select_algorithm.py", line 3443, in autotune_select_algorithm
    return cache(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/select_algorithm.py", line 2423, in __call__
    raise NoValidChoicesError(
torch._inductor.exc.InductorError: LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. 
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 16, 1, 256], stride=[4096, 256, 256, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 8, 100, 256], stride=[204800, 25600, 256, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 8, 100, 256], stride=[204800, 25600, 256, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (1, 100, TensorBox(StorageBox(
    InputBuffer(name='arg4_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1, 7], stride=[7, 7, 7, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg5_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg6_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1, 7], stride=[7, 7, 7, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg7_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7], stride=[7, 7, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg8_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7, 1], stride=[7, 7, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg9_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7], stride=[7, 7, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg10_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7, 1], stride=[7, 7, 1, 1]))
  )), 16, 16, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.0625
  args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}
  args[7]: ()
  args[8]: ()

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Bug: torch.compile(flex_attention) fails with compiled create_block_mask(..., BLOCK_SIZE=(16, 16)) in decoding case

Summary

When using flex_attention with enable_gqa=True in a decode-style setup (L=1, S=100), compiling create_block_mask with an explicit BLOCK_SIZE=(16, 16) causes flex_attention to fail during Inductor lowering with:

NoValidChoicesError: No choices to select

If BLOCK_SIZE is not passed, the same example works. The failure appears during flex_attention compilation rather than at eager mask creation time.

fyi: I found this while running vLLM: when using flexattention, FORCE_USE_FLEX_ATTENTION must be set to true, but with that configuration flex_decode cannot be used, which leads to very poor performance.

Minimal repro

import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

flex_attention_compiled = torch.compile(flex_attention)

device = "cuda"
dtype = torch.bfloat16

B = 1
Hq = 16
Hkv = 8
L = 1
S = 100
E = 256

q = torch.randn((B, Hq, L, E), dtype=dtype, device=device)
k = torch.randn((B, Hkv, S, E), dtype=dtype, device=device)
v = torch.randn((B, Hkv, S, E), dtype=dtype, device=device)

def causal_mask_mod(
    b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
):
    return q_idx >= kv_idx

# failure case
block_mask = torch.compile(create_block_mask)(
    causal_mask_mod,
    B=None,
    H=None,
    Q_LEN=q.shape[2],
    KV_LEN=k.shape[2],
    BLOCK_SIZE=(16, 16),
)

# ok case
# block_mask = torch.compile(create_block_mask)(
#     causal_mask_mod,
#     B=None,
#     H=None,
#     Q_LEN=q.shape[2],
#     KV_LEN=k.shape[2],
# )

out = flex_attention_compiled(
    q,
    k,
    v,
    block_mask=block_mask,
    enable_gqa=True,
)

torch.cuda.synchronize()
print(out.shape)

Error logs

/usr/local/lib/python3.11/dist-packages/torch/__init__.py:1551: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
  return _C._get_float32_matmul_precision()
Traceback (most recent call last):
  File "test_flex.py", line 35, in <module>
    out = flex_attention_compiled(
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 990, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 974, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1695, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 1420, in codegen_and_compile
    graph.run(*example_inputs)
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 937, in run
    return super().run(*args)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1624, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/fx/interpreter.py", line 256, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1289, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/graph.py", line 1279, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/lowering.py", line 488, in wrapped
    out = decomp_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/kernel/flex/flex_attention.py", line 172, in flex_attention
    return create_flex_decoding_kernel(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/kernel/flex/flex_decoding.py", line 388, in create_flex_decoding_kernel
    buf_ACC = autotune_select_algorithm(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/select_algorithm.py", line 3443, in autotune_select_algorithm
    return cache(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/select_algorithm.py", line 2423, in __call__
    raise NoValidChoicesError(
torch._inductor.exc.InductorError: LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. 
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 16, 1, 256], stride=[4096, 256, 256, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 8, 100, 256], stride=[204800, 25600, 256, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[1, 8, 100, 256], stride=[204800, 25600, 256, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (1, 100, TensorBox(StorageBox(
    InputBuffer(name='arg4_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg3_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1, 7], stride=[7, 7, 7, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg5_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg6_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 1, 7], stride=[7, 7, 7, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg7_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7], stride=[7, 7, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg8_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7, 1], stride=[7, 7, 1, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg9_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7], stride=[7, 7, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='arg10_1', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 7, 1], stride=[7, 7, 1, 1]))
  )), 16, 16, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.0625
  args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}
  args[7]: ()
  args[8]: ()

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Versions

PyTorch version: 2.9.0+cu128 Is debug build: False CUDA used to build PyTorch: 12.8 ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 12 (bookworm) (x86_64) GCC version: (Debian 12.2.0-14+deb12u1) 12.2.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.36

Python version: 3.11.2 (main, Nov 30 2024, 21:22:50) [GCC 12.2.0] (64-bit runtime) Python platform: Linux-5.15.152.bsk.10-amd64-x86_64-with-glibc2.36 Is CUDA available: True CUDA runtime version: 12.8.61 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA B200 Nvidia driver version: 580.105.08 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.7.1 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.7.1 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.7.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.7.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.7.1 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.7.1 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.7.1 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.7.1 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 52 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 216 On-line CPU(s) list: 0-215 Vendor ID: GenuineIntel Model name: INTEL(R) XEON(R) PLATINUM 8581C CPU @ 2.10GHz CPU family: 6 Model: 207 Thread(s) per core: 2 Core(s) per socket: 54 Socket(s): 2 Stepping: 2 BogoMIPS: 4200.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities Hypervisor vendor: KVM Virtualization type: full L1d cache: 5.1 MiB (108 instances) L1i cache: 3.4 MiB (108 instances) L2 cache: 216 MiB (108 instances) L3 cache: 520 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-53,108-161 NUMA node1 CPU(s): 54-107,162-215 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==2.2.6 [pip3] nvidia-cublas-cu12==12.8.4.1 [pip3] nvidia-cuda-cupti-cu12==12.8.90 [pip3] nvidia-cuda-nvrtc-cu12==12.8.93 [pip3] nvidia-cuda-runtime-cu12==12.8.90 [pip3] nvidia-cudnn-cu12==9.10.2.21 [pip3] nvidia-cudnn-frontend==1.18.0 [pip3] nvidia-cufft-cu12==11.3.3.83 [pip3] nvidia-curand-cu12==10.3.9.90 [pip3] nvidia-cusolver-cu12==11.7.3.90 [pip3] nvidia-cusparse-cu12==12.5.8.93 [pip3] nvidia-cusparselt-cu12==0.7.1 [pip3] nvidia-nccl-cu12==2.27.5 [pip3] nvidia-nvjitlink-cu12==12.8.93 [pip3] nvidia-nvtx-cu12==12.8.90 [pip3] torch==2.9.0 [pip3] torchaudio==2.9.0 [pip3] torchcodec==0.9.0 [pip3] torchdata==0.11.0 [pip3] torchvision==0.24.0 [pip3] triton==3.5.0 [conda] Could not collect

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv @xmfan

extent analysis

Fix Plan

The fix involves modifying the create_block_mask function to avoid explicit block size specification when compiling with torch.compile.

Here are the steps:

  • Remove the explicit BLOCK_SIZE argument from the create_block_mask function call.
  • If a specific block size is required, consider implementing a custom block mask creation function that is compatible with torch.compile.

Example code:

# Remove BLOCK_SIZE argument
block_mask = torch.compile(create_block_mask)(
    causal_mask_mod,
    B=None,
    H=None,
    Q_LEN=q.shape[2],
    KV_LEN=k.shape[2],
)

Verification

To verify the fix, run the modified code and check that it no longer raises the NoValidChoicesError. You can also test the performance of the compiled flex_attention function to ensure it is working as expected.

Extra Tips

  • When using torch.compile, it's essential to ensure that all functions and operations are compatible with the compilation process.
  • If you encounter issues with torch.compile, try setting TORCHDYNAMO_VERBOSE=1 to get more detailed error messages and internal stack traces.
  • Keep in mind that torch.compile is an experimental feature, and its behavior may change in future PyTorch versions.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix torch.compile(flex_attention) fails with compiled create_block_mask(..., BLOCK_SIZE=(16, 16)) in decoding case - LoweringException: NoValidChoicesError [3 comments, 2 participants]