pytorch - 💡(How to fix) Fix [Inductor/CuteDSL] UnboundLocalError on captured mask_mod tensor in flash FlexAttention [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181800Fetched 2026-04-29 06:10:59
View on GitHub
Comments
0
Participants
1
Timeline
65
Reactions
0
Author
Participants
Timeline (top)
mentioned ×29subscribed ×29labeled ×7

Error Message

/export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl │ │ │ │ ❱ 1789 │ │ │ return forward_call(*args, **kwargs) │ │ │ │ /home/jobuser/mbf/modeling/model_builder/model_builder.py:58 in forward │ │ │ │ ❱ 58 │ def forward( │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:473 in call │ │ │ │ ❱ 473 │ │ │ return super().call(*args, **kwargs) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1778 in _wrapped_call_impl │ │ │ │ ❱ 1778 │ │ │ return self._call_impl(*args, **kwargs) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl │ │ │ │ ❱ 1789 │ │ │ return forward_call(*args, **kwargs) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1304 in _fn │ │ │ │ ❱ 1304 │ │ │ │ │ return fn(*args, **kwargs) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/fx/graph_module.py:981 in call_wrapped │ │ │ │ ❱ 981 │ │ │ return self._wrapped_call(self, *args, **kwargs) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/fx/graph_module.py:488 in call │ │ │ │ ❱ 488 │ │ │ │ raise e │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/fx/graph_module.py:474 in call │ │ │ │ ❱ 474 │ │ │ │ return super(self.cls, obj).call(*args, **kwargs) # type: ignore[mi │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1778 in _wrapped_call_impl │ │ │ │ ❱ 1778 │ │ │ return self._call_impl(*args, **kwargs) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl │ │ │ │ ❱ 1789 │ │ │ return forward_call(*args, **kwargs) │ │ in forward:508 │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1778 in _wrapped_call_impl │ │ │ │ ❱ 1778 │ │ │ return self._call_impl(*args, **kwargs) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl │ │ │ │ ❱ 1789 │ │ │ return forward_call(*args, **kwargs) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py:217 in forward │ │ │ │ ❱ 217 │ │ │ │ x = self.submod(*args) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1304 in _fn │ │ │ │ ❱ 1304 │ │ │ │ │ return fn(*args, **kwargs) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py:1273 in forward │ │ │ │ ❱ 1273 │ │ │ return compiled_fn(full_args) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:924 in runtime_wrapper │ │ │ │ ❱ 924 │ │ │ all_outs = compiled_invoker.run(args, on_before_call=exit_prologue) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:522 in run │ │ │ │ ❱ 522 │ │ │ │ return call_func_at_runtime_with_args( │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py:126 in call_func_at_runtime_with_args │ │ │ │ ❱ 126 │ │ │ out = normalize_as_list(f(args)) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:1128 in wrapper │ │ │ │ ❱ 1128 │ │ │ │ return compiled_fn(runtime_args) │ │ │ │ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_inductor/output_code.py:725 in call │ │ │ │ ❱ 725 │ │ │ │ return self.current_callable(inputs) │ │ │ │ /tmp/torchinductor_jobuser/da/cda2n2rajnp5fr6zoobcxgc5zsbtquntl2adcfswlsdaasn4edia.py:5839 in call │ │ │ │ ❱ 5839 │ │ │ cutedsl_fused__to_copy_add_bitwise_and_clone_eq_flex_attention_gt_index_le_l │ ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ UnboundLocalError: cannot access local variable 'buf180' where it is not associated with a value

Code Example

/export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl                                   │
│                                                                                                                                       │
│ ❱ 1789 │   │   │   return forward_call(*args, **kwargs)│                                                                                                                                       │
/home/jobuser/mbf/modeling/model_builder/model_builder.py:58 in forward                                                               │
│                                                                                                                                       │
│ ❱  58 │   def forward(│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:473 in __call__                                     │
│                                                                                                                                       │
│ ❱  473 │   │   │   return super().__call__(*args, **kwargs)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1778 in _wrapped_call_impl                           │
│                                                                                                                                       │
│ ❱ 1778 │   │   │   return self._call_impl(*args, **kwargs)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl                                   │
│                                                                                                                                       │
│ ❱ 1789 │   │   │   return forward_call(*args, **kwargs)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1304 in _fn                                         │
│                                                                                                                                       │
│ ❱ 1304 │   │   │   │   │   return fn(*args, **kwargs)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/fx/graph_module.py:981 in call_wrapped                                    │
│                                                                                                                                       │
│ ❱  981 │   │   │   return self._wrapped_call(self, *args, **kwargs)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/fx/graph_module.py:488 in __call__                                        │
│                                                                                                                                       │
│ ❱  488 │   │   │   │   raise e                                                                                                        │
│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/fx/graph_module.py:474 in __call__                                        │
│                                                                                                                                       │
│ ❱  474 │   │   │   │   return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[mi                                       │
│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1778 in _wrapped_call_impl                           │
│                                                                                                                                       │
│ ❱ 1778 │   │   │   return self._call_impl(*args, **kwargs)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl                                   │
│                                                                                                                                       │
│ ❱ 1789 │   │   │   return forward_call(*args, **kwargs)in forward:508│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1778 in _wrapped_call_impl                           │
│                                                                                                                                       │
│ ❱ 1778 │   │   │   return self._call_impl(*args, **kwargs)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl                                   │
│                                                                                                                                       │
│ ❱ 1789 │   │   │   return forward_call(*args, **kwargs)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py:217 in forward                            │
│                                                                                                                                       │
│ ❱ 217 │   │   │   │   x = self.submod(*args)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1304 in _fn                                         │
│                                                                                                                                       │
│ ❱ 1304 │   │   │   │   │   return fn(*args, **kwargs)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py:1273 in forward                                │
│                                                                                                                                       │
│ ❱ 1273 │   │   │   return compiled_fn(full_args)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:924 in runtime_wrapper       │
│                                                                                                                                       │
│ ❱  924 │   │   │   all_outs = compiled_invoker.run(args, on_before_call=exit_prologue)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:522 in run                   │
│                                                                                                                                       │
│ ❱  522 │   │   │   │   return call_func_at_runtime_with_args(│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py:126 in call_func_at_runtime_with_args   │
│                                                                                                                                       │
│ ❱ 126 │   │   │   out = normalize_as_list(f(args))│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:1128 in wrapper              │
│                                                                                                                                       │
│ ❱ 1128 │   │   │   │   return compiled_fn(runtime_args)│                                                                                                                                       │
/export/apps/python/3.12/lib/python3.12/site-packages/torch/_inductor/output_code.py:725 in __call__                                  │
│                                                                                                                                       │
│ ❱  725 │   │   │   │   return self.current_callable(inputs)│                                                                                                                                       │
/tmp/torchinductor_jobuser/da/cda2n2rajnp5fr6zoobcxgc5zsbtquntl2adcfswlsdaasn4edia.py:5839 in call                                    │
│                                                                                                                                       │
│ ❱ 5839 │   │   │   cutedsl_fused__to_copy_add_bitwise_and_clone_eq_flex_attention_gt_index_le_l                                       │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
UnboundLocalError: cannot access local variable 'buf180' where it is not associated with a value

---

triton_poi_fused_gt_10.run(arg103_1, buf138, 262144, stream=stream0)
            del arg103_1
            assert_size_stride(arg58_1, (128, 1003), (1003, 1))
            assert_size_stride(arg105_1, (128, 1003), (1003, 1))
            assert_size_stride(arg106_1, (128, 1003), (1003, 1))
            assert_size_stride(arg59_1, (128, 1003), (1003, 1))
            arg58_1 = copy_if_misaligned(arg58_1)
            arg105_1 = copy_if_misaligned(arg105_1)
            arg106_1 = copy_if_misaligned(arg106_1)
            arg59_1 = copy_if_misaligned(arg59_1)
            **buf180 = empty_strided_cuda((3584, 1003), (1003, 1), torch.bool)**
            buf152 = reinterpret_tensor(buf180, (128, 1003), (1003, 1), 0)  # alias
            buf154 = reinterpret_tensor(buf180, (128, 1003), (1003, 1), 256768)  # alias
....................
triton_red_fused__to_copy_any_bitwise_and_index_le_lift_fresh_stack_sum_unsqueeze_view_23.run(buf180, arg104_1, _tensor_constant0, arg5_1, buf203, buf250, buf274, 43776, 1003, stream=stream0)
            del arg104_1
            del arg5_1
            **del buf180**
............
cutedsl_fused__to_copy_add_bitwise_and_clone_eq_flex_attention_gt_index_le_lift_fresh_lt_permute_slice_sort_stack_sum_transpose_unsqueeze_view_53177e8e.run(reinterpret_tensor(buf215, (128, 4, 342, 128), (175104, 43776, 128, 1), 0), reinterpret_tensor(buf210, (128, 4, 1003, 128), (513536, 128, 512, 1), 0), reinterpret_tensor(buf214, (128, 4, 1003, 128), (513536, 128, 512, 1), 0), buf225, buf216, buf217, buf218, buf219, **buf180**, arg104_1, _tensor_constant0, arg5_1, buf226, stream=stream0)

---

import torch
  from torch.nn.attention.flex_attention import create_block_mask, flex_attention
                                                                                                                                                                                                                                                                 
  dtype = torch.bfloat16
                                                                                                                                                   
  B, H, Q_LEN, KV_LEN, D = 2, 4, 256, 512, 128                                                                                                     
                                                                                                                                                   
  def model(q, k, v, rule_masks_per_query, prefix_mask):                                                                                           
      masks = rule_masks_per_query & prefix_mask.unsqueeze(0)   # bitwise_and → buf_T                                                              
                                                                                                                                                   
      def mask_mod(b, h, q_idx, kv_idx):                                                                                                           
          return masks[q_idx, b, kv_idx]                       # captures buf_T                                                                    
                                                                                                                                                   
      block_mask = create_block_mask(                                                                                                              
          mask_mod, B=B, H=None, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device,                                                                        
      )                                                                                                                                            
      return flex_attention(
          q, k, v,                                                                                                                                 
          block_mask=block_mask,
          kernel_options={"BACKEND": "FLASH"},
      )

  compiled = torch.compile(model, dynamic=False)                                                                                                   
  
  q = torch.randn(B, H, Q_LEN, D, device=device, dtype=dtype)                                                                                      
  k = torch.randn(B, H, KV_LEN, D, device=device, dtype=dtype)
  v = torch.randn(B, H, KV_LEN, D, device=device, dtype=dtype)                                                                                     
  rule_masks = torch.randint(0, 2, (Q_LEN, B, KV_LEN), dtype=torch.bool, device=device)                                                            
  prefix    = torch.randint(0, 2, (B, KV_LEN), dtype=torch.bool, device=device)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When using flex_attention with kernel_options={"BACKEND": "FLASH"} (the FA4 / CuTeDSL backend), if a mask_mod captures a tensor that is also consumed by other ops earlier in the same graph, Inductor's scheduler emits del buf<N> on that captured tensor before the FlexAttention kernel runs. The generated Python wrapper then raises:

UnboundLocalError: cannot access local variable 'buf180' where it is not associated with a value

The same model compiles and runs fine with the default Triton backend.

Error trace:

/export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl                                   │
│                                                                                                                                       │
│ ❱ 1789 │   │   │   return forward_call(*args, **kwargs)                                                                               │
│                                                                                                                                       │
│ /home/jobuser/mbf/modeling/model_builder/model_builder.py:58 in forward                                                               │
│                                                                                                                                       │
│ ❱  58 │   def forward(                                                                                                                │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:473 in __call__                                     │
│                                                                                                                                       │
│ ❱  473 │   │   │   return super().__call__(*args, **kwargs)                                                                           │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1778 in _wrapped_call_impl                           │
│                                                                                                                                       │
│ ❱ 1778 │   │   │   return self._call_impl(*args, **kwargs)                                                                            │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl                                   │
│                                                                                                                                       │
│ ❱ 1789 │   │   │   return forward_call(*args, **kwargs)                                                                               │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1304 in _fn                                         │
│                                                                                                                                       │
│ ❱ 1304 │   │   │   │   │   return fn(*args, **kwargs)                                                                                 │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/fx/graph_module.py:981 in call_wrapped                                    │
│                                                                                                                                       │
│ ❱  981 │   │   │   return self._wrapped_call(self, *args, **kwargs)                                                                   │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/fx/graph_module.py:488 in __call__                                        │
│                                                                                                                                       │
│ ❱  488 │   │   │   │   raise e                                                                                                        │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/fx/graph_module.py:474 in __call__                                        │
│                                                                                                                                       │
│ ❱  474 │   │   │   │   return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[mi                                       │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1778 in _wrapped_call_impl                           │
│                                                                                                                                       │
│ ❱ 1778 │   │   │   return self._call_impl(*args, **kwargs)                                                                            │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl                                   │
│                                                                                                                                       │
│ ❱ 1789 │   │   │   return forward_call(*args, **kwargs)                                                                               │
│ in forward:508                                                                                                                        │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1778 in _wrapped_call_impl                           │
│                                                                                                                                       │
│ ❱ 1778 │   │   │   return self._call_impl(*args, **kwargs)                                                                            │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1789 in _call_impl                                   │
│                                                                                                                                       │
│ ❱ 1789 │   │   │   return forward_call(*args, **kwargs)                                                                               │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py:217 in forward                            │
│                                                                                                                                       │
│ ❱ 217 │   │   │   │   x = self.submod(*args)                                                                                          │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1304 in _fn                                         │
│                                                                                                                                       │
│ ❱ 1304 │   │   │   │   │   return fn(*args, **kwargs)                                                                                 │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py:1273 in forward                                │
│                                                                                                                                       │
│ ❱ 1273 │   │   │   return compiled_fn(full_args)                                                                                      │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:924 in runtime_wrapper       │
│                                                                                                                                       │
│ ❱  924 │   │   │   all_outs = compiled_invoker.run(args, on_before_call=exit_prologue)                                                │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:522 in run                   │
│                                                                                                                                       │
│ ❱  522 │   │   │   │   return call_func_at_runtime_with_args(                                                                         │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py:126 in call_func_at_runtime_with_args   │
│                                                                                                                                       │
│ ❱ 126 │   │   │   out = normalize_as_list(f(args))                                                                                    │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:1128 in wrapper              │
│                                                                                                                                       │
│ ❱ 1128 │   │   │   │   return compiled_fn(runtime_args)                                                                               │
│                                                                                                                                       │
│ /export/apps/python/3.12/lib/python3.12/site-packages/torch/_inductor/output_code.py:725 in __call__                                  │
│                                                                                                                                       │
│ ❱  725 │   │   │   │   return self.current_callable(inputs)                                                                           │
│                                                                                                                                       │
│ /tmp/torchinductor_jobuser/da/cda2n2rajnp5fr6zoobcxgc5zsbtquntl2adcfswlsdaasn4edia.py:5839 in call                                    │
│                                                                                                                                       │
│ ❱ 5839 │   │   │   cutedsl_fused__to_copy_add_bitwise_and_clone_eq_flex_attention_gt_index_le_l                                       │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
UnboundLocalError: cannot access local variable 'buf180' where it is not associated with a value

inductor compilation code (buf180 got deleted before cute kernel call it)

 triton_poi_fused_gt_10.run(arg103_1, buf138, 262144, stream=stream0)
            del arg103_1
            assert_size_stride(arg58_1, (128, 1003), (1003, 1))
            assert_size_stride(arg105_1, (128, 1003), (1003, 1))
            assert_size_stride(arg106_1, (128, 1003), (1003, 1))
            assert_size_stride(arg59_1, (128, 1003), (1003, 1))
            arg58_1 = copy_if_misaligned(arg58_1)
            arg105_1 = copy_if_misaligned(arg105_1)
            arg106_1 = copy_if_misaligned(arg106_1)
            arg59_1 = copy_if_misaligned(arg59_1)
            **buf180 = empty_strided_cuda((3584, 1003), (1003, 1), torch.bool)**
            buf152 = reinterpret_tensor(buf180, (128, 1003), (1003, 1), 0)  # alias
            buf154 = reinterpret_tensor(buf180, (128, 1003), (1003, 1), 256768)  # alias
....................
triton_red_fused__to_copy_any_bitwise_and_index_le_lift_fresh_stack_sum_unsqueeze_view_23.run(buf180, arg104_1, _tensor_constant0, arg5_1, buf203, buf250, buf274, 43776, 1003, stream=stream0)
            del arg104_1
            del arg5_1
            **del buf180**
............
cutedsl_fused__to_copy_add_bitwise_and_clone_eq_flex_attention_gt_index_le_lift_fresh_lt_permute_slice_sort_stack_sum_transpose_unsqueeze_view_53177e8e.run(reinterpret_tensor(buf215, (128, 4, 342, 128), (175104, 43776, 128, 1), 0), reinterpret_tensor(buf210, (128, 4, 1003, 128), (513536, 128, 512, 1), 0), reinterpret_tensor(buf214, (128, 4, 1003, 128), (513536, 128, 512, 1), 0), buf225, buf216, buf217, buf218, buf219, **buf180**, arg104_1, _tensor_constant0, arg5_1, buf226, stream=stream0)

example model code:

   import torch
  from torch.nn.attention.flex_attention import create_block_mask, flex_attention
                                                                                                                                                                                                                                                                 
  dtype = torch.bfloat16
                                                                                                                                                   
  B, H, Q_LEN, KV_LEN, D = 2, 4, 256, 512, 128                                                                                                     
                                                                                                                                                   
  def model(q, k, v, rule_masks_per_query, prefix_mask):                                                                                           
      masks = rule_masks_per_query & prefix_mask.unsqueeze(0)   # bitwise_and → buf_T                                                              
                                                                                                                                                   
      def mask_mod(b, h, q_idx, kv_idx):                                                                                                           
          return masks[q_idx, b, kv_idx]                       # captures buf_T                                                                    
                                                                                                                                                   
      block_mask = create_block_mask(                                                                                                              
          mask_mod, B=B, H=None, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device,                                                                        
      )                                                                                                                                            
      return flex_attention(
          q, k, v,                                                                                                                                 
          block_mask=block_mask,
          kernel_options={"BACKEND": "FLASH"},
      )

  compiled = torch.compile(model, dynamic=False)                                                                                                   
  
  q = torch.randn(B, H, Q_LEN, D, device=device, dtype=dtype)                                                                                      
  k = torch.randn(B, H, KV_LEN, D, device=device, dtype=dtype)
  v = torch.randn(B, H, KV_LEN, D, device=device, dtype=dtype)                                                                                     
  rule_masks = torch.randint(0, 2, (Q_LEN, B, KV_LEN), dtype=torch.bool, device=device)                                                            
  prefix    = torch.randint(0, 2, (B, KV_LEN), dtype=torch.bool, device=device)

Versions

pytorch nightly 2.13.0.dev20260426+cu130 (include previous unsupported index_expr fix commit: https://github.com/pytorch/pytorch/issues/181182)

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv @oulgen @jansel @yf225 @Sibylau @choijon5

extent analysis

TL;DR

The issue is likely caused by the Inductor scheduler deleting a buffer (buf180) before it is used by the FlexAttention kernel, and a workaround is to modify the kernel options or the model code to avoid this deletion.

Guidance

  • The error occurs when using the flex_attention function with the kernel_options={"BACKEND": "FLASH"} and a mask_mod that captures a tensor used by other ops earlier in the graph.
  • The Inductor scheduler is emitting a del buf180 statement before the FlexAttention kernel runs, causing the UnboundLocalError.
  • To mitigate this issue, try modifying the kernel_options to use a different backend or adjust the mask_mod to avoid capturing tensors used by other ops.
  • Verify that the fix worked by checking if the error is resolved and the model compiles and runs correctly.

Example

No code snippet is provided as the issue is related to the internal workings of the Inductor scheduler and the FlexAttention kernel.

Notes

The issue is specific to the PyTorch nightly version 2.13.0.dev20260426+cu130 and may not occur in other versions. The fix may require modifying the kernel options or the model code, and the exact solution may depend on the specific use case.

Recommendation

Apply a workaround by modifying the kernel options or the model code to avoid the deletion of the buffer (buf180) before it is used by the FlexAttention kernel. This may involve using a different backend or adjusting the mask_mod to avoid capturing tensors used by other ops.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [Inductor/CuteDSL] UnboundLocalError on captured mask_mod tensor in flash FlexAttention [1 participants]