vllm - ✅(Solved) Fix [Bug]: Gemma4 MoE routing closure captures `per_expert_scale` Parameter, breaking `torch.func.functional_call` substitution [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#42239Fetched 2026-05-11 03:13:40
View on GitHub
Comments
0
Participants
1
Timeline
2
Reactions
0
Author
Participants
Timeline (top)
cross-referenced ×1labeled ×1

Error Message

AttributeError: 'Parameter' object has no attribute '_env'

Root Cause

Gemma4 MoE captures the registered per_expert_scale parameter in a Python closure in vllm/model_executor/models/gemma4.py. This prevents torch.func.functional_call-style parameter substitution from reaching the routing function, because the closure keeps the original Parameter object even when the module attribute is substituted.

Fix Action

Fixed

PR fix notes

PR #42250: [Bugfix][Model] Gemma4 MoE routing closure captures per_expert_scale, breaking functional_call substitution

Description (problem / solution / changelog)

Why this is not a duplicate

No existing open PR addresses this fix. Checked via:

  • gh pr list --repo vllm-project/vllm --state open --search "42239 in:body" → no results
  • gh pr list --repo vllm-project/vllm --state open --search "gemma4 routing closure" → no results

Fixes #42239.

What

Gemma4MoE.__init__ builds a routing_function closure that captured self.per_expert_scale into a local variable:

per_expert_scale = self.per_expert_scale   # ← captured once at construction

def routing_function(...):
    return gemma4_routing_function_torch(gating_output, topk, per_expert_scale)

Because the closure holds the original Parameter object, torch.func.functional_call parameter substitution is invisible to the routing function — the module attribute is patched, but the closure-captured reference is not.

Fix

Remove the local capture; read self.per_expert_scale at call time:

def routing_function(...):
    return gemma4_routing_function_torch(gating_output, topk, self.per_expert_scale)

self is captured instead, so module-level patching (including functional_call) now reaches the routing function correctly.

Tests

Added test_gemma4_moe_routing_functional_call_substitution to tests/kernels/moe/test_gemma4router.py. The test constructs a minimal module replicating the closure pattern, substitutes per_expert_scale via functional_call, and asserts the substituted value is used by the routing function.

Run locally (CPU, no GPU required):

.venv/bin/python -m pytest tests/kernels/moe/test_gemma4router.py::test_gemma4_moe_routing_functional_call_substitution -v

Result: PASSED

AI assistance disclosure

This PR was developed with AI assistance (Claude). All changed lines were reviewed and the test was run locally before submission.

Changed files

  • tests/kernels/moe/test_gemma4router.py (modified, +59/-0)
  • vllm/model_executor/models/gemma4.py (modified, +7/-4)

Code Example

Collecting environment information...
==============================
        System Info
==============================
OS                           : Ubuntu 22.04.5 LTS (x86_64)
GCC version                  : (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version                : Could not collect
CMake version                : Could not collect
Libc version                 : glibc-2.35

==============================
       PyTorch Info
==============================
PyTorch version              : 2.10.0+cpu
Is debug build               : False
CUDA used to build PyTorch   : None
ROCM used to build PyTorch   : N/A
XPU used to build PyTorch    : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.13 (main, May  8 2026, 18:32:55) [Clang 22.1.3 ] (64-bit runtime)
Python platform              : Linux-6.8.0-1015-gcp-x86_64-with-glibc2.35

==============================
          CPU Info
==============================
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        52 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               180
On-line CPU(s) list:                  0-179
Vendor ID:                            AuthenticAMD
Model name:                           AMD EPYC 9B14
CPU family:                           25
Model:                                17
Thread(s) per core:                   1
Core(s) per socket:                   90
Socket(s):                            2
Stepping:                             1
BogoMIPS:                             5199.99
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm
Hypervisor vendor:                    KVM
Virtualization type:                  full

==============================
Versions of relevant libraries
==============================
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.3.5
[pip3] pyzmq==27.1.0
[pip3] torch==2.10.0+cpu
[pip3] torchax==0.0.11
[pip3] torchvision==0.25.0+cpu
[pip3] transformers==5.5.3
[pip3] triton==3.6.0
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.20.2rc1.dev168+gecd0b60aa (git sha: ecd0b60aa)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; XPU: Disabled
GPU Topology:
  Could not collect

==============================
     Environment Variables
==============================
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_kanna

---

self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))

per_expert_scale = self.per_expert_scale

def routing_function(...):
    return gemma4_routing_function_torch(
        gating_output, topk, per_expert_scale
    )

---

AttributeError: 'Parameter' object has no attribute '_env'

---

import torch
from torch import nn
from torch.func import functional_call


class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.p = nn.Parameter(torch.tensor([1.0]))

        p = self.p

        def f():
            return p

        self.f = f

    def forward(self):
        return self.p, self.f()


m = M()
new_p = torch.tensor([2.0])

out_attr, out_closure = functional_call(m, {"p": new_p}, ())

print("module attr lookup:", out_attr.item())
print("closure lookup:", out_closure.item())

---

module attr lookup: 2.0
closure lookup: 1.0

---

def routing_function(...):
    if current_platform.is_cuda_alike() or current_platform.is_xpu():
        return gemma4_fused_routing_kernel_triton(
            gating_output, topk, self.per_expert_scale
        )

    return gemma4_routing_function_torch(
        gating_output, topk, self.per_expert_scale
    )
RAW_BUFFERClick to expand / collapse

Your current environment

<details> <summary>The output of `python collect_env.py`</summary>
Collecting environment information...
==============================
        System Info
==============================
OS                           : Ubuntu 22.04.5 LTS (x86_64)
GCC version                  : (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version                : Could not collect
CMake version                : Could not collect
Libc version                 : glibc-2.35

==============================
       PyTorch Info
==============================
PyTorch version              : 2.10.0+cpu
Is debug build               : False
CUDA used to build PyTorch   : None
ROCM used to build PyTorch   : N/A
XPU used to build PyTorch    : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.13 (main, May  8 2026, 18:32:55) [Clang 22.1.3 ] (64-bit runtime)
Python platform              : Linux-6.8.0-1015-gcp-x86_64-with-glibc2.35

==============================
          CPU Info
==============================
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        52 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               180
On-line CPU(s) list:                  0-179
Vendor ID:                            AuthenticAMD
Model name:                           AMD EPYC 9B14
CPU family:                           25
Model:                                17
Thread(s) per core:                   1
Core(s) per socket:                   90
Socket(s):                            2
Stepping:                             1
BogoMIPS:                             5199.99
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm
Hypervisor vendor:                    KVM
Virtualization type:                  full

==============================
Versions of relevant libraries
==============================
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.3.5
[pip3] pyzmq==27.1.0
[pip3] torch==2.10.0+cpu
[pip3] torchax==0.0.11
[pip3] torchvision==0.25.0+cpu
[pip3] transformers==5.5.3
[pip3] triton==3.6.0
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.20.2rc1.dev168+gecd0b60aa (git sha: ecd0b60aa)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; XPU: Disabled
GPU Topology:
  Could not collect

==============================
     Environment Variables
==============================
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_kanna
</details>

🐛 Describe the bug

Gemma4 MoE captures the registered per_expert_scale parameter in a Python closure in vllm/model_executor/models/gemma4.py. This prevents torch.func.functional_call-style parameter substitution from reaching the routing function, because the closure keeps the original Parameter object even when the module attribute is substituted.

This is not specific to TorchAX or TPU execution. TorchAX is one downstream consumer that hits the issue, but the underlying behavior can be demonstrated with plain PyTorch torch.func.functional_call: module attribute lookup sees the substituted parameter, while a closure-captured reference keeps pointing at the original Parameter.

The relevant code is in Gemma4MoE.__init__, around:

https://github.com/vllm-project/vllm/blob/215e2f7990d9bb8788555a49036002e69ce14eaa/vllm/model_executor/models/gemma4.py#L323-L341

Current pattern:

self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))

per_expert_scale = self.per_expert_scale

def routing_function(...):
    return gemma4_routing_function_torch(
        gating_output, topk, per_expert_scale
    )

per_expert_scale is registered on the module, so the issue is not that the parameter is missing from module state. The issue is that Gemma4 routing bypasses normal module-state lookup by closing over the original Parameter reference.

In a TorchAX/JAX-backed execution path, this causes Gemma4 MoE routing to mix substituted TorchAX/JAX-backed tensors with the original PyTorch Parameter, leading to:

AttributeError: 'Parameter' object has no attribute '_env'

Minimal reproducer of the underlying behavior

This reproduces the closure-capture behavior with plain PyTorch functional_call:

import torch
from torch import nn
from torch.func import functional_call


class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.p = nn.Parameter(torch.tensor([1.0]))

        p = self.p

        def f():
            return p

        self.f = f

    def forward(self):
        return self.p, self.f()


m = M()
new_p = torch.tensor([2.0])

out_attr, out_closure = functional_call(m, {"p": new_p}, ())

print("module attr lookup:", out_attr.item())
print("closure lookup:", out_closure.item())

Output:

module attr lookup: 2.0
closure lookup: 1.0

This shows why the Gemma4 routing closure is problematic: parameter substitution updates module attribute lookup, but does not rewrite the closure-captured object.

Expected behavior

Gemma4 routing should read per_expert_scale from module state at call time instead of closing over the original Parameter object.

That would allow functional/stateless transform systems to substitute the parameter normally.

Suggested minimal fix

Avoid capturing self.per_expert_scale into a local closure variable. Let routing_function read self.per_expert_scale at call time:

def routing_function(...):
    if current_platform.is_cuda_alike() or current_platform.is_xpu():
        return gemma4_fused_routing_kernel_triton(
            gating_output, topk, self.per_expert_scale
        )

    return gemma4_routing_function_torch(
        gating_output, topk, self.per_expert_scale
    )

A module-method or callable-object refactor would also work, but the state lookup change above is the smallest fix.

Happy to send a PR with this change if this approach looks reasonable.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

Gemma4 routing should read per_expert_scale from module state at call time instead of closing over the original Parameter object.

That would allow functional/stateless transform systems to substitute the parameter normally.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

vllm - ✅(Solved) Fix [Bug]: Gemma4 MoE routing closure captures `per_expert_scale` Parameter, breaking `torch.func.functional_call` substitution [1 pull requests, 1 participants]