pytorch - 💡(How to fix) Fix make_fx doesn't work with partial binded tensors

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Root Cause

prints

partial_replays_batch0 True

partial_matches_batch1 False

closure_replays_batch0 False

closure_matches_batch1 True

This is because we can pytree flatten/unflatten closure mask_mods not partially binded ones. I guess there is nothing to fix here but just filed here for common gotchas in non-strict tracing.

Code Example

from functools import partial

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.attention.flex_attention import BlockMask, create_block_mask


query_indices = torch.arange(8, dtype=torch.int32)[:, None]
key_indices = torch.arange(8, dtype=torch.int32)[None, :]
document_ids = torch.zeros(8, dtype=torch.int64)
batch0_attn_regions = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.int32)
batch1_attn_regions = torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int32)


def mask_rule(batch, head, query_idx, key_idx, attn_regions, document_ids): 
    return (query_idx >= key_idx) & (attn_regions[query_idx] == attn_regions[key_idx]) & (document_ids[query_idx] == document_ids[key_idx])
def closure_mask_rule(attn_regions, document_ids): 
    return lambda batch, head, query_idx, key_idx: mask_rule(batch, head, query_idx, key_idx, attn_regions, document_ids)
def make_block_mask(mask_mod): 
    return create_block_mask(mask_mod, B=1, H=1, Q_LEN=8, KV_LEN=8, device="cpu", BLOCK_SIZE=4)


def trace_mask_mod(block_mask):
    leaves, spec = block_mask._flatten()
    return make_fx(lambda *flat_leaves: BlockMask._unflatten(flat_leaves, spec).mask_mod(0, 0, query_indices, key_indices))(*leaves)


traced_partial_mask_mod = trace_mask_mod(make_block_mask(partial(mask_rule, attn_regions=batch0_attn_regions, document_ids=document_ids)))
traced_closure_mask_mod = trace_mask_mod(make_block_mask(closure_mask_rule(batch0_attn_regions, document_ids)))

replayed_partial_on_batch1 = traced_partial_mask_mod(*make_block_mask(partial(mask_rule, attn_regions=batch1_attn_regions, document_ids=document_ids))._flatten()[0])
replayed_closure_on_batch1 = traced_closure_mask_mod(*make_block_mask(closure_mask_rule(batch1_attn_regions, document_ids))._flatten()[0])
expected_batch0 = make_block_mask(partial(mask_rule, attn_regions=batch0_attn_regions, document_ids=document_ids)).mask_mod(0, 0, query_indices, key_indices)
expected_batch1 = make_block_mask(partial(mask_rule, attn_regions=batch1_attn_regions, document_ids=document_ids)).mask_mod(0, 0, query_indices, key_indices)

print("partial_replays_batch0", torch.equal(replayed_partial_on_batch1, expected_batch0))
print("partial_matches_batch1", torch.equal(replayed_partial_on_batch1, expected_batch1))
print("closure_replays_batch0", torch.equal(replayed_closure_on_batch1, expected_batch0))
print("closure_matches_batch1", torch.equal(replayed_closure_on_batch1, expected_batch1))

# prints 
# partial_replays_batch0 True
# partial_matches_batch1 False
# closure_replays_batch0 False
# closure_matches_batch1 True
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

from functools import partial

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.attention.flex_attention import BlockMask, create_block_mask


query_indices = torch.arange(8, dtype=torch.int32)[:, None]
key_indices = torch.arange(8, dtype=torch.int32)[None, :]
document_ids = torch.zeros(8, dtype=torch.int64)
batch0_attn_regions = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.int32)
batch1_attn_regions = torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int32)


def mask_rule(batch, head, query_idx, key_idx, attn_regions, document_ids): 
    return (query_idx >= key_idx) & (attn_regions[query_idx] == attn_regions[key_idx]) & (document_ids[query_idx] == document_ids[key_idx])
