transformers - ✅(Solved) Fix `Qwen3VLVisionPatchEmbed.proj` (`nn.Conv3d` with `stride == kernel`) is ~50,000× slower than equivalent `nn.Linear` on Blackwell + bf16 [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45750Fetched 2026-05-04 04:58:16
View on GitHub
Comments
0
Participants
1
Timeline
8
Reactions
0
Participants
Timeline (top)
mentioned ×3subscribed ×3cross-referenced ×1labeled ×1

Fix Action

Fix / Workaround

torch.cuda.synchronize(); t0 = time.time() h = vt.patch_embed(pv); torch.cuda.synchronize() print(f"patch_embed: {(time.time()-t0)*1000:.1f} ms, shape={tuple(h.shape)}")

   patch_embed:  16111.3 ms, shape=(6080, 1024)   ← 96% of total forward
   pos_embed:    22.8 ms
   rot_pos_emb:  20.7 ms
   24 blocks total: 56.4 ms (mean 2.3 ms)
   merger:       0.5 ms

The 24-layer ViT runs in 56 ms total. The single patch_embed takes 16,111 ms — 287× more than the rest combined.

PR fix notes

PR #45771: perf(qwen3_vl): replace Conv3d with F.linear in patch embed forward

Description (problem / solution / changelog)

What does this PR do?

Replaces the Conv3d forward pass in Qwen3VLVisionPatchEmbed with F.linear on the reshaped weight. When stride == kernel_size, Conv3d is mathematically equivalent to extracting non-overlapping patches and applying a linear projection. The Conv3d codepath triggers an extremely slow cuDNN kernel on some GPU/dtype combinations (~50,000x slower on Blackwell + bf16, ~62x on other configs per the issue benchmarks).

The fix reshapes the input to (batch, in_channels * t * h * w) and uses F.linear(input, weight.view(embed_dim, -1), bias). Same weight tensor, just reshaped at forward time, so existing checkpoints load without changes.

Before submitting

  • Did you read the contributor guideline?
  • This PR fixes a bug (issue #45750)
  • Backward compatible (same weights, same outputs, just faster)

Fixes #45750.

Changed files

  • src/transformers/models/qwen3_vl/modeling_qwen3_vl.py (modified, +8/-3)

Code Example

transformers version: 5.0.0.dev0
PyTorch:              2.9.0+cu128
CUDA:                 12.8
cuDNN:                9.10.0.2 (91002)
Python:               3.14.0
flash-attn:           2.8.3 (installed)
GPU:                  NVIDIA GeForce RTX 5090 (Blackwell, compute capability 12.0, sm_120)
OS:                   Linux 6.8.0-110-generic, glibc 2.39

---

import torch, time
   for size in [4096, 8192]:
       a = torch.randn(size, size, dtype=torch.bfloat16, device="cuda")
       b = torch.randn(size, size, dtype=torch.bfloat16, device="cuda")
       for _ in range(3): _ = a @ b
       torch.cuda.synchronize(); t0 = time.time()
       for _ in range(10): c = a @ b
       torch.cuda.synchronize()
       e = time.time() - t0
       print(f"matmul {size}x{size}: {2 * size**3 * 10 / e / 1e12:.1f} TFLOPS")

---

matmul 4096x4096: 182.3 TFLOPS
   matmul 8192x8192: 223.7 TFLOPS

---

import time, torch
   from PIL import Image
   from transformers import AutoModelForImageTextToText, AutoProcessor
   
   torch.set_grad_enabled(False)
   proc = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-4B-Instruct")
   model = AutoModelForImageTextToText.from_pretrained(
       "Qwen/Qwen3-VL-4B-Instruct", dtype=torch.bfloat16,
   ).cuda().eval()
   
   def make_clip():
       return [Image.fromarray(torch.randint(0, 256, (720, 1280, 3),
               dtype=torch.uint8).numpy()) for _ in range(8)]
   
   def time_forward(bs):
       texts, images = [], []
       for _ in range(bs):
           frames = make_clip()
           msgs = [{"role": "user", "content":
                    [{"type": "image", "image": img} for img in frames]
                    + [{"type": "text", "text": "Describe."}]}]
           texts.append(proc.apply_chat_template(
               msgs, tokenize=False, add_generation_prompt=True))
           images.append(frames)
       inputs = proc(text=texts, images=images,
                     return_tensors="pt", padding=True)
       inputs = {k: (v.cuda() if isinstance(v, torch.Tensor) else v)
                 for k, v in inputs.items()}
       inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
       keys = ("input_ids","attention_mask","pixel_values","image_grid_thw")
       args = {k: inputs[k] for k in keys if k in inputs}
       for rep in range(2):
           torch.cuda.synchronize(); t0 = time.time()
           with torch.amp.autocast("cuda", dtype=torch.bfloat16):
               _ = model.model(**args, use_cache=False, return_dict=True)
           torch.cuda.synchronize()
           e = time.time() - t0
           print(f"  bs={bs} rep={rep}: {e:.2f}s ({e/bs*1000:.0f} ms/sample)")
   
   for bs in [1, 4, 8, 16]: time_forward(bs)

---

bs=1  rep=0: 16.70s (16700 ms/sample)
   bs=1  rep=1: 16.46s (16458 ms/sample)
   bs=4  rep=0: 65.30s (16325 ms/sample)
   bs=8  rep=0: 148.05s (18506 ms/sample)
   bs=8  rep=1: 148.01s (18501 ms/sample)
   bs=16 rep=0: 148.78s (9299 ms/sample)
   bs=16 rep=1: 148.34s (9271 ms/sample)

---

for impl in ["sdpa", "flash_attention_2", "eager"]:
       model = AutoModelForImageTextToText.from_pretrained(
           "Qwen/Qwen3-VL-4B-Instruct", dtype=torch.bfloat16,
           attn_implementation=impl).cuda().eval()
       # ... same time_forward(bs=8) ...

---

sdpa              bs=8: 148.05s (18506 ms/sample)
   flash_attention_2 bs=8: 147.64s (18455 ms/sample)
   eager             bs=8: 148.20s (18525 ms/sample)

---

import torch.nn.functional as F
   pv = inputs["pixel_values"]; grid_thw = inputs["image_grid_thw"]
   vt = model.visual
   
   torch.cuda.synchronize(); t0 = time.time()
   h = vt.patch_embed(pv); torch.cuda.synchronize()
   print(f"patch_embed:  {(time.time()-t0)*1000:.1f} ms, shape={tuple(h.shape)}")
   
   t0 = time.time(); pos = vt.fast_pos_embed_interpolate(grid_thw)
   torch.cuda.synchronize(); print(f"pos_embed:    {(time.time()-t0)*1000:.1f} ms")
   h = h + pos
   
   t0 = time.time(); rope = vt.rot_pos_emb(grid_thw)
   torch.cuda.synchronize(); print(f"rot_pos_emb:  {(time.time()-t0)*1000:.1f} ms")
   
   seq_len = h.size(0); h = h.reshape(seq_len, -1)
   rope = rope.reshape(seq_len, -1)
   emb = torch.cat((rope, rope), dim=-1)
   pos_emb = (emb.cos(), emb.sin())
   cu = torch.repeat_interleave(grid_thw[:,1]*grid_thw[:,2],
        grid_thw[:,0]).cumsum(0, dtype=torch.int32)
   cu = F.pad(cu, (1,0), value=0)
   
   times = []
   for i, blk in enumerate(vt.blocks):
       torch.cuda.synchronize(); t0 = time.time()
       h = blk(h, cu_seqlens=cu, position_embeddings=pos_emb)
       torch.cuda.synchronize()
       times.append((time.time()-t0)*1000)
   print(f"24 blocks total: {sum(times):.1f} ms (mean {sum(times)/24:.1f} ms)")
   
   t0 = time.time(); _ = vt.merger(h)
   torch.cuda.synchronize(); print(f"merger:       {(time.time()-t0)*1000:.1f} ms")

---

patch_embed:  16111.3 ms, shape=(6080, 1024)96% of total forward
   pos_embed:    22.8 ms
   rot_pos_emb:  20.7 ms
   24 blocks total: 56.4 ms (mean 2.3 ms)
   merger:       0.5 ms

---

class Qwen3VLVisionPatchEmbed(nn.Module):
       def __init__(self, config) -> None:
           super().__init__()
           self.patch_size = config.patch_size                  # 16
           self.temporal_patch_size = config.temporal_patch_size  # 2
           self.in_channels = config.in_channels                 # 3
           self.embed_dim = config.hidden_size                   # 1024
           kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
           self.proj = nn.Conv3d(
               self.in_channels, self.embed_dim,
               kernel_size=kernel_size, stride=kernel_size, bias=True,
           )
   
       def forward(self, hidden_states):
           target_dtype = self.proj.weight.dtype
           hidden_states = hidden_states.view(
               -1, self.in_channels, self.temporal_patch_size,
               self.patch_size, self.patch_size,
           )
           hidden_states = self.proj(
               hidden_states.to(dtype=target_dtype)
           ).view(-1, self.embed_dim)
           return hidden_states

---

import time, torch, torch.nn as nn
   torch.set_grad_enabled(False)
   
   conv = nn.Conv3d(3, 1024, kernel_size=(2,16,16),
                    stride=(2,16,16), bias=True).cuda().to(torch.bfloat16)
   
   out_dim, in_dim = 1024, 3*2*16*16  # 1536
   lin = nn.Linear(in_dim, out_dim, bias=True)
   lin.weight.data.copy_(conv.weight.detach().reshape(out_dim, in_dim))
   lin.bias.data.copy_(conv.bias.detach())
   lin = lin.cuda().to(torch.bfloat16)
   
   N = 6080  # patches in one 8-frame Qwen3-VL clip
   x_5d  = torch.randn(N, 3, 2, 16, 16, dtype=torch.bfloat16, device="cuda")
   x_flat = x_5d.reshape(N, -1).contiguous()
   
   for _ in range(3): _ = conv(x_5d); _ = lin(x_flat)
   
   torch.cuda.synchronize(); t0 = time.time()
   for _ in range(5): y_conv = conv(x_5d).view(N, -1)
   torch.cuda.synchronize(); t_conv = (time.time()-t0)/5
   
   torch.cuda.synchronize(); t0 = time.time()
   for _ in range(5): y_lin = lin(x_flat)
   torch.cuda.synchronize(); t_lin = (time.time()-t0)/5
   
   print(f"Conv3d:  {t_conv*1000:8.2f} ms")
   print(f"Linear:  {t_lin*1000:8.2f} ms")
   print(f"Speedup: {t_conv/t_lin:8.1f}x")
   diff = (y_conv.float() - y_lin.float()).abs().max().item()
   cos = torch.nn.functional.cosine_similarity(
       y_conv.float().flatten().unsqueeze(0),
       y_lin.float().flatten().unsqueeze(0)).item()
   print(f"max abs diff (bf16): {diff:.2e}")
   print(f"cosine similarity:   {cos:.6f}")

---

Conv3d:    16111.30 ms
   Linear:        0.30 ms
   Speedup:    53704.3x
   max abs diff (bf16): 1.56e-02
   cosine similarity:   0.999500

---

import torch, torch.nn as nn
torch.manual_seed(0); N = 100; C, T, P = 3, 2, 16; out_dim = 1024
conv = nn.Conv3d(C, out_dim, (T,P,P), stride=(T,P,P), bias=True)
in_dim = C*T*P*P
lin = nn.Linear(in_dim, out_dim, bias=True)
lin.weight.data.copy_(conv.weight.detach().reshape(out_dim, in_dim))
lin.bias.data.copy_(conv.bias.detach())

x_5d = torch.randn(N, C, T, P, P, dtype=torch.float32)
x_flat = x_5d.reshape(N, -1).contiguous()
with torch.no_grad():
    o_conv = conv(x_5d).view(N, -1)
    o_lin  = lin(x_flat)

abs_diff = (o_conv - o_lin).abs()
print(f"fp32 max abs diff:  {abs_diff.max().item():.2e}")
print(f"fp32 mean abs diff: {abs_diff.mean().item():.2e}")

---

fp32 max abs diff:  4.77e-07
fp32 mean abs diff: 7.61e-08

---

import time, torch, torch.nn as nn
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
    Qwen3VLVisionPatchEmbed,
)

def _fast_forward(self, hidden_states):
    target_dtype = self.proj.weight.dtype
    if isinstance(self.proj, nn.Conv3d):
        conv = self.proj
        out_dim = conv.out_channels
        in_dim = (conv.in_channels * conv.kernel_size[0]
                  * conv.kernel_size[1] * conv.kernel_size[2])
        w_flat = conv.weight.detach().reshape(out_dim, in_dim).contiguous()
        bias = conv.bias.detach().clone() if conv.bias is not None else None
        new_proj = nn.Linear(in_dim, out_dim, bias=bias is not None)
        new_proj.weight.data.copy_(w_flat)
        if bias is not None: new_proj.bias.data.copy_(bias)
        new_proj.to(device=conv.weight.device, dtype=conv.weight.dtype)
        self.proj = new_proj
    if hidden_states.dim() > 2 \
            or hidden_states.shape[-1] != self.proj.in_features:
        hidden_states = hidden_states.reshape(-1, self.proj.in_features)
    return self.proj(hidden_states.to(dtype=target_dtype))

Qwen3VLVisionPatchEmbed.forward = _fast_forward

# Reload model and run step-2-style timing again.

---

bs=1  rep=0: 0.27s (270 ms/sample)
bs=1  rep=1: 0.29s (290 ms/sample)
bs=8  rep=0: 2.16s (270 ms/sample)
bs=8  rep=1: 2.18s (273 ms/sample)
RAW_BUFFERClick to expand / collapse

System Info

transformers version: 5.0.0.dev0
PyTorch:              2.9.0+cu128
CUDA:                 12.8
cuDNN:                9.10.0.2 (91002)
Python:               3.14.0
flash-attn:           2.8.3 (installed)
GPU:                  NVIDIA GeForce RTX 5090 (Blackwell, compute capability 12.0, sm_120)
OS:                   Linux 6.8.0-110-generic, glibc 2.39

Who can help?

@yonigozlan @molbap @zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Confirm GPU is healthy. RTX 5090 should hit ~100–209 TFLOPS bf16 dense matmul.

    import torch, time
    for size in [4096, 8192]:
        a = torch.randn(size, size, dtype=torch.bfloat16, device="cuda")
        b = torch.randn(size, size, dtype=torch.bfloat16, device="cuda")
        for _ in range(3): _ = a @ b
        torch.cuda.synchronize(); t0 = time.time()
        for _ in range(10): c = a @ b
        torch.cuda.synchronize()
        e = time.time() - t0
        print(f"matmul {size}x{size}: {2 * size**3 * 10 / e / 1e12:.1f} TFLOPS")

    Output:

    matmul 4096x4096: 182.3 TFLOPS
    matmul 8192x8192: 223.7 TFLOPS

    Hardware is fine.

  2. Run a full vision-tower forward at multiple batch sizes, all with default settings.

    import time, torch
    from PIL import Image
    from transformers import AutoModelForImageTextToText, AutoProcessor
    
    torch.set_grad_enabled(False)
    proc = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-4B-Instruct")
    model = AutoModelForImageTextToText.from_pretrained(
        "Qwen/Qwen3-VL-4B-Instruct", dtype=torch.bfloat16,
    ).cuda().eval()
    
    def make_clip():
        return [Image.fromarray(torch.randint(0, 256, (720, 1280, 3),
                dtype=torch.uint8).numpy()) for _ in range(8)]
    
    def time_forward(bs):
        texts, images = [], []
        for _ in range(bs):
            frames = make_clip()
            msgs = [{"role": "user", "content":
                     [{"type": "image", "image": img} for img in frames]
                     + [{"type": "text", "text": "Describe."}]}]
            texts.append(proc.apply_chat_template(
                msgs, tokenize=False, add_generation_prompt=True))
            images.append(frames)
        inputs = proc(text=texts, images=images,
                      return_tensors="pt", padding=True)
        inputs = {k: (v.cuda() if isinstance(v, torch.Tensor) else v)
                  for k, v in inputs.items()}
        inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
        keys = ("input_ids","attention_mask","pixel_values","image_grid_thw")
        args = {k: inputs[k] for k in keys if k in inputs}
        for rep in range(2):
            torch.cuda.synchronize(); t0 = time.time()
            with torch.amp.autocast("cuda", dtype=torch.bfloat16):
                _ = model.model(**args, use_cache=False, return_dict=True)
            torch.cuda.synchronize()
            e = time.time() - t0
            print(f"  bs={bs} rep={rep}: {e:.2f}s ({e/bs*1000:.0f} ms/sample)")
    
    for bs in [1, 4, 8, 16]: time_forward(bs)

    Output:

bs=1  rep=0: 16.70s (16700 ms/sample)
bs=1  rep=1: 16.46s (16458 ms/sample)
bs=4  rep=0: 65.30s (16325 ms/sample)
bs=8  rep=0: 148.05s (18506 ms/sample)
bs=8  rep=1: 148.01s (18501 ms/sample)
bs=16 rep=0: 148.78s (9299 ms/sample)
bs=16 rep=1: 148.34s (9271 ms/sample)

Per-sample time is ~16 s regardless of batch — rules out DataLoader, collate, padding bugs.

  1. Eliminate attn_implementation as the cause. Test all three.

    for impl in ["sdpa", "flash_attention_2", "eager"]:
        model = AutoModelForImageTextToText.from_pretrained(
            "Qwen/Qwen3-VL-4B-Instruct", dtype=torch.bfloat16,
            attn_implementation=impl).cuda().eval()
        # ... same time_forward(bs=8) ...

    Output:

    sdpa              bs=8: 148.05s (18506 ms/sample)
    flash_attention_2 bs=8: 147.64s (18455 ms/sample)
    eager             bs=8: 148.20s (18525 ms/sample)

    All three implementations are identically slow → attention is not the cause.

  2. Per-component timing of Qwen3VLVisionModel.forward (bs=1, 8 frames).

    import torch.nn.functional as F
    pv = inputs["pixel_values"]; grid_thw = inputs["image_grid_thw"]
    vt = model.visual
    
    torch.cuda.synchronize(); t0 = time.time()
    h = vt.patch_embed(pv); torch.cuda.synchronize()
    print(f"patch_embed:  {(time.time()-t0)*1000:.1f} ms, shape={tuple(h.shape)}")
    
    t0 = time.time(); pos = vt.fast_pos_embed_interpolate(grid_thw)
    torch.cuda.synchronize(); print(f"pos_embed:    {(time.time()-t0)*1000:.1f} ms")
    h = h + pos
    
    t0 = time.time(); rope = vt.rot_pos_emb(grid_thw)
    torch.cuda.synchronize(); print(f"rot_pos_emb:  {(time.time()-t0)*1000:.1f} ms")
    
    seq_len = h.size(0); h = h.reshape(seq_len, -1)
    rope = rope.reshape(seq_len, -1)
    emb = torch.cat((rope, rope), dim=-1)
    pos_emb = (emb.cos(), emb.sin())
    cu = torch.repeat_interleave(grid_thw[:,1]*grid_thw[:,2],
         grid_thw[:,0]).cumsum(0, dtype=torch.int32)
    cu = F.pad(cu, (1,0), value=0)
    
    times = []
    for i, blk in enumerate(vt.blocks):
        torch.cuda.synchronize(); t0 = time.time()
        h = blk(h, cu_seqlens=cu, position_embeddings=pos_emb)
        torch.cuda.synchronize()
        times.append((time.time()-t0)*1000)
    print(f"24 blocks total: {sum(times):.1f} ms (mean {sum(times)/24:.1f} ms)")
    
    t0 = time.time(); _ = vt.merger(h)
    torch.cuda.synchronize(); print(f"merger:       {(time.time()-t0)*1000:.1f} ms")

    Output:

    patch_embed:  16111.3 ms, shape=(6080, 1024)   ← 96% of total forward
    pos_embed:    22.8 ms
    rot_pos_emb:  20.7 ms
    24 blocks total: 56.4 ms (mean 2.3 ms)
    merger:       0.5 ms

    The 24-layer ViT runs in 56 ms total. The single patch_embed takes 16,111 ms — 287× more than the rest combined.

  3. Inspect Qwen3VLVisionPatchEmbed (file: transformers/models/qwen3_vl/modeling_qwen3_vl.py, lines 59–76).

    class Qwen3VLVisionPatchEmbed(nn.Module):
        def __init__(self, config) -> None:
            super().__init__()
            self.patch_size = config.patch_size                  # 16
            self.temporal_patch_size = config.temporal_patch_size  # 2
            self.in_channels = config.in_channels                 # 3
            self.embed_dim = config.hidden_size                   # 1024
            kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
            self.proj = nn.Conv3d(
                self.in_channels, self.embed_dim,
                kernel_size=kernel_size, stride=kernel_size, bias=True,
            )
    
        def forward(self, hidden_states):
            target_dtype = self.proj.weight.dtype
            hidden_states = hidden_states.view(
                -1, self.in_channels, self.temporal_patch_size,
                self.patch_size, self.patch_size,
            )
            hidden_states = self.proj(
                hidden_states.to(dtype=target_dtype)
            ).view(-1, self.embed_dim)
            return hidden_states

    kernel_size == stride, no padding, no dilation → output windows are disjoint → mathematically equivalent to flatten + nn.Linear.

  4. Isolated benchmark: Conv3d vs equivalent Linear (no checkpoint needed).

    import time, torch, torch.nn as nn
    torch.set_grad_enabled(False)
    
    conv = nn.Conv3d(3, 1024, kernel_size=(2,16,16),
                     stride=(2,16,16), bias=True).cuda().to(torch.bfloat16)
    
    out_dim, in_dim = 1024, 3*2*16*16  # 1536
    lin = nn.Linear(in_dim, out_dim, bias=True)
    lin.weight.data.copy_(conv.weight.detach().reshape(out_dim, in_dim))
    lin.bias.data.copy_(conv.bias.detach())
    lin = lin.cuda().to(torch.bfloat16)
    
    N = 6080  # patches in one 8-frame Qwen3-VL clip
    x_5d  = torch.randn(N, 3, 2, 16, 16, dtype=torch.bfloat16, device="cuda")
    x_flat = x_5d.reshape(N, -1).contiguous()
    
    for _ in range(3): _ = conv(x_5d); _ = lin(x_flat)
    
    torch.cuda.synchronize(); t0 = time.time()
    for _ in range(5): y_conv = conv(x_5d).view(N, -1)
    torch.cuda.synchronize(); t_conv = (time.time()-t0)/5
    
    torch.cuda.synchronize(); t0 = time.time()
    for _ in range(5): y_lin = lin(x_flat)
    torch.cuda.synchronize(); t_lin = (time.time()-t0)/5
    
    print(f"Conv3d:  {t_conv*1000:8.2f} ms")
    print(f"Linear:  {t_lin*1000:8.2f} ms")
    print(f"Speedup: {t_conv/t_lin:8.1f}x")
    diff = (y_conv.float() - y_lin.float()).abs().max().item()
    cos = torch.nn.functional.cosine_similarity(
        y_conv.float().flatten().unsqueeze(0),
        y_lin.float().flatten().unsqueeze(0)).item()
    print(f"max abs diff (bf16): {diff:.2e}")
    print(f"cosine similarity:   {cos:.6f}")

    Output:

    Conv3d:    16111.30 ms
    Linear:        0.30 ms
    Speedup:    53704.3x
    max abs diff (bf16): 1.56e-02
    cosine similarity:   0.999500
  5. Verify mathematical equivalence in fp32 (rules out numerical accident).

import torch, torch.nn as nn
torch.manual_seed(0); N = 100; C, T, P = 3, 2, 16; out_dim = 1024
conv = nn.Conv3d(C, out_dim, (T,P,P), stride=(T,P,P), bias=True)
in_dim = C*T*P*P
lin = nn.Linear(in_dim, out_dim, bias=True)
lin.weight.data.copy_(conv.weight.detach().reshape(out_dim, in_dim))
lin.bias.data.copy_(conv.bias.detach())

x_5d = torch.randn(N, C, T, P, P, dtype=torch.float32)
x_flat = x_5d.reshape(N, -1).contiguous()
with torch.no_grad():
    o_conv = conv(x_5d).view(N, -1)
    o_lin  = lin(x_flat)

abs_diff = (o_conv - o_lin).abs()
print(f"fp32 max abs diff:  {abs_diff.max().item():.2e}")
print(f"fp32 mean abs diff: {abs_diff.mean().item():.2e}")

Output:

fp32 max abs diff:  4.77e-07
fp32 mean abs diff: 7.61e-08

Conv3d with kernel == stride is exactly equivalent to Linear over reshaped weights — fp32 difference is single-multiplication round-off (~5e-7).

  1. Apply the fix (lazy in-place Conv3d → Linear via monkey-patch on Qwen3VLVisionPatchEmbed.forward) and re-benchmark.
import time, torch, torch.nn as nn
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
    Qwen3VLVisionPatchEmbed,
)

