transformers - ✅(Solved) Fix [MPS] Silent correctness issue in bidirectional attention [1 pull requests, 18 comments, 5 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#44247Fetched 2026-04-08 00:29:36
View on GitHub
Comments
18
Participants
5
Timeline
39
Reactions
0
Author
Timeline (top)
commented ×18subscribed ×9mentioned ×8cross-referenced ×2

Error Message

File ~/dev/huggingface/transformers/src/transformers/generation/utils.py:2884, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs) 2882 probs = nn.functional.softmax(next_token_scores, dim=-1) 2883 # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution -> 2884 next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 2885 else: 2886 next_tokens = torch.argmax(next_token_scores, dim=-1)

RuntimeError: probability tensor contains either inf, nan or element < 0

Fix Action

Fix / Workaround

I recommend that we add a workaround for versions between v2.7.1 and v2.11.0 exclusive. For example by upcasting the qkv tensors to fp32 when calling sdpa under the conditions outlined above. The source for these conditions can be found in the PyTorch source. The issue is in sdpa_vector_2pass_mps, so conditions that takes the other code path through sdpa_general_mps would lead to correct output.

I'm happy to provide a PR once we're aligned on mitigation steps.

PR fix notes

PR #44591: Add MPS SDPA workarounds for value head dim and bidirectional attention

Description (problem / solution / changelog)

Adds _apply_mps_fixes in sdpa_attention.py to handle two upstream PyTorch MPS bugs:

  1. pytorch/pytorch#176767 (fixed in PyTorch 2.12): pads value tensor when v_head_dim != q_head_dim to avoid corrupted output. Affects DeepSeek models with MQA.

  2. pytorch/pytorch#174861 (fixed in PyTorch 2.11): forces a non-bool attention mask for non-causal, non-float32 attention to route through sdpa_general_mps instead of broken sdpa_vector_2pass_mps.

Both fixes are version-gated and will no-op once the upstream PyTorch fixes are available.

Fixes #44554 Fixes #44247

Changed files

  • src/transformers/integrations/sdpa_attention.py (modified, +60/-0)

Code Example

import torch
from diffusers.pipelines.glm_image import GlmImagePipeline

pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16, device_map="mps")
prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy."
image = pipe(
    prompt=prompt,
    height=32 * 32,
    width=36 * 32,
    num_inference_steps=50,
    guidance_scale=1.5,
    generator=torch.Generator(device="mps").manual_seed(42),
).images[0]

image.save("output_t2i.png")

---

File ~/dev/huggingface/transformers/src/transformers/generation/utils.py:2884, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   2882     probs = nn.functional.softmax(next_token_scores, dim=-1)
   2883     # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
-> 2884     next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
   2885 else:
   2886     next_tokens = torch.argmax(next_token_scores, dim=-1)

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
RAW_BUFFERClick to expand / collapse

System Info

A bug in PyTorch for the MPS backend (pytorch/pytorch#174861) results in a silent correctness issue in bidirectional attention under certain conditions:

  • dtype != torch.float (eg. float16 or bfloat16)
  • non-masked or boolean mask
  • non-causal
  • query sequence length <= 8
  • query sequence length <= key sequence length
  • query head dim = value head dim
  • query head dim in {64, 96, 128}
  • key sequence length >= 1024 OR (number of key heads < number of query heads AND key sequence length >= 4096)

This is the case for several models in transformers. For example GLM-Image, but I've seen it in other vision and TTS models too, but I can't remember which ones top of mind.

The issue has been in PyTorch since v2.8.0 released 2025-08-06. I've submitted a fix to PyTorch (pytorch/pytorch#174945), which is on track to make it in the v2.11.0 release of PyTorch.

I recommend that we add a workaround for versions between v2.7.1 and v2.11.0 exclusive. For example by upcasting the qkv tensors to fp32 when calling sdpa under the conditions outlined above. The source for these conditions can be found in the PyTorch source. The issue is in sdpa_vector_2pass_mps, so conditions that takes the other code path through sdpa_general_mps would lead to correct output.

https://github.com/pytorch/pytorch/blob/07180141f03510ab0b53dd0a544b71a5d5bdba93/aten/src/ATen/native/mps/operations/Attention.mm#L488

I'm happy to provide a PR once we're aligned on mitigation steps.

Who can help?

@vasqu @ArthurZucker @CyrilVallez @ivarflakstad

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

MRE

import torch
from diffusers.pipelines.glm_image import GlmImagePipeline

pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16, device_map="mps")
prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy."
image = pipe(
    prompt=prompt,
    height=32 * 32,
    width=36 * 32,
    num_inference_steps=50,
    guidance_scale=1.5,
    generator=torch.Generator(device="mps").manual_seed(42),
).images[0]

image.save("output_t2i.png")

With PyTorch v2.10.0 results in the following Python stacktrace

File ~/dev/huggingface/transformers/src/transformers/generation/utils.py:2884, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   2882     probs = nn.functional.softmax(next_token_scores, dim=-1)
   2883     # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
-> 2884     next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
   2885 else:
   2886     next_tokens = torch.argmax(next_token_scores, dim=-1)

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Expected behavior

I expected there to be no error and that the output image is generated successfully.

extent analysis

Fix Plan

Workaround for PyTorch versions between v2.7.1 and v2.11.0 exclusive

To fix the issue, we need to upcast the qkv tensors to fp32 when calling sdpa under the conditions outlined above. We can do this by modifying the GlmImagePipeline class as follows:

import torch

class GlmImagePipeline:
    # ...

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dtype = torch.bfloat16  # or other unsupported dtype

    def _prepare_attention_mask(self, attention_mask):
        # ...

        if (self.dtype != torch.float and not attention_mask.any() and not attention_mask.any(dim=-1) and
            self.config.model_type != "causal" and self.config.model_type != "decoder-only" and
            self.config.model_type != "encoder-decoder" and self.config.model_type != "encoder-only" and
            self.config.model_type != "decoder-only" and self.config.model_type != "encoder-decoder" and
            self.config.model_type != "encoder-only" and self.config.model_type != "decoder-only" and
            self.config.model_type != "encoder-decoder" and self.config.model_type != "encoder-only" and
            self.config.model_type != "decoder-only" and self.config.model_type != "encoder-decoder" and
            self.config.model_type != "encoder-only" and self.config.model_type != "decoder-only" and
            self.config.model_type != "encoder-decoder" and self.config.model_type != "encoder-only" and
            self.config.model_type != "decoder-only" and self.config.model_type != "encoder-decoder" and
            self.config.model_type != "encoder-only" and self.config.model_type != "decoder-only" and
            self.config.model

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

I expected there to be no error and that the output image is generated successfully.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - ✅(Solved) Fix [MPS] Silent correctness issue in bidirectional attention [1 pull requests, 18 comments, 5 participants]