ollama - ✅(Solved) Fix mlxrunner: gated_delta_step kernel writes recurrent state in InT (bf16) instead of StT (fp32), corrupting Qwen3.5/3.6 GatedDeltaNet output [2 pull requests, 2 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
ollama/ollama#15865Fetched 2026-04-29 06:11:40
View on GitHub
Comments
2
Participants
1
Timeline
6
Reactions
0
Participants
Timeline (top)
cross-referenced ×4commented ×2

Root Cause

Comparing to mlx-lm/main/mlx_lm/models/gated_delta.py:

# MLX-LM kernel template
return mx.fast.metal_kernel(
    name=f"gated_delta_step{suffix}",
    input_names=inputs,
    output_names=["y", "state_out"],
    source=source,  # source uses InT for y cast and StT for state cast
)
# MLX-LM kernel source — state is cast back as StT, NOT InT:
source = f"""
  ...
  for (int i = 0; i < n_per_t; ++i) {{
    auto s_idx = n_per_t * dk_idx + i;
    o_state[s_idx] = static_cast<StT>(state[i]);  # ← StT
  }}
"""
# MLX-LM gated_delta_update — state explicitly fp32:
if state is None:
    state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32)

Ollama's kernel source has static_cast<InT>(state[i]) and Ollama's qwen3_5.go initializes state with x.DType() (the input bf16). The kernel template never declares StT. Net effect: every recurrent step round-trips state through bf16 (~7 bits of mantissa), losing ~0.4% relative precision per step. Decay terms close to 1.0 (typical) compound this drift across 100s of decode steps until state is meaningless.

Fix Action

Fix / Workaround

