pytorch - 💡(How to fix) Fix flex_attention(BACKEND='FLASH') return_lse=True returns lse off by factor ln(2) [1 participants]

Official PRs (…)
ON THIS PAGE

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181645Fetched 2026-04-28 06:24:09
View on GitHub
Comments
0
Participants
1
Timeline
150
Reactions
0
Author
Participants
Timeline (top)
mentioned ×72subscribed ×72labeled ×6

Root Cause

# torch/_higher_order_ops/flex_attention.py:260-266
# NB: kernel computes in ln2 space, we always convert back at the top level op, so
# for math impl we divide by log(2) because we will multiply by log(2)
return (
    post_mod_scores.to(query.dtype) @ value,
    logsumexp / math.log(2),
    max_scores / math.log(2),
)

Fix Action

Fix

Add a lse / ln(2) pre-divide in the FA4 dispatch glue (after _flash_attn_fwd returns, before _finalize_outputs). One line, mirrors what math_attention already does for eager.

Workaround

Divide returned FLASH lse by math.log(2.0).

Code Example

# torch/_higher_order_ops/flex_attention.py:260-266
# NB: kernel computes in ln2 space, we always convert back at the top level op, so
# for math impl we divide by log(2) because we will multiply by log(2)
return (
    post_mod_scores.to(query.dtype) @ value,
    logsumexp / math.log(2),
    max_scores / math.log(2),
)

---

import math, torch
from functools import partial
from torch.nn.attention.flex_attention import flex_attention

torch.manual_seed(0)
N, D = 1024, 32
q = torch.randn(1, 1, N, D, device="cuda", dtype=torch.float16)
k = torch.randn(1, 1, N, D, device="cuda", dtype=torch.float16)
v = torch.randn(1, 1, N, D, device="cuda", dtype=torch.float16)

triton = torch.compile(partial(flex_attention, kernel_options={"BACKEND": "TRITON"}))
flash  = torch.compile(partial(flex_attention, kernel_options={"BACKEND": "FLASH"}))

_, lse_t = triton(q, k, v, scale=1.0, return_lse=True)
_, lse_f = flash (q, k, v, scale=1.0, return_lse=True)

ref = torch.logsumexp(q.squeeze().float() @ k.squeeze().float().T, dim=-1)
print((lse_t.squeeze().float() - ref).abs().max())              # ~1e-6
print((lse_f.squeeze().float() - ref).abs().max())              # ~9 (off by ln(2))
print((lse_f.squeeze().float() / math.log(2.0) - ref).abs().max())  # ~1e-6
RAW_BUFFERClick to expand / collapse

Bug

flex_attention(..., kernel_options={"BACKEND": "FLASH"}, return_lse=True) returns lse * ln(2) instead of lse. The Triton backend and the eager fallback both return correct lse.

Why

_finalize_outputs (torch/nn/attention/flex_attention.py:1588) unconditionally does lse * ln(2), with the convention that backend implementations return lse in base-2. The eager math_attention honors this with an explicit pre-divide:

# torch/_higher_order_ops/flex_attention.py:260-266
# NB: kernel computes in ln2 space, we always convert back at the top level op, so
# for math impl we divide by log(2) because we will multiply by log(2)
return (
    post_mod_scores.to(query.dtype) @ value,
    logsumexp / math.log(2),
    max_scores / math.log(2),
)

The Triton template stores lse in base-2 natively (exp2/log2), so it also satisfies the convention. The FA4 dispatch path (_inductor/kernel/flex/templates/flash_attention.py.jinja) calls flash_attn.cute.interface._flash_attn_fwd, which returns lse in natural log (the conversion happens inside FA4's softmax.finalize). Nothing converts it back to base-2 before _finalize_outputs runs, so the multiply over-converts.

max_scores (line 1590) has the same issue, though I haven't tested whether the FA4 max_scores path is exposed publicly yet.

Repro

import math, torch
from functools import partial
from torch.nn.attention.flex_attention import flex_attention

torch.manual_seed(0)
N, D = 1024, 32
q = torch.randn(1, 1, N, D, device="cuda", dtype=torch.float16)
k = torch.randn(1, 1, N, D, device="cuda", dtype=torch.float16)
v = torch.randn(1, 1, N, D, device="cuda", dtype=torch.float16)

triton = torch.compile(partial(flex_attention, kernel_options={"BACKEND": "TRITON"}))
flash  = torch.compile(partial(flex_attention, kernel_options={"BACKEND": "FLASH"}))

_, lse_t = triton(q, k, v, scale=1.0, return_lse=True)
_, lse_f = flash (q, k, v, scale=1.0, return_lse=True)

ref = torch.logsumexp(q.squeeze().float() @ k.squeeze().float().T, dim=-1)
print((lse_t.squeeze().float() - ref).abs().max())              # ~1e-6
print((lse_f.squeeze().float() - ref).abs().max())              # ~9 (off by ln(2))
print((lse_f.squeeze().float() / math.log(2.0) - ref).abs().max())  # ~1e-6

Only manifests inside torch.compile. The eager fallback path goes through math_attention which honors the base-2 convention; only the compiled FA4 dispatch skips it.

Workaround

Divide returned FLASH lse by math.log(2.0).

Fix

Add a lse / ln(2) pre-divide in the FA4 dispatch glue (after _flash_attn_fwd returns, before _finalize_outputs). One line, mirrors what math_attention already does for eager.

Environment

Reproduced on:

  • PyTorch: 2.11.0a0+a6c236b9fd.nv26.03.46836102 (NGC nvcr.io/nvidia/pytorch:26.03-py3)
  • CUDA: 13.2
  • flash-attn-4: 4.0.0b10
  • nvidia-cutlass-dsl: 4.4.2
  • GPU: H200 (also seen on H100 and B200)

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

extent analysis

TL;DR

To fix the issue, divide the returned FLASH lse by math.log(2.0) or add a pre-divide in the FA4 dispatch glue.

Guidance

  • The issue is caused by the FA4 dispatch path returning lse in natural log, which is not converted back to base-2 before _finalize_outputs runs.
  • To verify the issue, compare the lse values returned by the Triton and FLASH backends using the provided repro code.
  • To mitigate the issue, divide the returned FLASH lse by math.log(2.0) as a workaround.
  • To fix the issue, add a lse / ln(2) pre-divide in the FA4 dispatch glue after _flash_attn_fwd returns and before _finalize_outputs runs.

Example

# Workaround
lse_f_corrected = lse_f.squeeze().float() / math.log(2.0)
print((lse_f_corrected - ref).abs().max())  # ~1e-6

Notes

The issue only manifests inside torch.compile and does not affect the eager fallback path.

Recommendation

Apply the workaround by dividing the returned FLASH lse by math.log(2.0) until the fix is implemented in the FA4 dispatch glue.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix flex_attention(BACKEND='FLASH') return_lse=True returns lse off by factor ln(2) [1 participants]