def _fast_forward(self, hidden_states):
    target_dtype = self.proj.weight.dtype
    if isinstance(self.proj, nn.Conv3d):
        conv = self.proj
        out_dim = conv.out_channels
        in_dim = (conv.in_channels * conv.kernel_size[0]
                  * conv.kernel_size[1] * conv.kernel_size[2])
        w_flat = conv.weight.detach().reshape(out_dim, in_dim).contiguous()
        bias = conv.bias.detach().clone() if conv.bias is not None else None
        new_proj = nn.Linear(in_dim, out_dim, bias=bias is not None)
        new_proj.weight.data.copy_(w_flat)
        if bias is not None: new_proj.bias.data.copy_(bias)
        new_proj.to(device=conv.weight.device, dtype=conv.weight.dtype)
        self.proj = new_proj
    if hidden_states.dim() > 2 \
            or hidden_states.shape[-1] != self.proj.in_features:
        hidden_states = hidden_states.reshape(-1, self.proj.in_features)
    return self.proj(hidden_states.to(dtype=target_dtype))

Qwen3VLVisionPatchEmbed.forward = _fast_forward

# Reload model and run step-2-style timing again.

Output:

bs=1  rep=0: 0.27s (270 ms/sample)
bs=1  rep=1: 0.29s (290 ms/sample)
bs=8  rep=0: 2.16s (270 ms/sample)
bs=8  rep=1: 2.18s (273 ms/sample)

