transformers - ✅(Solved) Fix [BUG][CI] SwitchTransformers and TimmWrapperModel dtype mismatches in bfloat16 inference [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45072Fetched 2026-04-08 01:40:43
View on GitHub
Comments
0
Participants
1
Timeline
4
Reactions
0
Participants
Timeline (top)
cross-referenced ×2closed ×1labeled ×1

Error Message

import torch from transformers import SwitchTransformersModel

try: model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).to("cuda").eval() input_ids = torch.ones(1, 16, dtype=torch.long, device="cuda") output = model(input_ids, decoder_input_ids=input_ids) print(output.last_hidden_state.shape) except Exception as e: print(e)

Fix Action

Fixed

PR fix notes

PR #45074: fix(models): Fix dtype mismatch in SwitchTransformers and TimmWrapperModel

Description (problem / solution / changelog)

What does this PR do?

The following dtype mismatch use cases were identified and fixed in this PR:

Switch Transformers: 7938e91fa refactored all MoE models for vLLM compatibility; in that refactor, the _cast_classifier() method was removed from SwitchTransformersTop1Router but no dtype cast was added. Casting hidden_states to classifier.weight.dtype before the linear call fixes that! → TimmWrapper: 6217adc6c8 changed the default dtype behavior to "auto"; in that commit, pixel_values.to(self.device, self.dtype) was regressed to pixel_values.to(self.device) dropping the dtype cast. I'm not too sure why it was dropped; but restoring it seems logical to fix the use case. → For more details on reproducing the bug and the output screenshots, please visit the linked issue!

cc: @Rocketknight1

Fixes #45072

CI run test coverage of this behavior (as suggested by @ydshieh) :):

SwitchTransformers: → test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_generate_with_past_key_valuestest_modeling_switch_transformers.py::SwitchTransformersModelTest::test_model_fp16_forwardtest_modeling_switch_transformers.py::SwitchTransformerModelIntegrationTests::test_small_logits TimmWrapper: → TimmWrapperModelTest does not have explicit bfloat16 forward pass tests; added one in this PR for complete coverage.

Repro output after the fixes (feel free to cross-check):

<img width="500" height="300" alt="1" src="https://github.com/user-attachments/assets/6c97e0ca-e3ce-4fe7-b501-31e1b144d4ad" /> <img width="380" height="175" alt="1-1" src="https://github.com/user-attachments/assets/5cad72ed-9deb-43d4-b54d-5a3ad8ad793d" />

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • Did you fix any necessary existing tests?

Changed files

  • src/transformers/models/switch_transformers/modeling_switch_transformers.py (modified, +1/-1)
  • src/transformers/models/switch_transformers/modular_switch_transformers.py (modified, +1/-1)
  • src/transformers/models/timm_wrapper/modeling_timm_wrapper.py (modified, +1/-1)
  • tests/models/timm_wrapper/test_modeling_timm_wrapper.py (modified, +9/-0)

Code Example

import torch
from transformers import SwitchTransformersModel

try:
    model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).to("cuda").eval()
    input_ids = torch.ones(1, 16, dtype=torch.long, device="cuda")
    output = model(input_ids, decoder_input_ids=input_ids)
    print(output.last_hidden_state.shape)
except Exception as e:
    print(e)

---

import torch
from transformers import TimmWrapperModel, TimmWrapperConfig

try:
    config = TimmWrapperConfig(architecture="resnet18")
    model = TimmWrapperModel(config).to("cuda", torch.bfloat16).eval()
    pixel_values = torch.randn(1, 3, 224, 224, device="cuda")
    output = model(pixel_values)
    print(output.last_hidden_state.shape)
except Exception as e:
    print(e)
RAW_BUFFERClick to expand / collapse

System Info

  • transformers version: 5.0.0.dev0
  • Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • huggingface_hub version: 1.3.2
  • safetensors version: 0.7.0
  • accelerate version: 1.12.0
  • Accelerate config: not installed
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.9.1+cu128 (CUDA)
  • GPU type: NVIDIA L4
  • NVIDIA driver version: 550.90.07
  • CUDA version: 12.4

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Switch Transformers:

import torch
from transformers import SwitchTransformersModel

try:
    model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).to("cuda").eval()
    input_ids = torch.ones(1, 16, dtype=torch.long, device="cuda")
    output = model(input_ids, decoder_input_ids=input_ids)
    print(output.last_hidden_state.shape)
except Exception as e:
    print(e)

TimmWrapper:

import torch
from transformers import TimmWrapperModel, TimmWrapperConfig

try:
    config = TimmWrapperConfig(architecture="resnet18")
    model = TimmWrapperModel(config).to("cuda", torch.bfloat16).eval()
    pixel_values = torch.randn(1, 3, 224, 224, device="cuda")
    output = model(pixel_values)
    print(output.last_hidden_state.shape)
except Exception as e:
    print(e)

→ Loading "google/switch-base-8" in bfloat16 and running a forward pass crashes with a dtype mismatch in the MoE router's linear layer; got: float != c10::BFloat16. → Instantiating a TimmWrapperModel in bfloat16 on CUDA and passing float32 pixel_values crashes; the first conv layer raises Input type (torch.cuda.FloatTensor) and weight type (CUDABFloat16Type) should be the same.

Current Repro Output:

<img width="500" height="300" alt="Image" src="https://github.com/user-attachments/assets/0402636c-c8f2-4273-855b-750c37ec93dd" /> <img width="500" height="300" alt="Image" src="https://github.com/user-attachments/assets/6a3a2586-ee85-4b79-8985-0e7146b3eb97" />

Expected behavior

→ Both models should complete bfloat16 inference successfully.

extent analysis

Fix Plan

To resolve the dtype mismatch issues, we need to ensure that the input data and model weights are of the same type. Here are the steps to fix the issues:

For Switch Transformers:

  1. Cast input_ids to bfloat16: Modify the input_ids and decoder_input_ids to have the same dtype as the model.
input_ids = torch.ones(1, 16, dtype=torch.long, device="cuda").to(torch.bfloat16)

However, since input_ids should be integers, we should keep them as torch.long and ensure the model's input embedding layer can handle torch.long inputs.

  1. Update the model to handle bfloat16 inputs: The error message indicates a dtype mismatch in the MoE router's linear layer. We need to ensure that the linear layer can handle bfloat16 inputs. This might require updating the SwitchTransformersModel to use bfloat16-compatible layers or modifying the input embedding layer to cast inputs to bfloat16.

For TimmWrapper:

  1. Cast pixel_values to bfloat16: Modify the pixel_values to have the same dtype as the model.
pixel_values = torch.randn(1, 3, 224, 224, device="cuda", dtype=torch.bfloat16)
  1. Update the model to handle bfloat16 inputs: Similar to the Switch Transformers issue, we need to ensure that the first conv layer can handle bfloat16 inputs.

Verification

To verify that the fixes work, run the modified code and check that:

  • The models complete inference without crashing due to dtype mismatches.
  • The output shapes and types are as expected.

Extra Tips

  • When working with mixed precision, ensure that all layers and inputs are compatible with the desired dtype.
  • Use torch.autocast or torch.cuda.amp to automatically cast inputs and models to the desired dtype.
  • Be cautious when casting integer inputs (like input_ids) to floating-point types, as this can lead to precision issues.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

→ Both models should complete bfloat16 inference successfully.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING