transformers - ✅(Solved) Fix flash-attn-4 (flash_attn.cute) is not supported by attn_implementation="flash_attention_2" [1 pull requests, 2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#44559Fetched 2026-04-08 00:27:41
View on GitHub
Comments
2
Participants
2
Timeline
4
Reactions
0
Timeline (top)
commented ×2closed ×1labeled ×1

Error Message

ImportError: cannot import name 'flash_attn_func' from 'flash_attn' (unknown location)

Root Cause

Blackwell users moving to newer CUDA / PyTorch stacks are likely to try flash-attn-4, but the current import path fails before training begins. This makes the newer FA4 stack unusable from stock Transformers attention selection even though the FA4 functions are present and importable from flash_attn.cute.

Fix Action

Fix / Workaround

My main motivation is to make newer Blackwell-oriented FlashAttention stacks usable from Transformers without requiring users to patch library internals locally.

PR fix notes

PR #840: [MoBA] Integrate MOBA and FlashMOBA

Description (problem / solution / changelog)

This pull request introduces two state-of-the-art sparse attention modules—MobaAttention and FlashMoBA—to the FLA codebase, along with the necessary integration and registration steps.

  • Background: MoBA (Mixture-of-Block Attention) originates from MoonshotAI MoonshotAI/MoBA, while FlashMoBA is its highly optimized CUDA implementation developed by the MIT HAN Lab mit-han-lab/flash-moba.
  • Motivation: Considering both are typical and highly advanced sparse attention mechanisms, I have refactored and wrapped them into an FLA-compatible format. The goal is to enrich the scope of models supported by FLA and provide the open-source community with a unified, convenient interface to easily utilize MoBA and Flash-MoBA.
  • Technical Integration:
    • Added the MobaAttention module implementation in fla/layers/moba.py, which provides a native mixture of block attention mechanism, supporting chunking and output gating.
    • Added the FlashMoBA attention module implementation in fla/layers/moba.py, supporting efficient mixture of block attention.
    • Introduced the moba_attn_varlen function (and its naive counterpart) in fla/ops/moba/ as the core operator for the module.
    • Registered MobaAttention in the main package imports and __all__ lists in fla/__init__.py and fla/layers/__init__.py.

Test plan:

  • Framework Compatibility: Verified that both MobaAttention and FlashMoBA are fully compatible with the FLAME training framework.
  • Evaluation Compatibility: Successfully tested compatibility with HuggingFace's lm-evaluation-harness for downstream task evaluations.

Breaking changes:

  • None. This PR strictly adds new modules and operators without modifying existing public APIs.
<!-- This is an auto-generated comment: release notes by coderabbit.ai -->

Summary by CodeRabbit

  • New Features
    • Two new attention layers (FlashMoBA, MobaAttention) added to the public API.
    • Mixture-of-block attention with configurable chunking and top‑k selection for selective attention.
    • Efficient variable‑length attention with caching support, rotary embeddings, and optimized native ops for faster execution and lower memory.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Changed files

  • fla/__init__.py (modified, +2/-0)
  • fla/layers/__init__.py (modified, +2/-0)
  • fla/layers/moba.py (added, +269/-0)
  • fla/ops/moba/__init__.py (added, +10/-0)
  • fla/ops/moba/naive.py (added, +412/-0)

Code Example

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

---

from flash_attn import flash_attn_func, flash_attn_varlen_func

---

from flash_attn.cute import flash_attn_func, flash_attn_varlen_func

---

ImportError: cannot import name 'flash_attn_func' from 'flash_attn' (unknown location)

---

Traceback (most recent call last):
  File "/home/joshua/llm-fine-tune-hf/main.py", line 68, in <module>
    main()
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/home/joshua/llm-fine-tune-hf/main.py", line 54, in main
    return train(config)
  File "/home/joshua/llm-fine-tune-hf/src/pipelines/pipeline.py", line 55, in train
    model = setup.get_model()
  File "/home/joshua/llm-fine-tune-hf/src/utils/setup.py", line 116, in get_model
    model = AutoModelForCausalLM.from_pretrained(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py", line 374, in from_pretrained
    return model_class.from_pretrained(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4094, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 461, in __init__
    super().__init__(config)
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/modeling_utils.py", line 1260, in __init__
    self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/modeling_utils.py", line 1893, in _check_and_adjust_attn_implementation
    lazy_import_flash_attention(applicable_attn_implementation)
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 171, in lazy_import_flash_attention
    _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 96, in _lazy_imports
    from flash_attn import flash_attn_func, flash_attn_varlen_func
ImportError: cannot import name 'flash_attn_func' from 'flash_attn' (unknown location)

---

from flash_attn.cute import flash_attn_func, flash_attn_varlen_func

---

from flash_attn import flash_attn_func

---

ImportError: cannot import name 'flash_attn_func' from 'flash_attn'
RAW_BUFFERClick to expand / collapse

Feature request

Support flash-attn-4 (flash_attn.cute) in Transformers attention backend selection

System Info

  • transformers==5.3.0
  • torch==2.10.0+cu128
  • flash-attn-4==4.0.0b4
  • accelerate==1.13.0
  • trl==0.29.0
  • peft==0.18.0
  • deepspeed==0.18.7
  • tokenizers==0.22.2
  • huggingface_hub==1.6.0
  • Python 3.12
  • CUDA 12.8
  • GPU: NVIDIA Blackwell (sm120)

Information

  • The official example scripts
  • My own modified scripts
  • I am willing to open a PR

Reproduction

I am testing a Blackwell environment with:

  • PyTorch 2.10
  • CUDA 12.8
  • flash-attn-4
  • transformers 5.3.0

Model loading fails before training starts when I pass:

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

Actual behavior

Transformers appears to still expect the FlashAttention v2-style top-level import:

from flash_attn import flash_attn_func, flash_attn_varlen_func

But flash-attn-4 exposes the relevant API under:

from flash_attn.cute import flash_attn_func, flash_attn_varlen_func

As a result, model initialization fails with:

ImportError: cannot import name 'flash_attn_func' from 'flash_attn' (unknown location)

This is the traceback I get during training:

Traceback (most recent call last):
  File "/home/joshua/llm-fine-tune-hf/main.py", line 68, in <module>
    main()
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/home/joshua/llm-fine-tune-hf/main.py", line 54, in main
    return train(config)
  File "/home/joshua/llm-fine-tune-hf/src/pipelines/pipeline.py", line 55, in train
    model = setup.get_model()
  File "/home/joshua/llm-fine-tune-hf/src/utils/setup.py", line 116, in get_model
    model = AutoModelForCausalLM.from_pretrained(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py", line 374, in from_pretrained
    return model_class.from_pretrained(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4094, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 461, in __init__
    super().__init__(config)
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/modeling_utils.py", line 1260, in __init__
    self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/modeling_utils.py", line 1893, in _check_and_adjust_attn_implementation
    lazy_import_flash_attention(applicable_attn_implementation)
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 171, in lazy_import_flash_attention
    _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(
  File "/home/joshua/anaconda3/envs/joshpp/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 96, in _lazy_imports
    from flash_attn import flash_attn_func, flash_attn_varlen_func
ImportError: cannot import name 'flash_attn_func' from 'flash_attn' (unknown location)

Expected behavior

One of the following would solve this cleanly:

  1. Detect flash-attn-4 and import from flash_attn.cute when that package is installed.
  2. Introduce an explicit backend such as attn_implementation="flash_attention_4".
  3. Document that flash-attn-4 is not yet supported by the current attention backend selection logic.

Why this matters

Blackwell users moving to newer CUDA / PyTorch stacks are likely to try flash-attn-4, but the current import path fails before training begins. This makes the newer FA4 stack unusable from stock Transformers attention selection even though the FA4 functions are present and importable from flash_attn.cute.

Additional notes

In the same machine, an older stack works:

  • transformers==4.57.3
  • flash_attn==2.8.3

That older stack exports flash_attn_func at the package top level, so the current Transformers import path works there.

By contrast, in the newer environment:

from flash_attn.cute import flash_attn_func, flash_attn_varlen_func

works, while:

from flash_attn import flash_attn_func

does not.

Motivation

I am trying to use a newer Blackwell training stack with torch==2.10.0+cu128, CUDA 12.8, and flash-attn-4==4.0.0b4.

At the moment, transformers==5.3.0 appears to assume the FlashAttention v2-style top-level API when attn_implementation="flash_attention_2" is selected. However, flash-attn-4 exposes its functions under flash_attn.cute instead of the older top-level import path. Because of this, model loading fails before training even starts with:

ImportError: cannot import name 'flash_attn_func' from 'flash_attn'

This makes the newer FA4 stack unusable from stock Transformers attention backend selection, even though the FA4 functions themselves are present and importable.

This seems related in spirit to earlier flash-attn compatibility/import issues, for example:

  • #35899
  • #27002

My main motivation is to make newer Blackwell-oriented FlashAttention stacks usable from Transformers without requiring users to patch library internals locally.

Your contribution

Yes, I can help with a PR.

I can test proposed changes on a local Blackwell environment using:

  • torch==2.10.0+cu128
  • flash-attn-4==4.0.0b4
  • transformers==5.3.0

If the maintainers agree on the intended direction, I can help with:

  • validating a fix for flash-attn-4 detection/import
  • testing whether flash_attn.cute can be supported safely
  • verifying that the change does not break the existing flash_attn v2 path

I have read the contribution guidance and can prepare a focused PR once the preferred approach is confirmed.

extent analysis

Fix Plan

Update Transformers to Support flash-attn-4

Step 1: Update transformers to Detect flash-attn-4

Update transformers to detect flash-attn-4 and import from flash_attn.cute when that package is installed.

# transformers/modeling_utils.py
def _check_and_adjust_attn_implementation(config, applicable_attn_implementation):
    if applicable_attn_implementation == "flash_attention_2":
        try:
            import importlib.util
            spec = importlib.util.find_spec("flash_attn.cute")
            if spec:
                from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
                return "flash_attn.cute"
        except ImportError:
            pass
    # ... rest of the function remains the same ...

Step 2: Add flash_attn_4 as a Valid Backend

Add flash_attn_4 as a valid backend option in transformers.

# transformers/modeling_utils.py
def _check_and_adjust_attn_implementation(config, applicable_attn_implementation):
    # ... rest of the function remains the same ...
    elif applicable_attn_implementation == "flash_attn_4":
        try:
            import importlib.util
            spec = importlib.util.find_spec("flash_attn.cute")
            if spec:
                from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
                return "flash_attn.cute"
        except ImportError:
            pass
    # ... rest of the function remains the same ...

Step 3: Update Example Scripts

Update example scripts to use the new flash_attn_4 backend option.

# transformers/examples.py
def train(config):
    model = AutoModelForCausalLM.from_pretrained(

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

One of the following would solve this cleanly:

  1. Detect flash-attn-4 and import from flash_attn.cute when that package is installed.
  2. Introduce an explicit backend such as attn_implementation="flash_attention_4".
  3. Document that flash-attn-4 is not yet supported by the current attention backend selection logic.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - ✅(Solved) Fix flash-attn-4 (flash_attn.cute) is not supported by attn_implementation="flash_attention_2" [1 pull requests, 2 comments, 2 participants]