transformers - 💡(How to fix) Fix [Bug] GlmMoeDsa crashes on second forward pass — stale indexer cache [2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#44995Fetched 2026-04-08 01:30:57
View on GitHub
Comments
2
Participants
2
Timeline
4
Reactions
0
Author
Timeline (top)
commented ×2mentioned ×1subscribed ×1

Error Message

With CUDA_LAUNCH_BLOCKING=1:

  File ".../transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py", line 414, in forward
    index_mask.scatter_(-1, topk_indices, 0.0)  # [B, S, T]
torch.AcceleratorError: CUDA error: device-side assert triggered

The underlying issue is at modeling_glm_moe_dsa.py:198:

k_cached = torch.cat([self._cached_keys, k], dim=1)  # [B, T, D]

On the second forward call, self._cached_keys still holds stale state from the first call, leading to shape mismatches or invalid indices.

Code Example

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("yujiepan/glm-5-tiny-random", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("yujiepan/glm-5-tiny-random")

inputs = tokenizer("Hello", return_tensors="pt").to(model.device)

# First forward: OK
out1 = model(**inputs)
print(out1.logits.shape)  # torch.Size([1, 1, 154880])

# Second forward: CRASH
out2 = model(**inputs)  # AcceleratorError: CUDA error: device-side assert triggered

---

File ".../transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py", line 414, in forward
    index_mask.scatter_(-1, topk_indices, 0.0)  # [B, S, T]
torch.AcceleratorError: CUDA error: device-side assert triggered

---

k_cached = torch.cat([self._cached_keys, k], dim=1)  # [B, T, D]
RAW_BUFFERClick to expand / collapse

System Info

  • transformers version: 5.3.0
  • Platform: Linux
  • Python version: 3.13.5
  • PyTorch version: 2.8.0+cu128

Who can help?

@Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Reproduction

GlmMoeDsa models crash on any second forward pass. The DSA indexer's _cached_keys and _cached_indices persist between calls and cause shape mismatches or out-of-bounds scatter indices.

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("yujiepan/glm-5-tiny-random", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("yujiepan/glm-5-tiny-random")

inputs = tokenizer("Hello", return_tensors="pt").to(model.device)

# First forward: OK
out1 = model(**inputs)
print(out1.logits.shape)  # torch.Size([1, 1, 154880])

# Second forward: CRASH
out2 = model(**inputs)  # AcceleratorError: CUDA error: device-side assert triggered

Same issue with yujiepan/glm-moe-dsa-tiny-random.

Error

With CUDA_LAUNCH_BLOCKING=1:

  File ".../transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py", line 414, in forward
    index_mask.scatter_(-1, topk_indices, 0.0)  # [B, S, T]
torch.AcceleratorError: CUDA error: device-side assert triggered

The underlying issue is at modeling_glm_moe_dsa.py:198:

k_cached = torch.cat([self._cached_keys, k], dim=1)  # [B, T, D]

On the second forward call, self._cached_keys still holds stale state from the first call, leading to shape mismatches or invalid indices.

Expected behavior

The model should be callable multiple times without error. The DSA indexer should either reset its cache between forward passes or not use persistent state for inference without KV cache.

Additional context

This is related to other known GlmMoeDsa indexer issues (#44360, #44263). The stale cache issue compounds with those bugs — even if the indexer logic is fixed, the persistent cache between calls will continue to cause problems.

extent analysis

Fix Plan

To resolve the issue of the GlmMoeDsa models crashing on any second forward pass due to stale cache, we need to reset the DSA indexer's cache between forward passes.

Here are the steps to fix the issue:

  • Modify the forward method in modeling_glm_moe_dsa.py to reset the cache before each forward pass.
  • Alternatively, modify the DSAIndexer class to reset its cache in the __call__ method.

Example Code

# In modeling_glm_moe_dsa.py
class GlmMoeDsaModel(...):
    def forward(self, ...):
        # Reset the DSA indexer's cache
        self.dsa_indexer._cached_keys = None
        self.dsa_indexer._cached_indices = None
        
        # Rest of the forward method remains the same
        ...

Alternatively, you can modify the DSAIndexer class:

# In dsa_indexer.py
class DSAIndexer(...):
    def __call__(self, ...):
        # Reset the cache
        self._cached_keys = None
        self._cached_indices = None
        
        # Rest of the __call__ method remains the same
        ...

Verification

To verify that the fix worked, you can run the following code:

model = AutoModelForCausalLM.from_pretrained("yujiepan/glm-5-tiny-random", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("yujiepan/glm-5-tiny-random")

inputs = tokenizer("Hello", return_tensors="pt").to(model.device)

# First forward: OK
out1 = model(**inputs)
print(out1.logits.shape)  # torch.Size([1, 1, 154880])

# Second forward: Should not crash
out2 = model(**inputs)
print(out2.logits.shape)  # torch.Size([1, 1, 154880])

If the second forward pass does not crash, the fix is successful.

Extra Tips

  • Make sure to test the model with different inputs and scenarios to ensure that the fix does not introduce any new issues.
  • Consider submitting a pull request to the transformers repository to fix the issue for all users.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

The model should be callable multiple times without error. The DSA indexer should either reset its cache between forward passes or not use persistent state for inference without KV cache.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - 💡(How to fix) Fix [Bug] GlmMoeDsa crashes on second forward pass — stale indexer cache [2 comments, 2 participants]