Patch

  • PR #15793 (MLX 0.31.2 + threading) — fixes load-time panic, complementary to this fix

  • PR #15759 / #15760 (per-tensor quant overrides) — fixes a separate SparseMoE.Forward panic, complementary to this fix

  • PR #14968 (mlx: qwen3.5 vision support) — touches GatedDeltaNet but is currently CONFLICTING with main; the patch above applies cleanly to v0.22.0 + #15793 + #15760 and could merge ahead of the broader vision work

  • Issue #15822 (MLX runner failed with qwen3.6:35b-a3b-coding-bf16 format=json) — possibly the same root cause, since bf16 variants also exercise the same kernel path

  • Ollama version: v0.22.0 + #15793 + #15759 + #15760 + this patch

  • OS: macOS 26.4.1

  • Hardware: Apple M4 Max, 36 GB unified

  • MLX: 0.31.2 (via #15793)

PR fix notes

PR #14968: mlx: qwen3.5 vision support

Description (problem / solution / changelog)

This change adds vision support to qwen3.5 for the mlx runner.

Changed files

  • x/create/client/create.go (modified, +58/-123)
  • x/create/client/create_test.go (modified, +31/-0)
  • x/mlxrunner/cache.go (modified, +17/-0)
  • x/mlxrunner/client.go (modified, +2/-0)
  • x/mlxrunner/mlx/ops_extra.go (modified, +12/-0)
  • x/mlxrunner/model/base/multimodal.go (added, +32/-0)
  • x/mlxrunner/pipeline.go (modified, +43/-3)
  • x/mlxrunner/runner.go (modified, +3/-1)
  • x/models/qwen3_5/multimodal.go (added, +354/-0)
  • x/models/qwen3_5/qwen3_5.go (modified, +147/-37)
  • x/models/qwen3_5/qwen3_5_test.go (modified, +396/-6)
  • x/models/qwen3_5/vision.go (added, +854/-0)

PR #15870: fix(mlxrunner): preserve fp32 precision in gated_delta_step recurrent state

Description (problem / solution / changelog)

Summary

Fixes incoherent generation in Qwen3.5/3.6 GatedDeltaNet (linear-attention) layers by preserving the fp32 recurrent-state accumulator across kernel invocations, matching MLX-LM's reference. Refs #15865, #15866.

The gated_delta_step Metal/CUDA kernel computed state in float (fp32) inside the inner loop but cast it back to InT (the input dtype, typically bf16) when writing to o_state. That truncated 16 mantissa bits to 7 every recurrent step. Across 30 linear-attention layers and N tokens of a real prompt the state degrades enough that generation becomes incoherent (e.g. "Copyright ofusr =" for a "What is 2+2?" prompt against mlx-community/Qwen3.6-35B-A3B-4bit).

Changes

Two surgical changes in two files:

x/mlxrunner/mlx/gated_delta.go

  • Add an StT template arg to both Metal and CUDA kernels (separate from InT)
  • Cast state writes via static_cast<StT>(state[i]) instead of static_cast<InT>(state[i])
  • Loosen the state.DType() == dtype precondition so the kernel accepts fp32 state alongside bf16 inputs
  • Set the state output_arg dtype to state.DType() instead of the input dtype

x/mlxrunner/cache/recurrent.go

  • Hardcode deltaState to fp32 in ensure(). Conv state continues to track the activation dtype (typically bf16); only the recurrent accumulator is widened.
  • Documented why with a comment pointing at the kernel side and the MLX-LM reference.

No API changes, no qwen3_5.go changes -- call sites still pass a single dtype to RecurrentCache.Get(b, dtype), which now applies only to conv state. Full-attention layers and other models that don't use RecurrentCache are unaffected.

Why hardcode fp32 vs. add a parameter

MLX-LM's reference always allocates the recurrent state as mx.float32. There's no current model that wants a different precision for it, and adding a knob would just push the decision to every call site without a use case to justify it. Easier to revisit if a model ever needs bf16 state for memory reasons.

Memory cost

Delta state is [B, num_v_heads, head_v_dim, head_k_dim]. For Qwen3.6-35B-A3B (32 v-heads x 128 dim x 128 dim per head x 30 linear layers x B=1 = ~63 MB at fp32 vs ~16 MB at bf16). Negligible relative to the model weights.

Verification

  • go build ./x/... -- clean
  • go test ./x/mlxrunner/cache/... ./x/mlxrunner/mlx/... -- all pass

Functional verification (M4 Max / 36 GB / macOS 26.4) using mlx-community/Qwen3.6-35B-A3B-4bit imported via ollama create --experimental:

BeforeAfter
"What is 2+2?""\n\n// Copyright ofusr = ...""<think>\n\n</think>\n\nThe result of $2 + 2$ is **4**."
"Reverse a string in Go"gibberishclean idiomatic Go one-liner
Throughput110 tok/s110 tok/s
nomic-embed-text (regression check)worksworks

mlx-lm's reference on the same checkpoint runs ~112 tok/s for comparison.

Test plan

  • Reviewer can reproduce on Apple Silicon by importing any mlx-community/Qwen3.* MoE checkpoint via ollama create --experimental and running a short prompt -- output should be coherent rather than the Copyright/static/-1999... pattern reported in #15866.
  • CI: existing cache + mlx tests should continue to pass (verified locally).

Generated with Claude Code </content> </invoke>

Changed files

  • x/mlxrunner/cache/recurrent.go (modified, +10/-2)
  • x/mlxrunner/mlx/gated_delta.go (modified, +22/-6)

Code Example

curl -s http://localhost:11434/api/chat -d '{
  "model": "qwen3.6:35b-a3b-coding-nvfp4",
  "messages":[{"role":"user","content":"What is 2+2?"}],
  "stream": false,
  "think": false,
  "options": {"num_predict": 50, "temperature": 0.0}
}'

---

content: '<|im_start|><|im_start|>'
eval_count: 2
done_reason: stop

---

content: '\n\n4\n4\n4\n4\n4...'
eval_count: 50
done_reason: length

---

# MLX-LM kernel template
return mx.fast.metal_kernel(
    name=f"gated_delta_step{suffix}",
    input_names=inputs,
    output_names=["y", "state_out"],
    source=source,  # source uses InT for y cast and StT for state cast
)

---

# MLX-LM kernel source — state is cast back as StT, NOT InT:
source = f"""
  ...
  for (int i = 0; i < n_per_t; ++i) {{
    auto s_idx = n_per_t * dk_idx + i;
    o_state[s_idx] = static_cast<StT>(state[i]);  # ← StT
  }}
"""

---

# MLX-LM gated_delta_update — state explicitly fp32:
if state is None:
    state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32)

---

diff --git a/x/mlxrunner/mlx/gated_delta.go b/x/mlxrunner/mlx/gated_delta.go
index c691f05..ec8fe82 100644
--- a/x/mlxrunner/mlx/gated_delta.go
+++ b/x/mlxrunner/mlx/gated_delta.go
@@ -83,7 +83,7 @@ for (int t = 0; t < T; ++t) {
 
 for (int i = 0; i < n_per_t; ++i) {
   auto s_idx = n_per_t * dk_idx + i;
-  o_state[s_idx] = static_cast<InT>(state[i]);
+  o_state[s_idx] = static_cast<StT>(state[i]);
 }
 `
 
@@ -163,7 +163,7 @@ for (int t = 0; t < T_val; ++t) {
 
 for (int i = 0; i < n_per_t; ++i) {
   auto s_idx = n_per_t * dk_idx + i;
-  o_state[s_idx] = static_cast<InT>(state[i]);
+  o_state[s_idx] = static_cast<StT>(state[i]);
 }
 `
 
@@ -263,9 +263,11 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b
 	}
 
 	dtype := q.DType()
-	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
+	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype {
 		return nil, nil, false
 	}
+	// state may have a different dtype (typically fp32) than q/k/v/g/beta (typically bf16)
+	// — this matches MLX-LM, where state stays fp32 to preserve recurrent precision.
 
 	gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
 	if gatedDeltaMetalDisabled {
@@ -281,6 +283,12 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b
 		gatedDeltaMetalDisabled = true
 		return nil, nil, false
 	}
+	cStT := C.CString("StT")
+	defer C.free(unsafe.Pointer(cStT))
+	if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(state.DType())) != 0 {
+		gatedDeltaMetalDisabled = true
+		return nil, nil, false
+	}
 	for _, tpl := range []struct {
 		name  string
 		value int
@@ -305,7 +313,7 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b
 		gatedDeltaMetalDisabled = true
 		return nil, nil, false
 	}
-	if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
+	if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(state.DType())) != 0 {
 		gatedDeltaMetalDisabled = true
 		return nil, nil, false
 	}
@@ -517,9 +525,11 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar
 	}
 
 	dtype := q.DType()
-	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
+	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype {
 		return nil, nil, false
 	}
+	// state may have a different dtype (typically fp32) than q/k/v/g/beta (typically bf16)
+	// — this matches MLX-LM, where state stays fp32 to preserve recurrent precision.
 
 	gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
 	if gatedDeltaCUDADisabled {
@@ -535,6 +545,12 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar
 		gatedDeltaCUDADisabled = true
 		return nil, nil, false
 	}
+	cStT := C.CString("StT")
+	defer C.free(unsafe.Pointer(cStT))
+	if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(state.DType())) != 0 {
+		gatedDeltaCUDADisabled = true
+		return nil, nil, false
+	}
 	for _, tpl := range []struct {
 		name  string
 		value int
@@ -559,7 +575,7 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar
 		gatedDeltaCUDADisabled = true
 		return nil, nil, false
 	}
-	if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
+	if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(state.DType())) != 0 {
 		gatedDeltaCUDADisabled = true
 		return nil, nil, false
 	}
diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go
index f29563f..6ff586d 100644
--- a/x/models/qwen3_5/qwen3_5.go
+++ b/x/models/qwen3_5/qwen3_5.go
@@ -1231,12 +1231,16 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
 
 	beta := mlx.Sigmoid(b)
 