Speedup vs step 2: 62× at bs=1, 68× at bs=8. VRAM unchanged. Patch embedding goes from 96% of total forward time to <1%.

Expected behavior

<!DOCTYPE html><p cid="n71" mdtype="paragraph" class="md-end-block md-p md-focus" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="code" spellcheck="false" class="md-pair-s md-expand" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Qwen3VLVisionPatchEmbed.forward</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> should run in ~0.3 ms (the time of the</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;">

</span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">equivalent </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">nn.Linear</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">), not ~16 s.</span></p><h3 cid="n72" mdtype="heading" class="md-end-block md-heading" style="box-sizing: border-box; white-space: pre-wrap; break-after: avoid-page; break-inside: avoid; orphans: 4; font-size: 1.5em; margin-top: 1rem; margin-bottom: 1rem; position: relative; font-weight: bold; line-height: 1.43; cursor: text; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Proposed fix</span></h3><pre class="md-fences md-end-block ty-contain-cm modeLoaded" spellcheck="false" lang="diff" cid="n73" mdtype="fences" style="box-sizing: border-box; overflow: visible; font-family: var(--monospace); font-size: 0.9em; display: block; break-inside: avoid; text-align: left; white-space: pre; background-image: inherit; background-position: inherit; background-size: inherit; background-repeat: inherit; background-attachment: inherit; background-origin: inherit; background-clip: inherit; background-color: rgb(248, 248, 248); position: relative; border: 1px solid rgb(231, 234, 237); border-radius: 3px; padding: 8px 4px 6px; margin-bottom: 15px; margin-top: 15px; width: inherit; color: rgb(51, 51, 51); font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"> class Qwen3VLVisionPatchEmbed(nn.Module):</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">     def init(self, config) -> None:</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         super().init()</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         self.patch_size = config.patch_size</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         self.temporal_patch_size = config.temporal_patch_size</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         self.in_channels = config.in_channels</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         self.embed_dim = config.hidden_size</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span cm-text="" cm-zwsp="" style="box-sizing: border-box;"></span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       self.proj = nn.Conv3d(</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-           self.in_channels, self.embed_dim,</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-           kernel_size=kernel_size, stride=kernel_size, bias=True,</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       )</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-positive" style="box-sizing: border-box; color: rgb(34, 153, 34);">+       in_dim = (self.in_channels * self.temporal_patch_size</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-positive" style="box-sizing: border-box; color: rgb(34, 153, 34);">+                 * self.patch_size * self.patch_size)</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-positive" style="box-sizing: border-box; color: rgb(34, 153, 34);">+       self.proj = nn.Linear(in_dim, self.embed_dim, bias=True)</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span cm-text="" cm-zwsp="" style="box-sizing: border-box;"></span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">     def forward(self, hidden_states):</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         target_dtype = self.proj.weight.dtype</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       hidden_states = hidden_states.view(</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-           -1, self.in_channels, self.temporal_patch_size,</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-           self.patch_size, self.patch_size,</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       )</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-positive" style="box-sizing: border-box; color: rgb(34, 153, 34);">+       hidden_states = hidden_states.reshape(-1, self.proj.in_features)</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-positive" style="box-sizing: border-box; color: rgb(34, 153, 34);">+       hidden_states = self.proj(hidden_states.to(dtype=target_dtype))</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         return hidden_states</span></pre><h3 cid="n74" mdtype="heading" class="md-end-block md-heading" style="box-sizing: border-box; white-space: pre-wrap; break-after: avoid-page; break-inside: avoid; orphans: 4; font-size: 1.5em; margin-top: 1rem; margin-bottom: 1rem; position: relative; font-weight: bold; line-height: 1.43; cursor: text; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Backward compatibility for existing checkpoints</span></h3><p cid="n75" mdtype="paragraph" class="md-end-block md-p" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Pretrained </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Qwen3-VL--Instruct</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> checkpoints save </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">proj.weight</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> in</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">5-D </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Conv3d</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> shape </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">(out, in, k_t, k_h, k_w)</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">. To load them into the new</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Linear</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> layer (</span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">(out, in·k_t·k_h·k_w)</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">), add a </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">_load_from_state_dict</code></span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">hook:</span></p><pre class="md-fences md-end-block ty-contain-cm modeLoaded" spellcheck="false" lang="python" cid="n76" mdtype="fences" style="box-sizing: border-box; overflow: visible; font-family: var(--monospace); font-size: 0.9em; display: block; break-inside: avoid; text-align: left; white-space: pre; background-image: inherit; background-position: inherit; background-size: inherit; background-repeat: inherit; background-attachment: inherit; background-origin: inherit; background-clip: inherit; background-color: rgb(248, 248, 248); position: relative; border: 1px solid rgb(231, 234, 237); border-radius: 3px; padding: 8px 4px 6px; margin-bottom: 15px; margin-top: 15px; width: inherit; color: rgb(51, 51, 51); font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-keyword" style="box-sizing: border-box; color: rgb(119, 0, 136);">def</span> <span class="cm-def" style="box-sizing: border-box; color: rgb(0, 0, 255);">_load_from_state_dict</span>(<span class="cm-variable-2" style="box-sizing: border-box; color: rgb(0, 85, 170);">self</span>, <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>, <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">prefix</span>, <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);"></span><span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">args</span>, <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);"></span><span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">kwargs</span>):</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">    <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span> <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">=</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">prefix</span> <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">+</span> <span class="cm-string" style="box-sizing: border-box; color: rgb(170, 17, 17);">"proj.weight"</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">    <span class="cm-keyword" style="box-sizing: border-box; color: rgb(119, 0, 136);">if</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span> <span class="cm-keyword" style="box-sizing: border-box; color: rgb(119, 0, 136);">in</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span> <span class="cm-keyword" style="box-sizing: border-box; color: rgb(119, 0, 136);">and</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>[<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span>].<span class="cm-property" style="box-sizing: border-box; color: rgb(0, 0, 0);">dim</span>() <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">==</span> <span class="cm-number" style="box-sizing: border-box; color: rgb(17, 102, 68);">5</span>:</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">        <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">out_dim</span> <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">=</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>[<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span>].<span class="cm-property" style="box-sizing: border-box; color: rgb(0, 0, 0);">shape</span>[<span class="cm-number" style="box-sizing: border-box; color: rgb(17, 102, 68);">0</span>]</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">        <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>[<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span>] <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">=</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>[<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span>].<span class="cm-property" style="box-sizing: border-box; color: rgb(0, 0, 0);">reshape</span>(<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">out_dim</span>, <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">-</span><span class="cm-number" style="box-sizing: border-box; color: rgb(17, 102, 68);">1</span>).<span class="cm-property" style="box-sizing: border-box; color: rgb(0, 0, 0);">contiguous</span>()</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">    <span class="cm-builtin" style="box-sizing: border-box; color: rgb(51, 0, 170);">super</span>().<span class="cm-property" style="box-sizing: border-box; color: rgb(0, 0, 0);">_load_from_state_dict</span>(<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>, <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">prefix</span>, <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">*</span><span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">args</span>, <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);"></span><span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">kwargs</span>)</span></pre><p cid="n77" mdtype="paragraph" class="md-end-block md-p" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">This makes the change transparent to existing public checkpoints.</span></p><h3 cid="n78" mdtype="heading" class="md-end-block md-heading" style="box-sizing: border-box; white-space: pre-wrap; break-after: avoid-page; break-inside: avoid; orphans: 4; font-size: 1.5em; margin-top: 1rem; margin-bottom: 1rem; position: relative; font-weight: bold; line-height: 1.43; cursor: text; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Numerical equivalence verified</span></h3><figure class="md-table-fig table-figure" cid="n79" mdtype="table" style="box-sizing: border-box; margin: 1.2em 0px; overflow-x: auto; max-width: calc(100% + 16px); padding: 0px; cursor: default; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; white-space: normal; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;">

checktoleranceobserved
fp32 max abs diff (proj output)< 1e-5< 1e-7
bf16 cosine similarity (proj output)> 0.9990.9995
bf16 cosine similarity (full 24-layer vision tower)> 0.99> 0.999 per sample
</figure><h3 cid="n96" mdtype="heading" class="md-end-block md-heading" style="box-sizing: border-box; white-space: pre-wrap; break-after: avoid-page; break-inside: avoid; orphans: 4; font-size: 1.5em; margin-top: 1rem; margin-bottom: 1rem; position: relative; font-weight: bold; line-height: 1.43; cursor: text; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Same fix applies to Qwen2-VL and Qwen2.5-VL</span></h3><p cid="n97" mdtype="paragraph" class="md-end-block md-p" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">The same </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Conv3d</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">-with-</span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">kernel_size == stride</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> pattern exists in</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Qwen2VLVisionPatchEmbed</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> and </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Qwen2_5_VLVisionPatchEmbed</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">. Both should</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">be patched identically.</span></p><h3 cid="n98" mdtype="heading" class="md-end-block md-heading" style="box-sizing: border-box; white-space: pre-wrap; break-after: avoid-page; break-inside: avoid; orphans: 4; font-size: 1.5em; margin-top: 1rem; margin-bottom: 1rem; position: relative; font-weight: bold; line-height: 1.43; cursor: text; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Why this matters</span></h3><p cid="n99" mdtype="paragraph" class="md-end-block md-p" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Anyone running Qwen-VL inference on a Blackwell GPU in bf16 silently</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">pays a ~50,000× cost on the patch projection. For 30,000-sample feature</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">extraction, this is the difference between </span><span md-inline="strong" class="md-pair-s " style="box-sizing: border-box;"><strong style="box-sizing: border-box;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">6 days</span></strong></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> and </span><span md-inline="strong" class="md-pair-s " style="box-sizing: border-box;"><strong style="box-sizing: border-box;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">~2 hours</span></strong></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">.</span></p><p cid="n100" mdtype="paragraph" class="md-end-block md-p md-focus" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Happy to send a PR with the rewrite, the backward-compat</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">_load_from_state_dict</code></span><span md-inline="plain" class="md-plain md-expand" style="box-sizing: border-box;"> hook, a unit test, and a benchmark script.</span></p>

extent analysis

TL;DR

Replace the Conv3d layer with an equivalent Linear layer in Qwen3VLVisionPatchEmbed.forward to achieve a significant speedup.

Guidance

  1. Identify the bottleneck: The Qwen3VLVisionPatchEmbed.forward method is the main contributor to the slowdown, specifically the Conv3d layer.
  2. Replace Conv3d with Linear: Modify the Qwen3VLVisionPatchEmbed class to use a Linear layer instead of Conv3d for the patch projection.
  3. Add backward compatibility: Implement the _load_from_state_dict hook to ensure compatibility with existing checkpoints.
  4. Verify numerical equivalence: Check that the modified implementation produces the same results as the original Conv3d layer.

Example

class Qwen3VLVisionPatchEmbed(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        #...
        in_dim = (self.in_channels * self.temporal_patch_size
                  * self.patch_size * self.patch_size)
        self.proj = nn.Linear(in_dim, self.embed_dim, bias=True)

    def forward(self, hidden_states):
        #...
        hidden_states = hidden_states.reshape(-1, self.proj.in_features)
        hidden_states = self.proj(hidden_states.to(dtype=self.proj.weight.dtype))
        return hidden_states

Notes

  • This fix applies to Qwen2-VL and Qwen2.5-VL models as well.
  • The modified implementation should be thoroughly tested to ensure correctness and numerical equivalence.

Recommendation

Apply the workaround by replacing the Conv3d layer with an equivalent Linear layer in Qwen3VLVisionPatchEmbed.forward. This change should result in a significant speedup without affecting the model's accuracy.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

<!DOCTYPE html><p cid="n71" mdtype="paragraph" class="md-end-block md-p md-focus" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="code" spellcheck="false" class="md-pair-s md-expand" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Qwen3VLVisionPatchEmbed.forward</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> should run in ~0.3 ms (the time of the</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;">

</span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">equivalent </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">nn.Linear</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">), not ~16 s.</span></p><h3 cid="n72" mdtype="heading" class="md-end-block md-heading" style="box-sizing: border-box; white-space: pre-wrap; break-after: avoid-page; break-inside: avoid; orphans: 4; font-size: 1.5em; margin-top: 1rem; margin-bottom: 1rem; position: relative; font-weight: bold; line-height: 1.43; cursor: text; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Proposed fix</span></h3><pre class="md-fences md-end-block ty-contain-cm modeLoaded" spellcheck="false" lang="diff" cid="n73" mdtype="fences" style="box-sizing: border-box; overflow: visible; font-family: var(--monospace); font-size: 0.9em; display: block; break-inside: avoid; text-align: left; white-space: pre; background-image: inherit; background-position: inherit; background-size: inherit; background-repeat: inherit; background-attachment: inherit; background-origin: inherit; background-clip: inherit; background-color: rgb(248, 248, 248); position: relative; border: 1px solid rgb(231, 234, 237); border-radius: 3px; padding: 8px 4px 6px; margin-bottom: 15px; margin-top: 15px; width: inherit; color: rgb(51, 51, 51); font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"> class Qwen3VLVisionPatchEmbed(nn.Module):</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">     def init(self, config) -> None:</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         super().init()</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         self.patch_size = config.patch_size</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         self.temporal_patch_size = config.temporal_patch_size</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         self.in_channels = config.in_channels</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         self.embed_dim = config.hidden_size</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span cm-text="" cm-zwsp="" style="box-sizing: border-box;"></span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       self.proj = nn.Conv3d(</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-           self.in_channels, self.embed_dim,</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-           kernel_size=kernel_size, stride=kernel_size, bias=True,</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       )</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-positive" style="box-sizing: border-box; color: rgb(34, 153, 34);">+       in_dim = (self.in_channels * self.temporal_patch_size</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-positive" style="box-sizing: border-box; color: rgb(34, 153, 34);">+                 * self.patch_size * self.patch_size)</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-positive" style="box-sizing: border-box; color: rgb(34, 153, 34);">+       self.proj = nn.Linear(in_dim, self.embed_dim, bias=True)</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span cm-text="" cm-zwsp="" style="box-sizing: border-box;"></span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">     def forward(self, hidden_states):</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         target_dtype = self.proj.weight.dtype</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       hidden_states = hidden_states.view(</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-           -1, self.in_channels, self.temporal_patch_size,</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-           self.patch_size, self.patch_size,</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       )</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-negative" style="box-sizing: border-box; color: rgb(221, 68, 68);">-       hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-positive" style="box-sizing: border-box; color: rgb(34, 153, 34);">+       hidden_states = hidden_states.reshape(-1, self.proj.in_features)</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-positive" style="box-sizing: border-box; color: rgb(34, 153, 34);">+       hidden_states = self.proj(hidden_states.to(dtype=target_dtype))</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">         return hidden_states</span></pre><h3 cid="n74" mdtype="heading" class="md-end-block md-heading" style="box-sizing: border-box; white-space: pre-wrap; break-after: avoid-page; break-inside: avoid; orphans: 4; font-size: 1.5em; margin-top: 1rem; margin-bottom: 1rem; position: relative; font-weight: bold; line-height: 1.43; cursor: text; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Backward compatibility for existing checkpoints</span></h3><p cid="n75" mdtype="paragraph" class="md-end-block md-p" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Pretrained </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Qwen3-VL--Instruct</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> checkpoints save </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">proj.weight</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> in</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">5-D </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Conv3d</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> shape </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">(out, in, k_t, k_h, k_w)</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">. To load them into the new</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Linear</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> layer (</span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">(out, in·k_t·k_h·k_w)</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">), add a </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">_load_from_state_dict</code></span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">hook:</span></p><pre class="md-fences md-end-block ty-contain-cm modeLoaded" spellcheck="false" lang="python" cid="n76" mdtype="fences" style="box-sizing: border-box; overflow: visible; font-family: var(--monospace); font-size: 0.9em; display: block; break-inside: avoid; text-align: left; white-space: pre; background-image: inherit; background-position: inherit; background-size: inherit; background-repeat: inherit; background-attachment: inherit; background-origin: inherit; background-clip: inherit; background-color: rgb(248, 248, 248); position: relative; border: 1px solid rgb(231, 234, 237); border-radius: 3px; padding: 8px 4px 6px; margin-bottom: 15px; margin-top: 15px; width: inherit; color: rgb(51, 51, 51); font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;"><span class="cm-keyword" style="box-sizing: border-box; color: rgb(119, 0, 136);">def</span> <span class="cm-def" style="box-sizing: border-box; color: rgb(0, 0, 255);">_load_from_state_dict</span>(<span class="cm-variable-2" style="box-sizing: border-box; color: rgb(0, 85, 170);">self</span>, <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>, <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">prefix</span>, <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);"></span><span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">args</span>, <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);"></span><span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">kwargs</span>):</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">    <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span> <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">=</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">prefix</span> <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">+</span> <span class="cm-string" style="box-sizing: border-box; color: rgb(170, 17, 17);">"proj.weight"</span></span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">    <span class="cm-keyword" style="box-sizing: border-box; color: rgb(119, 0, 136);">if</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span> <span class="cm-keyword" style="box-sizing: border-box; color: rgb(119, 0, 136);">in</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span> <span class="cm-keyword" style="box-sizing: border-box; color: rgb(119, 0, 136);">and</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>[<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span>].<span class="cm-property" style="box-sizing: border-box; color: rgb(0, 0, 0);">dim</span>() <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">==</span> <span class="cm-number" style="box-sizing: border-box; color: rgb(17, 102, 68);">5</span>:</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">        <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">out_dim</span> <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">=</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>[<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span>].<span class="cm-property" style="box-sizing: border-box; color: rgb(0, 0, 0);">shape</span>[<span class="cm-number" style="box-sizing: border-box; color: rgb(17, 102, 68);">0</span>]</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">        <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>[<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span>] <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">=</span> <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>[<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">key</span>].<span class="cm-property" style="box-sizing: border-box; color: rgb(0, 0, 0);">reshape</span>(<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">out_dim</span>, <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">-</span><span class="cm-number" style="box-sizing: border-box; color: rgb(17, 102, 68);">1</span>).<span class="cm-property" style="box-sizing: border-box; color: rgb(0, 0, 0);">contiguous</span>()</span><br><span role="presentation" style="box-sizing: border-box; padding-right: 0.1px;">    <span class="cm-builtin" style="box-sizing: border-box; color: rgb(51, 0, 170);">super</span>().<span class="cm-property" style="box-sizing: border-box; color: rgb(0, 0, 0);">_load_from_state_dict</span>(<span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">state_dict</span>, <span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">prefix</span>, <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);">*</span><span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">args</span>, <span class="cm-operator" style="box-sizing: border-box; color: rgb(152, 26, 26);"></span><span class="cm-variable" style="box-sizing: border-box; color: rgb(0, 0, 0);">kwargs</span>)</span></pre><p cid="n77" mdtype="paragraph" class="md-end-block md-p" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">This makes the change transparent to existing public checkpoints.</span></p><h3 cid="n78" mdtype="heading" class="md-end-block md-heading" style="box-sizing: border-box; white-space: pre-wrap; break-after: avoid-page; break-inside: avoid; orphans: 4; font-size: 1.5em; margin-top: 1rem; margin-bottom: 1rem; position: relative; font-weight: bold; line-height: 1.43; cursor: text; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Numerical equivalence verified</span></h3><figure class="md-table-fig table-figure" cid="n79" mdtype="table" style="box-sizing: border-box; margin: 1.2em 0px; overflow-x: auto; max-width: calc(100% + 16px); padding: 0px; cursor: default; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; white-space: normal; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;">

checktoleranceobserved
fp32 max abs diff (proj output)< 1e-5< 1e-7
bf16 cosine similarity (proj output)> 0.9990.9995
bf16 cosine similarity (full 24-layer vision tower)> 0.99> 0.999 per sample
</figure><h3 cid="n96" mdtype="heading" class="md-end-block md-heading" style="box-sizing: border-box; white-space: pre-wrap; break-after: avoid-page; break-inside: avoid; orphans: 4; font-size: 1.5em; margin-top: 1rem; margin-bottom: 1rem; position: relative; font-weight: bold; line-height: 1.43; cursor: text; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Same fix applies to Qwen2-VL and Qwen2.5-VL</span></h3><p cid="n97" mdtype="paragraph" class="md-end-block md-p" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">The same </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Conv3d</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">-with-</span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">kernel_size == stride</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> pattern exists in</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Qwen2VLVisionPatchEmbed</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> and </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">Qwen2_5_VLVisionPatchEmbed</code></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">. Both should</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">be patched identically.</span></p><h3 cid="n98" mdtype="heading" class="md-end-block md-heading" style="box-sizing: border-box; white-space: pre-wrap; break-after: avoid-page; break-inside: avoid; orphans: 4; font-size: 1.5em; margin-top: 1rem; margin-bottom: 1rem; position: relative; font-weight: bold; line-height: 1.43; cursor: text; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Why this matters</span></h3><p cid="n99" mdtype="paragraph" class="md-end-block md-p" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Anyone running Qwen-VL inference on a Blackwell GPU in bf16 silently</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">pays a ~50,000× cost on the patch projection. For 30,000-sample feature</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">extraction, this is the difference between </span><span md-inline="strong" class="md-pair-s " style="box-sizing: border-box;"><strong style="box-sizing: border-box;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">6 days</span></strong></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;"> and </span><span md-inline="strong" class="md-pair-s " style="box-sizing: border-box;"><strong style="box-sizing: border-box;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">~2 hours</span></strong></span><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">.</span></p><p cid="n100" mdtype="paragraph" class="md-end-block md-p md-focus" style="box-sizing: border-box; line-height: inherit; orphans: 4; margin: 0.8em 0px; white-space: pre-wrap; position: relative; color: rgb(51, 51, 51); font-family: &quot;Open Sans&quot;, &quot;Clear Sans&quot;, &quot;Helvetica Neue&quot;, Helvetica, Arial, &quot;Segoe UI Emoji&quot;, &quot;SF Pro&quot;, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; text-align: start; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><span md-inline="plain" class="md-plain" style="box-sizing: border-box;">Happy to send a PR with the rewrite, the backward-compat</span><span md-inline="softbreak" class="md-softbreak" style="box-sizing: border-box;"> </span><span md-inline="code" spellcheck="false" class="md-pair-s" style="box-sizing: border-box;"><code style="box-sizing: border-box; font-family: var(--monospace); text-align: left; vertical-align: initial; border: 1px solid rgb(231, 234, 237); background-color: rgb(243, 244, 244); border-radius: 3px; padding: 0px 2px; font-size: 0.9em;">_load_from_state_dict</code></span><span md-inline="plain" class="md-plain md-expand" style="box-sizing: border-box;"> hook, a unit test, and a benchmark script.</span></p>

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - ✅(Solved) Fix `Qwen3VLVisionPatchEmbed.proj` (`nn.Conv3d` with `stride == kernel`) is ~50,000× slower than equivalent `nn.Linear` on Blackwell + bf16 [1 pull requests, 1 participants]