pytorch - 💡(How to fix) Fix pytree type + AC composibility issue [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179189Fetched 2026-04-08 02:32:59
View on GitHub
Comments
1
Participants
1
Timeline
54
Reactions
0
Participants
Timeline (top)
mentioned ×20subscribed ×20labeled ×8unlabeled ×4

Error Message

File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/checkpoint.py:145, in _infer_device_type(*args) 143 if isinstance(arg, torch.Tensor) and arg.device.type != "cpu": 144 device_types.append(arg.device.type) --> 145 tree_map(add_device_types, args) 147 device_types_set = set(device_types) 148 if len(device_types_set) > 1: File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/_pytree.py:1577, in tree_map(func, tree, is_leaf, *rests) 1575 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) 1576 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] -> 1577 return treespec.unflatten(map(func, *flat_args)) File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/_pytree.py:1330, in TreeSpec.unflatten(self, leaves) 1328 for child_spec in self._children: 1329 end += child_spec.num_leaves -> 1330 child_pytrees.append(child_spec.unflatten(leaves[start:end])) 1331 start = end 1333 return unflatten_fn(child_pytrees, self._context) File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/_pytree.py:1333, in TreeSpec.unflatten(self, leaves) 1330 child_pytrees.append(child_spec.unflatten(leaves[start:end])) 1331 start = end -> 1333 return unflatten_fn(child_pytrees, self._context) File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/nn/attention/flex_attention.py:1070, in BlockMask._unflatten(cls, tensors, context) 1065 kwargs = { 1066 attr: cls._unwrap_context_value(attr, val) 1067 for attr, val in zip(cls._CONTEXT_ATTRS, context) 1068 } 1069 kwargs.update(zip(cls._TENSOR_ATTRS, tensors)) -> 1070 return cls(**kwargs) File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/nn/attention/flex_attention.py:601, in BlockMask.init(self, seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod) 587 def init( 588 self, 589 seq_lengths: tuple[int, int], (...) 599 mask_mod: _mask_mod_signature, 600 ) -> None: --> 601 if kv_indices.dim() < 2: 602 raise RuntimeError("BlockMask must have at least 2 dimensions") 603 if kv_num_blocks is None: AttributeError: 'NoneType' object has no attribute 'dim'

Root Cause

This shows up during enabling non-strict tracer on Titan where when we register BlockMask as pytree type, it fails during eager AC path because when you do tree_map you are supposed to return sth.

Code Example

import torch                                                                        
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from torch.utils._pytree import register_pytree_node                                
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper                                                                  

register_pytree_node(                                                               
    BlockMask, BlockMask._flatten, BlockMask._unflatten,        
    flatten_with_keys_fn=BlockMask._flatten_with_keys,
)

mask = create_block_mask(lambda b, h, q, kv: q >= kv, 1, 1, 128, 128)               

class Block(torch.nn.Module):                                                       
    def forward(self, x, mask):                                 
        return x * 2

model = checkpoint_wrapper(Block()).cuda()                                          
x = torch.randn(4, 128, device="cuda")
model(x, mask)  # crashes

---

File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/checkpoint.py:145, in _infer_device_type(*args)
    143     if isinstance(arg, torch.Tensor) and arg.device.type != "cpu":
    144         device_types.append(arg.device.type)
--> 145 tree_map(add_device_types, args)
    147 device_types_set = set(device_types)
    148 if len(device_types_set) > 1:
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/_pytree.py:1577, in tree_map(func, tree, is_leaf, *rests)
   1575 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
   1576 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
-> 1577 return treespec.unflatten(map(func, *flat_args))
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/_pytree.py:1330, in TreeSpec.unflatten(self, leaves)
   1328 for child_spec in self._children:
   1329     end += child_spec.num_leaves
-> 1330     child_pytrees.append(child_spec.unflatten(leaves[start:end]))
   1331     start = end
   1333 return unflatten_fn(child_pytrees, self._context)
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/_pytree.py:1333, in TreeSpec.unflatten(self, leaves)
   1330     child_pytrees.append(child_spec.unflatten(leaves[start:end]))
   1331     start = end
-> 1333 return unflatten_fn(child_pytrees, self._context)
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/nn/attention/flex_attention.py:1070, in BlockMask._unflatten(cls, tensors, context)
   1065 kwargs = {
   1066     attr: cls._unwrap_context_value(attr, val)
   1067     for attr, val in zip(cls._CONTEXT_ATTRS, context)
   1068 }
   1069 kwargs.update(zip(cls._TENSOR_ATTRS, tensors))
-> 1070 return cls(**kwargs)
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/nn/attention/flex_attention.py:601, in BlockMask.__init__(self, seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod)
    587 def __init__(
    588     self,
    589     seq_lengths: tuple[int, int],
   (...)
    599     mask_mod: _mask_mod_signature,
    600 ) -> None:
--> 601     if kv_indices.dim() < 2:
    602         raise RuntimeError("BlockMask must have at least 2 dimensions")
    603     if kv_num_blocks is None:
AttributeError: 'NoneType' object has no attribute 'dim'
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

import torch                                                                        
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from torch.utils._pytree import register_pytree_node                                
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper                                                                  

register_pytree_node(                                                               
    BlockMask, BlockMask._flatten, BlockMask._unflatten,        
    flatten_with_keys_fn=BlockMask._flatten_with_keys,
)

mask = create_block_mask(lambda b, h, q, kv: q >= kv, 1, 1, 128, 128)               

class Block(torch.nn.Module):                                                       
    def forward(self, x, mask):                                 
        return x * 2

model = checkpoint_wrapper(Block()).cuda()                                          
x = torch.randn(4, 128, device="cuda")
model(x, mask)  # crashes
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/checkpoint.py:145, in _infer_device_type(*args)
    143     if isinstance(arg, torch.Tensor) and arg.device.type != "cpu":
    144         device_types.append(arg.device.type)
--> 145 tree_map(add_device_types, args)
    147 device_types_set = set(device_types)
    148 if len(device_types_set) > 1:
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/_pytree.py:1577, in tree_map(func, tree, is_leaf, *rests)
   1575 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
   1576 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
-> 1577 return treespec.unflatten(map(func, *flat_args))
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/_pytree.py:1330, in TreeSpec.unflatten(self, leaves)
   1328 for child_spec in self._children:
   1329     end += child_spec.num_leaves
-> 1330     child_pytrees.append(child_spec.unflatten(leaves[start:end]))
   1331     start = end
   1333 return unflatten_fn(child_pytrees, self._context)
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/_pytree.py:1333, in TreeSpec.unflatten(self, leaves)
   1330     child_pytrees.append(child_spec.unflatten(leaves[start:end]))
   1331     start = end
-> 1333 return unflatten_fn(child_pytrees, self._context)
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/nn/attention/flex_attention.py:1070, in BlockMask._unflatten(cls, tensors, context)
   1065 kwargs = {
   1066     attr: cls._unwrap_context_value(attr, val)
   1067     for attr, val in zip(cls._CONTEXT_ATTRS, context)
   1068 }
   1069 kwargs.update(zip(cls._TENSOR_ATTRS, tensors))
-> 1070 return cls(**kwargs)
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2807/bento_kernel_pytorch_binary-inplace#link-tree/torch/nn/attention/flex_attention.py:601, in BlockMask.__init__(self, seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod)
    587 def __init__(
    588     self,
    589     seq_lengths: tuple[int, int],
   (...)
    599     mask_mod: _mask_mod_signature,
    600 ) -> None:
--> 601     if kv_indices.dim() < 2:
    602         raise RuntimeError("BlockMask must have at least 2 dimensions")
    603     if kv_num_blocks is None:
AttributeError: 'NoneType' object has no attribute 'dim'

This shows up during enabling non-strict tracer on Titan where when we register BlockMask as pytree type, it fails during eager AC path because when you do tree_map you are supposed to return sth.

Versions

main

cc @soulitzer @zou3519 @XuehaiPan @chauhang @penguinwu @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

extent analysis

TL;DR

The issue is likely due to the BlockMask object not being properly handled by the tree_map function, causing an AttributeError when trying to access the dim attribute of None.

Guidance

  • The error occurs because kv_indices is None when creating a BlockMask object, which is not expected to have a dim attribute.
  • The register_pytree_node function is used to register BlockMask as a pytree node, but it seems that the tree_map function is not correctly handling the BlockMask object.
  • To fix this issue, we need to ensure that kv_indices is not None when creating a BlockMask object, or modify the BlockMask class to handle the case where kv_indices is None.
  • We should also verify that the tree_map function is correctly handling the BlockMask object by checking the implementation of the tree_map function and the BlockMask class.

Example

# Before creating a BlockMask object, ensure that kv_indices is not None
kv_indices = torch.tensor(...)  # Replace with actual tensor
mask = create_block_mask(lambda b, h, q, kv: q >= kv, 1, 1, 128, 128, kv_indices=kv_indices)

Notes

  • The issue seems to be related to the implementation of the BlockMask class and the tree_map function, which are part of the PyTorch library.
  • Without more information about the implementation of these functions, it is difficult to provide a more specific solution.

Recommendation

Apply workaround: Modify the BlockMask class to handle the case where kv_indices is None, or ensure that kv_indices is not None when creating a BlockMask object. This will prevent the AttributeError from occurring and allow the code to run without errors.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix pytree type + AC composibility issue [1 comments, 1 participants]