+	// Recurrent state must be fp32 to match MLX-LM's reference and avoid bf16
+	// precision loss across many recurrent steps. This is the canonical
+	// gated-delta-rule precision contract; the kernel internally accumulates
+	// in float and now correctly casts to StT (fp32) when writing back.
 	var state *mlx.Array
 	if rc != nil {
-		state = rc.DeltaState(int(B), x.DType())
+		state = rc.DeltaState(int(B), mlx.DTypeFloat32)
 	}
 	if state == nil {
-		state = mlx.Zeros(x.DType(), int(B), int(cfg.LinearNumValueHeads), int(cfg.LinearValueHeadDim), int(cfg.LinearKeyHeadDim))
+		state = mlx.Zeros(mlx.DTypeFloat32, int(B), int(cfg.LinearNumValueHeads), int(cfg.LinearValueHeadDim), int(cfg.LinearKeyHeadDim))
 	}
 
 	out, state := mlx.GatedDelta(q, k, v, gDecay, beta, state)
RAW_BUFFERClick to expand / collapse

mlxrunner: gated_delta_step kernel writes recurrent state in InT (bf16) instead of StT (fp32), corrupting Qwen3.5/3.6 GatedDeltaNet output

What is the issue?

The Metal/CUDA gated_delta_step kernel in x/mlxrunner/mlx/gated_delta.go casts the recurrent state output back to InT (bf16 for our model) before writing it to the next-step state buffer. The reference implementation in mlx_lm.models.gated_delta uses a separate StT template arg for the state dtype and keeps state in fp32 to preserve precision across the recurrent decode loop.

Result on Qwen3.5/Qwen3.6 35B-A3B (any of bf16/mxfp8/nvfp4 variants): the model load, MoE routing, full-attention layers, embedding, and lm_head all work correctly, but the linear-attention-layer recurrent state degrades each decode step. The model emits one or two semi-correct tokens, then collapses into degenerate repetition (<|im_start|><|im_start|> echo, or repeating the last sampled token like 4\n4\n4\n4...).

Reproducer

  • macOS 26.4.1, M4 Max 36 GB unified memory, Xcode 26.4
  • Ollama v0.22.0 plus PR #15793 (mlx 0.31.2) and PR #15760 (per-tensor quant overrides) applied — the panic and crash paths those fix are already resolved
  • Model: qwen3.6:35b-a3b-coding-nvfp4 from the official library
curl -s http://localhost:11434/api/chat -d '{
  "model": "qwen3.6:35b-a3b-coding-nvfp4",
  "messages":[{"role":"user","content":"What is 2+2?"}],
  "stream": false,
  "think": false,
  "options": {"num_predict": 50, "temperature": 0.0}
}'

Before this fix

content: '<|im_start|><|im_start|>'
eval_count: 2
done_reason: stop

The model runs the forward pass cleanly (all 1436 tensors loaded, peak 19.7 GiB VRAM) and reaches the sampler, but the first few token logits put <|im_start|> (id 248045) at the top instead of any sensible response.

After this fix

content: '\n\n4\n4\n4\n4\n4...'
eval_count: 50
done_reason: length

The model now correctly identifies 4 as the answer to 2+2. EOS detection and a couple of remaining linear-attention precision points (gDecay cast, conv state dtype) are likely additional issues — but the recurrent-state-precision bug alone gates the model from producing any meaningful tokens at all.

Root cause

Comparing to mlx-lm/main/mlx_lm/models/gated_delta.py:

# MLX-LM kernel template
return mx.fast.metal_kernel(
    name=f"gated_delta_step{suffix}",
    input_names=inputs,
    output_names=["y", "state_out"],
    source=source,  # source uses InT for y cast and StT for state cast
)
# MLX-LM kernel source — state is cast back as StT, NOT InT:
source = f"""
  ...
  for (int i = 0; i < n_per_t; ++i) {{
    auto s_idx = n_per_t * dk_idx + i;
    o_state[s_idx] = static_cast<StT>(state[i]);  # ← StT
  }}
"""
# MLX-LM gated_delta_update — state explicitly fp32:
if state is None:
    state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32)

Ollama's kernel source has static_cast<InT>(state[i]) and Ollama's qwen3_5.go initializes state with x.DType() (the input bf16). The kernel template never declares StT. Net effect: every recurrent step round-trips state through bf16 (~7 bits of mantissa), losing ~0.4% relative precision per step. Decay terms close to 1.0 (typical) compound this drift across 100s of decode steps until state is meaningless.

Patch

