pytorch - 💡(How to fix) Fix FP8 Per-Head Quantized Scaled Dot Product Attention [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181181Fetched 2026-04-23 07:22:06
View on GitHub
Comments
0
Participants
1
Timeline
19
Reactions
0
Participants
Timeline (top)
mentioned ×8subscribed ×8labeled ×3
RAW_BUFFERClick to expand / collapse

New Feature for Release

A new prototype op for fp8 per-head quantized scaled dot product attention using the flash attention 3 backend. Prototype op can be called directly from torch.nn.attention.experimental._scaled_dot_product_attention_quantized. However, more detailed e2e usage guides for how to build off of these primitives are in torchao (check torchao 0.17.0 release)

Point(s) of contact

howardzhang-cv

Release Mode (pytorch/pytorch features only)

In-tree

Out-Of-Tree Repo

No response

Description and value to the user

Provide prototype ops for calling the fp8 per-head quantized flash attention 3 backend. Initial tests have shown good speedup for minimal quality degradation for image generation, video generation, and LLMs.

Link to design doc, GitHub issues, past submissions, etc

No response

What feedback adopters have provided

No response

Plan for documentations / tutorials

Tutorial exists

Additional context for tutorials

Pytorch op can be used by itself, in which case tutorial is not needed, since it's a bare op. Users are encouraged to look at the torchao wrapper for this op, which includes e2e model wrapping and direct SDPA replacements (with triton per-head quantization and rope-fused quantization variants). Tutorial can be found in the 0.17.0 torchao release notes.

Marketing/Blog Coverage

Yes

Are you requesting other marketing assistance with this feature?

No response

Release Version

No response

OS / Platform / Compute Coverage

For CUDA SM90 Hopper architectures

Testing Support (CI, test cases, etc..)

Integration tests located in test/nn/attention/test_fa3.py. Additional E2E tests have been done. Results below: <img width="1284" height="735" alt="Image" src="https://github.com/user-attachments/assets/4432dc4d-18e8-4bea-bd79-e80c739000c4" /> Single layer attention tests show fp8 is faster than bf16 past 8192 sequence length (including the time for per-head quantization).

<img width="1266" height="560" alt="Image" src="https://github.com/user-attachments/assets/ac395ef9-87dd-4792-bc99-0f02c127d374" /> Tests with llama3 prefill show fp8 is faster than bf16 past 4096 sequence length. These tests are shown with fusion of the RoPE and per-head quantization computations, which improves performance in e2e model inference (see torchao documentation for more information). Results show minimal quality degradation (perplexity 7.54 -> 7.62).

Additional tests with Wan2.1-14B show 1.18x speedup in video generation (1080p, 81 frames). Naive application leads to harsh quality drops. Suggested to skip quantization on the first and last few layers for good runtime/quality tradeoff.

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

extent analysis

TL;DR

To utilize the new prototype op for fp8 per-head quantized scaled dot product attention, call it directly from torch.nn.attention.experimental._scaled_dot_product_attention_quantized or explore the torchao wrapper for end-to-end model wrapping and direct SDPA replacements.

Guidance

  • For optimal performance, consider using the torchao wrapper, which includes end-to-end model wrapping and direct SDPA replacements with triton per-head quantization and rope-fused quantization variants.
  • When applying the prototype op, be mindful of sequence length; fp8 is faster than bf16 past 8192 sequence length for single layer attention tests and past 4096 sequence length for tests with llama3 prefill.
  • To mitigate quality degradation, consider skipping quantization on the first and last few layers, as suggested by tests with Wan2.1-14B.
  • Review the tutorial in the 0.17.0 torchao release notes for more detailed guidance on using the prototype op and torchao wrapper.

Example

No explicit code example is provided due to the lack of specific implementation details in the issue.

Notes

The provided solution is tailored for CUDA SM90 Hopper architectures and may require adjustments for other platforms. Quality degradation and performance improvements may vary depending on the specific use case and model architecture.

Recommendation

Apply the workaround by utilizing the torchao wrapper and considering sequence length and layer quantization strategies to balance runtime and quality tradeoffs. This approach allows for leveraging the speedup provided by fp8 quantization while minimizing quality degradation.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix FP8 Per-Head Quantized Scaled Dot Product Attention [1 participants]