def closure_mask_rule(attn_regions, document_ids): 
    return lambda batch, head, query_idx, key_idx: mask_rule(batch, head, query_idx, key_idx, attn_regions, document_ids)
def make_block_mask(mask_mod): 
    return create_block_mask(mask_mod, B=1, H=1, Q_LEN=8, KV_LEN=8, device="cpu", BLOCK_SIZE=4)


def trace_mask_mod(block_mask):
    leaves, spec = block_mask._flatten()
    return make_fx(lambda *flat_leaves: BlockMask._unflatten(flat_leaves, spec).mask_mod(0, 0, query_indices, key_indices))(*leaves)


traced_partial_mask_mod = trace_mask_mod(make_block_mask(partial(mask_rule, attn_regions=batch0_attn_regions, document_ids=document_ids)))
traced_closure_mask_mod = trace_mask_mod(make_block_mask(closure_mask_rule(batch0_attn_regions, document_ids)))

replayed_partial_on_batch1 = traced_partial_mask_mod(*make_block_mask(partial(mask_rule, attn_regions=batch1_attn_regions, document_ids=document_ids))._flatten()[0])
replayed_closure_on_batch1 = traced_closure_mask_mod(*make_block_mask(closure_mask_rule(batch1_attn_regions, document_ids))._flatten()[0])
expected_batch0 = make_block_mask(partial(mask_rule, attn_regions=batch0_attn_regions, document_ids=document_ids)).mask_mod(0, 0, query_indices, key_indices)
expected_batch1 = make_block_mask(partial(mask_rule, attn_regions=batch1_attn_regions, document_ids=document_ids)).mask_mod(0, 0, query_indices, key_indices)

print("partial_replays_batch0", torch.equal(replayed_partial_on_batch1, expected_batch0))
print("partial_matches_batch1", torch.equal(replayed_partial_on_batch1, expected_batch1))
print("closure_replays_batch0", torch.equal(replayed_closure_on_batch1, expected_batch0))
print("closure_matches_batch1", torch.equal(replayed_closure_on_batch1, expected_batch1))

# prints 
# partial_replays_batch0 True
# partial_matches_batch1 False
# closure_replays_batch0 False
# closure_matches_batch1 True

This is because we can pytree flatten/unflatten closure mask_mods not partially binded ones. I guess there is nothing to fix here but just filed here for common gotchas in non-strict tracing.

Versions

main

cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @chauhang @penguinwu @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

extent analysis

TL;DR

The issue is due to the inability to pytree flatten/unflatten partially bound mask mods, and the solution is to use closure mask mods instead.

Guidance

  • The problem arises from trying to replay a traced partial mask mod on a different batch, which fails because partially bound mask mods cannot be pytree flattened/unflattened.
  • To mitigate this, use closure mask mods, which can be pytree flattened/unflattened, as demonstrated by the closure_mask_rule function.
  • When using make_fx to trace a mask mod, ensure that the mask mod is a closure, not a partially bound function.
  • Verify that the traced mask mod works correctly by comparing its output with the expected output, as shown in the provided code.

Example

# Define a closure mask mod
def closure_mask_rule(attn_regions, document_ids): 
    return lambda batch, head, query_idx, key_idx: mask_rule(batch, head, query_idx, key_idx, attn_regions, document_ids)

# Create a block mask using the closure mask mod
traced_closure_mask_mod = trace_mask_mod(make_block_mask(closure_mask_rule(batch0_attn_regions, document_ids)))

# Replay the traced closure mask mod on a different batch
replayed_closure_on_batch1 = traced_closure_mask_mod(*make_block_mask(closure_mask_rule(batch1_attn_regions, document_ids))._flatten()[0])

Notes

This issue highlights the importance of using closures when working with pytree flattening/unflattening in non-strict tracing.

Recommendation

Apply workaround: use closure mask mods instead of partially bound mask mods to ensure correct behavior when replaying traced mask mods on different batches.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix make_fx doesn't work with partial binded tensors