Three files (115 lines diff against v0.22.0). Summary:

  1. x/mlxrunner/mlx/gated_delta.go — kernel source (Metal and CUDA) uses StT for the state output cast; kernel config adds cStT template arg with state.DType(); state output_arg uses state.DType(); dtype validation no longer requires state.DType() == dtype.
  2. x/models/qwen3_5/qwen3_5.goGatedDeltaNet.Forward initializes the recurrent state as mlx.DTypeFloat32 (matches MLX-LM's mx.zeros(..., dtype=mx.float32)).
  3. x/mlxrunner/cache/recurrent.go — split ensure into ensureConv and ensureDelta so the conv state can stay at inputs.dtype (bf16) while the delta/recurrent state is fp32.

Validation

Test (temperature=0)BeforeAfter
What is 2+2? first token`<im_start
Token-level eval before done2 (immediate stop)50+ (proper inference)
MLX panic / load crashnone (PR #15793 already fixed)none

There are remaining issues with multi-token coherence (the model gets the first token right then often loops) — those look like additional precision-leak points (conv state, gDecay) that I'm continuing to investigate, but the state-precision bug is the single biggest one and gates everything else.

Related

  • PR #15793 (MLX 0.31.2 + threading) — fixes load-time panic, complementary to this fix
  • PR #15759 / #15760 (per-tensor quant overrides) — fixes a separate SparseMoE.Forward panic, complementary to this fix
  • PR #14968 (mlx: qwen3.5 vision support) — touches GatedDeltaNet but is currently CONFLICTING with main; the patch above applies cleanly to v0.22.0 + #15793 + #15760 and could merge ahead of the broader vision work
  • Issue #15822 (MLX runner failed with qwen3.6:35b-a3b-coding-bf16 format=json) — possibly the same root cause, since bf16 variants also exercise the same kernel path

Environment

  • Ollama version: v0.22.0 + #15793 + #15759 + #15760 + this patch
  • OS: macOS 26.4.1
  • Hardware: Apple M4 Max, 36 GB unified
  • MLX: 0.31.2 (via #15793)

Patch

<details> <summary>full diff</summary>
diff --git a/x/mlxrunner/mlx/gated_delta.go b/x/mlxrunner/mlx/gated_delta.go
index c691f05..ec8fe82 100644
--- a/x/mlxrunner/mlx/gated_delta.go
+++ b/x/mlxrunner/mlx/gated_delta.go
@@ -83,7 +83,7 @@ for (int t = 0; t < T; ++t) {
 
 for (int i = 0; i < n_per_t; ++i) {
   auto s_idx = n_per_t * dk_idx + i;
-  o_state[s_idx] = static_cast<InT>(state[i]);
+  o_state[s_idx] = static_cast<StT>(state[i]);
 }
 `
 
@@ -163,7 +163,7 @@ for (int t = 0; t < T_val; ++t) {
 
 for (int i = 0; i < n_per_t; ++i) {
   auto s_idx = n_per_t * dk_idx + i;
-  o_state[s_idx] = static_cast<InT>(state[i]);
+  o_state[s_idx] = static_cast<StT>(state[i]);
 }
 `
 
@@ -263,9 +263,11 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b
 	}
 
 	dtype := q.DType()
-	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
+	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype {
 		return nil, nil, false
 	}
+	// state may have a different dtype (typically fp32) than q/k/v/g/beta (typically bf16)
+	// — this matches MLX-LM, where state stays fp32 to preserve recurrent precision.
 
 	gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
 	if gatedDeltaMetalDisabled {
@@ -281,6 +283,12 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b
 		gatedDeltaMetalDisabled = true
 		return nil, nil, false
 	}
+	cStT := C.CString("StT")
+	defer C.free(unsafe.Pointer(cStT))
+	if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(state.DType())) != 0 {
+		gatedDeltaMetalDisabled = true
+		return nil, nil, false
+	}
 	for _, tpl := range []struct {
 		name  string
 		value int
@@ -305,7 +313,7 @@ func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok b
 		gatedDeltaMetalDisabled = true
 		return nil, nil, false
 	}
-	if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
+	if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(state.DType())) != 0 {
 		gatedDeltaMetalDisabled = true
 		return nil, nil, false
 	}
@@ -517,9 +525,11 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar
 	}
 
 	dtype := q.DType()
-	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
+	if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype {
 		return nil, nil, false
 	}
+	// state may have a different dtype (typically fp32) than q/k/v/g/beta (typically bf16)
+	// — this matches MLX-LM, where state stays fp32 to preserve recurrent precision.
 
 	gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
 	if gatedDeltaCUDADisabled {
@@ -535,6 +545,12 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar
 		gatedDeltaCUDADisabled = true
 		return nil, nil, false
 	}
+	cStT := C.CString("StT")
+	defer C.free(unsafe.Pointer(cStT))
+	if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(state.DType())) != 0 {
+		gatedDeltaCUDADisabled = true
+		return nil, nil, false
+	}
 	for _, tpl := range []struct {
 		name  string
 		value int
@@ -559,7 +575,7 @@ func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Ar
 		gatedDeltaCUDADisabled = true
 		return nil, nil, false
 	}
-	if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
+	if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(state.DType())) != 0 {
 		gatedDeltaCUDADisabled = true
 		return nil, nil, false
 	}
diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go
index f29563f..6ff586d 100644
--- a/x/models/qwen3_5/qwen3_5.go
+++ b/x/models/qwen3_5/qwen3_5.go
@@ -1231,12 +1231,16 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
 
 	beta := mlx.Sigmoid(b)
 
+	// Recurrent state must be fp32 to match MLX-LM's reference and avoid bf16
+	// precision loss across many recurrent steps. This is the canonical
+	// gated-delta-rule precision contract; the kernel internally accumulates
+	// in float and now correctly casts to StT (fp32) when writing back.
 	var state *mlx.Array
 	if rc != nil {
-		state = rc.DeltaState(int(B), x.DType())
+		state = rc.DeltaState(int(B), mlx.DTypeFloat32)
 	}
 	if state == nil {
-		state = mlx.Zeros(x.DType(), int(B), int(cfg.LinearNumValueHeads), int(cfg.LinearValueHeadDim), int(cfg.LinearKeyHeadDim))
+		state = mlx.Zeros(mlx.DTypeFloat32, int(B), int(cfg.LinearNumValueHeads), int(cfg.LinearValueHeadDim), int(cfg.LinearKeyHeadDim))
 	}
 
 	out, state := mlx.GatedDelta(q, k, v, gDecay, beta, state)
</details>

extent analysis

TL;DR

The most likely fix for the issue is to update the gated_delta_step kernel in x/mlxrunner/mlx/gated_delta.go to cast the recurrent state output to StT (fp32) instead of InT (bf16) to preserve precision across the recurrent decode loop.

Guidance

  • Verify that the gated_delta_step kernel is correctly casting the recurrent state output to StT (fp32) by checking the kernel source code.
  • Ensure that the GatedDeltaNet.Forward function initializes the recurrent state as mlx.DTypeFloat32 to match the reference implementation.
  • Review the ensure function in x/mlxrunner/cache/recurrent.go to ensure it correctly handles the conv state and delta/recurrent state with different data types.

Example

No code snippet is provided as the issue is specific to the gated_delta_step kernel and GatedDeltaNet.Forward function, and the fix involves updating the kernel source code and the GatedDeltaNet.Forward function.

Notes

The issue is specific to the gated_delta_step kernel and GatedDeltaNet.Forward function, and the fix involves updating the kernel source code and the GatedDeltaNet.Forward function to correctly handle the recurrent state precision. There may be additional precision-leak points in the conv state and gDecay cast that need to be investigated.

Recommendation

Apply the workaround by updating the gated_delta_step kernel and GatedDeltaNet.Forward function to correctly handle the recurrent state precision, as described in the patch provided in the issue.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

ollama - ✅(Solved) Fix mlxrunner: gated_delta_step kernel writes recurrent state in InT (bf16) instead of StT (fp32), corrupting Qwen3.5/3.6 GatedDeltaNet output [2 pull requests, 2 comments, 1 participants]