transformers - 💡(How to fix) Fix Qwen3.5 GatedDeltaNet: Large logit divergence between full-sequence forward and prefill+decode with cache

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

Qwen3.5-0.8B produces a max absolute difference of 0.125 between full-sequence forward and prefill+decode. This is ~3 orders of magnitude larger than other models. The error is identical whether loading in bfloat16 or float32, ruling out precision as the cause. Loading in float32 produces identical error magnitudes, ruling out bf16 precision as the cause. The error grows through layers, with a large jump at the final layer:

  1. The error magnitude is ~3 orders of magnitude larger than Llama and Mamba. Qwen3.5 shows max diff = 0.125, while Llama shows 3.5e-5 and Mamba shows 2.2e-4.
  2. Float32 does not help. The error is identical in bf16 and fp32, indicating this is not a floating-point precision issue.
  3. The error accumulates layer by layer, starting at layer 1 (1.9e-3) and growing through subsequent layers, with a large jump at the final layer (0.25).
  • The model weights are stored in bfloat16 on disk. Loading in float32 does not change the error magnitude.

Code Example

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "Qwen/Qwen3.5-0.8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map="cpu")
model.eval()

input_ids = tokenizer.encode("The capital of France is", return_tensors="pt")

# Mode A: Full sequence, no cache
with torch.no_grad():
    logits_full = model(input_ids=input_ids, use_cache=False).logits[0, -1, :].float()

# Mode B: Prefill (first N-1 tokens) + decode last token with cache
with torch.no_grad():
    out_prefix = model(input_ids=input_ids[:, :-1], use_cache=True)
    logits_cached = model(
        input_ids=input_ids[:, -1:],
        past_key_values=out_prefix.past_key_values,
        use_cache=True,
    ).logits[0, -1, :].float()

diff = torch.abs(logits_full - logits_cached)
print(f"Max diff: {diff.max().item():.6e}")   # 1.250000e-01
print(f"Mean diff: {diff.mean().item():.6e}")  # 1.944993e-02

---

A vs B:  Max=1.250000e-01  Mean=1.944993e-02  Equal=False

Top-5 largest differences:
Index    Full Forward     Prefill+Decode   Abs Diff       Token
6927     -6.343750        -6.218750        1.250000e-01   'oles'
28100    -1.132812        -1.007812        1.250000e-01   'ois'
124639   -9.125000        -9.250000        1.250000e-01   '版主'
107466   -8.937500        -8.812500        1.250000e-01   '了指'
16697    -8.062500        -7.937500        1.250000e-01   'olen'

---

A vs B:  Max=1.250000e-01  Mean=1.944993e-02  Equal=False

---

A vs B:  Max=2.231598e-04  Mean=3.242540e-05  Equal=False

---

A vs B:  Max=3.528595e-05  Mean=5.242007e-06  Equal=False

---

Layer    Max Abs Diff     Mean Abs Diff
0        0.000000e+00     0.000000e+00     (embedding, no divergence)
1        1.953125e-03     1.360402e-04
2        3.906250e-03     1.846030e-04
3        1.464844e-03     2.327561e-04
4        1.953125e-03     2.701916e-04
5        1.953125e-03     2.991073e-04
6        3.906250e-03     3.104694e-04
7        2.197266e-03     3.373846e-04
8        2.441406e-03     3.151372e-04
9        2.441406e-03     3.224313e-04
10       2.441406e-03     3.342256e-04
11       2.441406e-03     3.449395e-04
12       3.906250e-03     3.726929e-04
13       3.906250e-03     4.318133e-04
14       3.906250e-03     4.789680e-04
15       7.812500e-03     5.791914e-04
16       7.812500e-03     7.362664e-04
17       1.562500e-02     8.317083e-04
18       1.562500e-02     9.588450e-04
19       1.562500e-02     1.106501e-03
20       1.562500e-02     1.208007e-03
21       1.562500e-02     1.413584e-03
22       1.562500e-02     1.631334e-03
23       3.125000e-02     1.877025e-03
24       2.500000e-01     3.162303e-02     (final layer, large jump)

---

"""
Reproduce: Qwen3.5 GatedDeltaNet produces significantly different logits
between full-sequence forward and prefill+decode with KV cache.

Environment:
    pip install torch transformers

Usage:
    python reproduce_qwen35_kv_cache.py
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def test_kv_cache_consistency(model, tokenizer, model_name, dtype):
    """Compare full-sequence forward vs prefill+decode with KV cache."""
    model.eval()

    prompt = "The capital of France is"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    seq_len = input_ids.shape[1]

    print(f"\n{'=' * 70}")
    print(f"Model: {model_name} (dtype={dtype})")
    print(f"{'=' * 70}")
    print(f"Input: \"{prompt}\" (seq_len={seq_len})")

    # Mode A: Full sequence, no cache
    with torch.no_grad():
        out_a = model(input_ids=input_ids, use_cache=False)
        logits_a = out_a.logits[0, -1, :].float()

    # Mode B: Prefill (first N-1 tokens) + Decode (last token with cache)
    with torch.no_grad():
        out_prefix = model(input_ids=input_ids[:, :-1], use_cache=True)
        out_b = model(
            input_ids=input_ids[:, -1:],
            past_key_values=out_prefix.past_key_values,
            use_cache=True,
        )
        logits_b = out_b.logits[0, -1, :].float()

    # Print comparison
    diff = torch.abs(logits_a - logits_b)
    print(f"\n  A vs B:  Max={diff.max().item():.6e}  Mean={diff.mean().item():.6e}  Equal={torch.equal(logits_a, logits_b)}")

    # Top differences
    top5 = torch.topk(diff, 5)
    print(f"\n  Top-5 largest differences:")
    print(f"  {'Index':<8} {'Full Forward':<16} {'Prefill+Decode':<16} {'Abs Diff':<14} {'Token'}")
    for i in range(5):
        idx = top5.indices[i].item()
        print(
            f"  {idx:<8} {logits_a[idx].item():<16.6f} {logits_b[idx].item():<16.6f} "
            f"{top5.values[i].item():<14.6e} '{tokenizer.decode([idx])}'"
        )

    # Per-layer hidden state divergence
    print(f"\n  Per-layer hidden state divergence (last token, A vs B):")
    with torch.no_grad():
        out_full = model(input_ids=input_ids, use_cache=False, output_hidden_states=True)
        out_pfx = model(input_ids=input_ids[:, :-1], use_cache=True, output_hidden_states=True)
        out_dec = model(
            input_ids=input_ids[:, -1:],
            past_key_values=out_pfx.past_key_values,
            use_cache=True,
            output_hidden_states=True,
        )

    num_layers = len(out_full.hidden_states)
    print(f"  {'Layer':<8} {'Max Abs Diff':<16} {'Mean Abs Diff':<16}")
    for layer_idx in range(num_layers):
        h_full = out_full.hidden_states[layer_idx][0, -1, :].float()
        h_dec = out_dec.hidden_states[layer_idx][0, -1, :].float()
        d = torch.abs(h_full - h_dec)
        print(f"  {layer_idx:<8} {d.max().item():<16.6e} {d.mean().item():<16.6e}")


def test_mamba_kv_cache_consistency():
    """Test Mamba-130M (also uses conv1d + recurrent state, like GatedDeltaNet)."""
    model_name = "state-spaces/mamba-130m-hf"
    print(f"\n{'=' * 70}")
    print(f"Model: {model_name} (dtype=float32)")
    print(f"Architecture: MambaMixer (conv1d + SSM recurrent state)")
    print(f"{'=' * 70}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float32, device_map="cpu")
    model.eval()

    prompt = "The capital of France is"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    seq_len = input_ids.shape[1]
    print(f"Input: \"{prompt}\" (seq_len={seq_len})")

    # Mode A: Full sequence, no cache
    with torch.no_grad():
        out_a = model(input_ids=input_ids, use_cache=False)
        logits_a = out_a.logits[0, -1, :].float()

    # Mode B: Prefill + Decode (Mamba uses cache_params)
    with torch.no_grad():
        out_prefix = model(input_ids=input_ids[:, :-1], use_cache=True)
        cache = out_prefix.cache_params if hasattr(out_prefix, "cache_params") else out_prefix.past_key_values
        out_b = model(input_ids=input_ids[:, -1:], cache_params=cache, use_cache=True)
        logits_b = out_b.logits[0, -1, :].float()

    diff = torch.abs(logits_a - logits_b)
    print(f"\n  A vs B:  Max={diff.max().item():.6e}  Mean={diff.mean().item():.6e}  Equal={torch.equal(logits_a, logits_b)}")


def main():
    print("Loading models...")

    # --- Qwen3.5-0.8B ---
    model_id = "Qwen/Qwen3.5-0.8B"
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Test with bfloat16 (model's native dtype)
    model_bf16 = AutoModelForCausalLM.from_pretrained(
        model_id, dtype=torch.bfloat16, device_map="cpu"
    )
    test_kv_cache_consistency(model_bf16, tokenizer, model_id, "bfloat16")
    del model_bf16

    # Test with float32 (to rule out bf16 precision as the cause)
    model_fp32 = AutoModelForCausalLM.from_pretrained(
        model_id, dtype=torch.float32, device_map="cpu"
    )
    test_kv_cache_consistency(model_fp32, tokenizer, model_id, "float32")
    del model_fp32

    # --- Mamba-130M ---
    print("\n\n" + "#" * 70)
    print("# Comparison: Mamba-130M (conv1d + SSM, similar architecture pattern)")
    print("#" * 70)
    test_mamba_kv_cache_consistency()

    # --- Llama-3.2-1B ---
    print("\n\n" + "#" * 70)
    print("# Baseline: Llama-3.2-1B (standard Transformer)")
    print("#" * 70)
    llama_id = "meta-llama/Llama-3.2-1B"
    tokenizer_llama = AutoTokenizer.from_pretrained(llama_id)
    model_llama = AutoModelForCausalLM.from_pretrained(
        llama_id, dtype=torch.float32, device_map="cpu"
    )
    test_kv_cache_consistency(model_llama, tokenizer_llama, llama_id, "float32")


if __name__ == "__main__":
    main()

---

"""
Reproduce: Qwen3.5 GatedDeltaNet produces significantly different logits
between full-sequence forward and prefill+decode with KV cache.

Environment:
    pip install torch transformers

Usage:
    python reproduce_qwen35_kv_cache.py
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def test_kv_cache_consistency(model, tokenizer, model_name, dtype):
    """Compare full-sequence forward vs prefill+decode with KV cache."""
    model.eval()

    prompt = "The capital of France is"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    seq_len = input_ids.shape[1]

    print(f"\n{'=' * 70}")
    print(f"Model: {model_name} (dtype={dtype})")
    print(f"{'=' * 70}")
    print(f"Input: \"{prompt}\" (seq_len={seq_len})")

    # Mode A: Full sequence, no cache
    with torch.no_grad():
        out_a = model(input_ids=input_ids, use_cache=False)
        logits_a = out_a.logits[0, -1, :].float()

    # Mode B: Prefill (first N-1 tokens) + Decode (last token with cache)
    with torch.no_grad():
        out_prefix = model(input_ids=input_ids[:, :-1], use_cache=True)
        out_b = model(
            input_ids=input_ids[:, -1:],
            past_key_values=out_prefix.past_key_values,
            use_cache=True,
        )
        logits_b = out_b.logits[0, -1, :].float()

    # Print comparison
    diff = torch.abs(logits_a - logits_b)
    print(f"\n  A vs B:  Max={diff.max().item():.6e}  Mean={diff.mean().item():.6e}  Equal={torch.equal(logits_a, logits_b)}")

    # Top differences
    top5 = torch.topk(diff, 5)
    print(f"\n  Top-5 largest differences:")
    print(f"  {'Index':<8} {'Full Forward':<16} {'Prefill+Decode':<16} {'Abs Diff':<14} {'Token'}")
    for i in range(5):
        idx = top5.indices[i].item()
        print(
            f"  {idx:<8} {logits_a[idx].item():<16.6f} {logits_b[idx].item():<16.6f} "
            f"{top5.values[i].item():<14.6e} '{tokenizer.decode([idx])}'"
        )

    # Per-layer hidden state divergence
    print(f"\n  Per-layer hidden state divergence (last token, A vs B):")
    with torch.no_grad():
        out_full = model(input_ids=input_ids, use_cache=False, output_hidden_states=True)
        out_pfx = model(input_ids=input_ids[:, :-1], use_cache=True, output_hidden_states=True)
        out_dec = model(
            input_ids=input_ids[:, -1:],
            past_key_values=out_pfx.past_key_values,
            use_cache=True,
            output_hidden_states=True,
        )

    num_layers = len(out_full.hidden_states)
    print(f"  {'Layer':<8} {'Max Abs Diff':<16} {'Mean Abs Diff':<16}")
    for layer_idx in range(num_layers):
        h_full = out_full.hidden_states[layer_idx][0, -1, :].float()
        h_dec = out_dec.hidden_states[layer_idx][0, -1, :].float()
        d = torch.abs(h_full - h_dec)
        print(f"  {layer_idx:<8} {d.max().item():<16.6e} {d.mean().item():<16.6e}")


def test_mamba_kv_cache_consistency():
    """Test Mamba-130M (also uses conv1d + recurrent state, like GatedDeltaNet)."""
    model_name = "state-spaces/mamba-130m-hf"
    print(f"\n{'=' * 70}")
    print(f"Model: {model_name} (dtype=float32)")
    print(f"Architecture: MambaMixer (conv1d + SSM recurrent state)")
    print(f"{'=' * 70}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float32, device_map="cpu")
    model.eval()

    prompt = "The capital of France is"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    seq_len = input_ids.shape[1]
    print(f"Input: \"{prompt}\" (seq_len={seq_len})")

    # Mode A: Full sequence, no cache
    with torch.no_grad():
        out_a = model(input_ids=input_ids, use_cache=False)
        logits_a = out_a.logits[0, -1, :].float()

    # Mode B: Prefill + Decode (Mamba uses cache_params)
    with torch.no_grad():
        out_prefix = model(input_ids=input_ids[:, :-1], use_cache=True)
        cache = out_prefix.cache_params if hasattr(out_prefix, "cache_params") else out_prefix.past_key_values
        out_b = model(input_ids=input_ids[:, -1:], cache_params=cache, use_cache=True)
        logits_b = out_b.logits[0, -1, :].float()

    diff = torch.abs(logits_a - logits_b)
    print(f"\n  A vs B:  Max={diff.max().item():.6e}  Mean={diff.mean().item():.6e}  Equal={torch.equal(logits_a, logits_b)}")


def main():
    print("Loading models...")

    # --- Qwen3.5-0.8B ---
    model_id = "Qwen/Qwen3.5-0.8B"
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Test with bfloat16 (model's native dtype)
    model_bf16 = AutoModelForCausalLM.from_pretrained(
        model_id, dtype=torch.bfloat16, device_map="cpu"
    )
    test_kv_cache_consistency(model_bf16, tokenizer, model_id, "bfloat16")
    del model_bf16

    # Test with float32 (to rule out bf16 precision as the cause)
    model_fp32 = AutoModelForCausalLM.from_pretrained(
        model_id, dtype=torch.float32, device_map="cpu"
    )
    test_kv_cache_consistency(model_fp32, tokenizer, model_id, "float32")
    del model_fp32

    # --- Mamba-130M ---
    print("\n\n" + "#" * 70)
    print("# Comparison: Mamba-130M (conv1d + SSM, similar architecture pattern)")
    print("#" * 70)
    test_mamba_kv_cache_consistency()

    # --- Llama-3.2-1B ---
    print("\n\n" + "#" * 70)
    print("# Baseline: Llama-3.2-1B (standard Transformer)")
    print("#" * 70)
    llama_id = "meta-llama/Llama-3.2-1B"
    tokenizer_llama = AutoTokenizer.from_pretrained(llama_id)
    model_llama = AutoModelForCausalLM.from_pretrained(
        llama_id, dtype=torch.float32, device_map="cpu"
    )
    test_kv_cache_consistency(model_llama, tokenizer_llama, llama_id, "float32")


if __name__ == "__main__":
    main()
RAW_BUFFERClick to expand / collapse

System Info

Qwen3.5 GatedDeltaNet: Large logit divergence between full-sequence forward and prefill+decode with cache

System Info

  • transformers version: 5.9.0
  • Platform: macOS (Apple Silicon, arm64)
  • Python version: 3.13
  • PyTorch version: 2.12.0 (CPU)
  • Tokenizers version: 0.22.2
  • Using GPU in script?: No
  • causal-conv1d: NOT installed
  • mamba-ssm: NOT installed

Who can help?

@ArthurZucker @yzhangcs

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Expected Behavior

When computing the logits for the last token of a sequence, the results should be (nearly) identical regardless of whether the model processes the full sequence at once or uses prefill+decode with cache. For reference, Llama-3.2-1B shows max diff = 3.5e-5 and Mamba-130M shows max diff = 2.2e-4.

Actual Behavior

Qwen3.5-0.8B produces a max absolute difference of 0.125 between full-sequence forward and prefill+decode. This is ~3 orders of magnitude larger than other models. The error is identical whether loading in bfloat16 or float32, ruling out precision as the cause.

Minimal Reproduction

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "Qwen/Qwen3.5-0.8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map="cpu")
model.eval()

input_ids = tokenizer.encode("The capital of France is", return_tensors="pt")

# Mode A: Full sequence, no cache
with torch.no_grad():
    logits_full = model(input_ids=input_ids, use_cache=False).logits[0, -1, :].float()

# Mode B: Prefill (first N-1 tokens) + decode last token with cache
with torch.no_grad():
    out_prefix = model(input_ids=input_ids[:, :-1], use_cache=True)
    logits_cached = model(
        input_ids=input_ids[:, -1:],
        past_key_values=out_prefix.past_key_values,
        use_cache=True,
    ).logits[0, -1, :].float()

diff = torch.abs(logits_full - logits_cached)
print(f"Max diff: {diff.max().item():.6e}")   # 1.250000e-01
print(f"Mean diff: {diff.mean().item():.6e}")  # 1.944993e-02

Detailed Test Results

Two inference modes are compared:

  • Mode A — Full-sequence forward, no cache (use_cache=False)
  • Mode B — Prefill first N-1 tokens with cache, then decode last token using the cache

Mode B isolates the issue cleanly: the prefill stage uses the same chunk-based computation as Mode A, and only the final decode step switches to the recurrent path.

Qwen3.5-0.8B (bfloat16)

A vs B:  Max=1.250000e-01  Mean=1.944993e-02  Equal=False

Top-5 largest differences:
Index    Full Forward     Prefill+Decode   Abs Diff       Token
6927     -6.343750        -6.218750        1.250000e-01   'oles'
28100    -1.132812        -1.007812        1.250000e-01   'ois'
124639   -9.125000        -9.250000        1.250000e-01   '版主'
107466   -8.937500        -8.812500        1.250000e-01   '了指'
16697    -8.062500        -7.937500        1.250000e-01   'olen'

Qwen3.5-0.8B (float32)

A vs B:  Max=1.250000e-01  Mean=1.944993e-02  Equal=False

Loading in float32 produces identical error magnitudes, ruling out bf16 precision as the cause.

Mamba-130M (float32, comparison)

Mamba also uses conv1d + recurrent state (SSM), a similar architectural pattern to GatedDeltaNet. Neither causal-conv1d nor mamba-ssm is installed, so the torch fallback path is used (same as Qwen3.5).

A vs B:  Max=2.231598e-04  Mean=3.242540e-05  Equal=False

Llama-3.2-1B (float32, baseline)

A vs B:  Max=3.528595e-05  Mean=5.242007e-06  Equal=False

Per-layer hidden state divergence (Qwen3.5-0.8B, bfloat16, A vs B)

The error grows through layers, with a large jump at the final layer:

Layer    Max Abs Diff     Mean Abs Diff
0        0.000000e+00     0.000000e+00     (embedding, no divergence)
1        1.953125e-03     1.360402e-04
2        3.906250e-03     1.846030e-04
3        1.464844e-03     2.327561e-04
4        1.953125e-03     2.701916e-04
5        1.953125e-03     2.991073e-04
6        3.906250e-03     3.104694e-04
7        2.197266e-03     3.373846e-04
8        2.441406e-03     3.151372e-04
9        2.441406e-03     3.224313e-04
10       2.441406e-03     3.342256e-04
11       2.441406e-03     3.449395e-04
12       3.906250e-03     3.726929e-04
13       3.906250e-03     4.318133e-04
14       3.906250e-03     4.789680e-04
15       7.812500e-03     5.791914e-04
16       7.812500e-03     7.362664e-04
17       1.562500e-02     8.317083e-04
18       1.562500e-02     9.588450e-04
19       1.562500e-02     1.106501e-03
20       1.562500e-02     1.208007e-03
21       1.562500e-02     1.413584e-03
22       1.562500e-02     1.631334e-03
23       3.125000e-02     1.877025e-03
24       2.500000e-01     3.162303e-02     (final layer, large jump)

Summary of Observations

  1. The error magnitude is ~3 orders of magnitude larger than Llama and Mamba. Qwen3.5 shows max diff = 0.125, while Llama shows 3.5e-5 and Mamba shows 2.2e-4.

  2. Float32 does not help. The error is identical in bf16 and fp32, indicating this is not a floating-point precision issue.

  3. The error accumulates layer by layer, starting at layer 1 (1.9e-3) and growing through subsequent layers, with a large jump at the final layer (0.25).

Full Reproduction Script

<details> <summary>Click to expand full script (tests all three models with detailed output)</summary>
"""
Reproduce: Qwen3.5 GatedDeltaNet produces significantly different logits
between full-sequence forward and prefill+decode with KV cache.

Environment:
    pip install torch transformers

Usage:
    python reproduce_qwen35_kv_cache.py
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def test_kv_cache_consistency(model, tokenizer, model_name, dtype):
    """Compare full-sequence forward vs prefill+decode with KV cache."""
    model.eval()

    prompt = "The capital of France is"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    seq_len = input_ids.shape[1]

    print(f"\n{'=' * 70}")
    print(f"Model: {model_name} (dtype={dtype})")
    print(f"{'=' * 70}")
    print(f"Input: \"{prompt}\" (seq_len={seq_len})")

    # Mode A: Full sequence, no cache
    with torch.no_grad():
        out_a = model(input_ids=input_ids, use_cache=False)
        logits_a = out_a.logits[0, -1, :].float()

    # Mode B: Prefill (first N-1 tokens) + Decode (last token with cache)
    with torch.no_grad():
        out_prefix = model(input_ids=input_ids[:, :-1], use_cache=True)
        out_b = model(
            input_ids=input_ids[:, -1:],
            past_key_values=out_prefix.past_key_values,
            use_cache=True,
        )
        logits_b = out_b.logits[0, -1, :].float()

    # Print comparison
    diff = torch.abs(logits_a - logits_b)
    print(f"\n  A vs B:  Max={diff.max().item():.6e}  Mean={diff.mean().item():.6e}  Equal={torch.equal(logits_a, logits_b)}")

    # Top differences
    top5 = torch.topk(diff, 5)
    print(f"\n  Top-5 largest differences:")
    print(f"  {'Index':<8} {'Full Forward':<16} {'Prefill+Decode':<16} {'Abs Diff':<14} {'Token'}")
    for i in range(5):
        idx = top5.indices[i].item()
        print(
            f"  {idx:<8} {logits_a[idx].item():<16.6f} {logits_b[idx].item():<16.6f} "
            f"{top5.values[i].item():<14.6e} '{tokenizer.decode([idx])}'"
        )

    # Per-layer hidden state divergence
    print(f"\n  Per-layer hidden state divergence (last token, A vs B):")
    with torch.no_grad():
        out_full = model(input_ids=input_ids, use_cache=False, output_hidden_states=True)
        out_pfx = model(input_ids=input_ids[:, :-1], use_cache=True, output_hidden_states=True)
        out_dec = model(
            input_ids=input_ids[:, -1:],
            past_key_values=out_pfx.past_key_values,
            use_cache=True,
            output_hidden_states=True,
        )

    num_layers = len(out_full.hidden_states)
    print(f"  {'Layer':<8} {'Max Abs Diff':<16} {'Mean Abs Diff':<16}")
    for layer_idx in range(num_layers):
        h_full = out_full.hidden_states[layer_idx][0, -1, :].float()
        h_dec = out_dec.hidden_states[layer_idx][0, -1, :].float()
        d = torch.abs(h_full - h_dec)
        print(f"  {layer_idx:<8} {d.max().item():<16.6e} {d.mean().item():<16.6e}")


def test_mamba_kv_cache_consistency():
    """Test Mamba-130M (also uses conv1d + recurrent state, like GatedDeltaNet)."""
    model_name = "state-spaces/mamba-130m-hf"
    print(f"\n{'=' * 70}")
    print(f"Model: {model_name} (dtype=float32)")
    print(f"Architecture: MambaMixer (conv1d + SSM recurrent state)")
    print(f"{'=' * 70}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float32, device_map="cpu")
    model.eval()

    prompt = "The capital of France is"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    seq_len = input_ids.shape[1]
    print(f"Input: \"{prompt}\" (seq_len={seq_len})")

    # Mode A: Full sequence, no cache
    with torch.no_grad():
        out_a = model(input_ids=input_ids, use_cache=False)
        logits_a = out_a.logits[0, -1, :].float()

    # Mode B: Prefill + Decode (Mamba uses cache_params)
    with torch.no_grad():
        out_prefix = model(input_ids=input_ids[:, :-1], use_cache=True)
        cache = out_prefix.cache_params if hasattr(out_prefix, "cache_params") else out_prefix.past_key_values
        out_b = model(input_ids=input_ids[:, -1:], cache_params=cache, use_cache=True)
        logits_b = out_b.logits[0, -1, :].float()

    diff = torch.abs(logits_a - logits_b)
    print(f"\n  A vs B:  Max={diff.max().item():.6e}  Mean={diff.mean().item():.6e}  Equal={torch.equal(logits_a, logits_b)}")


def main():
    print("Loading models...")

    # --- Qwen3.5-0.8B ---
    model_id = "Qwen/Qwen3.5-0.8B"
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Test with bfloat16 (model's native dtype)
    model_bf16 = AutoModelForCausalLM.from_pretrained(
        model_id, dtype=torch.bfloat16, device_map="cpu"
    )
    test_kv_cache_consistency(model_bf16, tokenizer, model_id, "bfloat16")
    del model_bf16

    # Test with float32 (to rule out bf16 precision as the cause)
    model_fp32 = AutoModelForCausalLM.from_pretrained(
        model_id, dtype=torch.float32, device_map="cpu"
    )
    test_kv_cache_consistency(model_fp32, tokenizer, model_id, "float32")
    del model_fp32

    # --- Mamba-130M ---
    print("\n\n" + "#" * 70)
    print("# Comparison: Mamba-130M (conv1d + SSM, similar architecture pattern)")
    print("#" * 70)
    test_mamba_kv_cache_consistency()

    # --- Llama-3.2-1B ---
    print("\n\n" + "#" * 70)
    print("# Baseline: Llama-3.2-1B (standard Transformer)")
    print("#" * 70)
    llama_id = "meta-llama/Llama-3.2-1B"
    tokenizer_llama = AutoTokenizer.from_pretrained(llama_id)
    model_llama = AutoModelForCausalLM.from_pretrained(
        llama_id, dtype=torch.float32, device_map="cpu"
    )
    test_kv_cache_consistency(model_llama, tokenizer_llama, llama_id, "float32")


if __name__ == "__main__":
    main()
</details>

Additional Notes

  • The model weights are stored in bfloat16 on disk. Loading in float32 does not change the error magnitude.
  • causal-conv1d and mamba-ssm packages are not installed; the torch fallback paths are used for both Qwen3.5 and Mamba.
  • Qwen3.5 uses Qwen3_5GatedDeltaNet layers which have two computation paths: torch_chunk_gated_delta_rule (used during prefill for sequences > 1 token) and torch_recurrent_gated_delta_rule (used during decode for single-token steps).

Who can help?

@ArthurZucker @Cyrilvallez I don't know if this phenomenon is normal.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

"""
Reproduce: Qwen3.5 GatedDeltaNet produces significantly different logits
between full-sequence forward and prefill+decode with KV cache.

Environment:
    pip install torch transformers

Usage:
    python reproduce_qwen35_kv_cache.py
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def test_kv_cache_consistency(model, tokenizer, model_name, dtype):
    """Compare full-sequence forward vs prefill+decode with KV cache."""
    model.eval()

    prompt = "The capital of France is"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    seq_len = input_ids.shape[1]

    print(f"\n{'=' * 70}")
    print(f"Model: {model_name} (dtype={dtype})")
    print(f"{'=' * 70}")
    print(f"Input: \"{prompt}\" (seq_len={seq_len})")

    # Mode A: Full sequence, no cache
    with torch.no_grad():
        out_a = model(input_ids=input_ids, use_cache=False)
        logits_a = out_a.logits[0, -1, :].float()

    # Mode B: Prefill (first N-1 tokens) + Decode (last token with cache)
    with torch.no_grad():
        out_prefix = model(input_ids=input_ids[:, :-1], use_cache=True)
        out_b = model(
            input_ids=input_ids[:, -1:],
            past_key_values=out_prefix.past_key_values,
            use_cache=True,
        )
        logits_b = out_b.logits[0, -1, :].float()

    # Print comparison
    diff = torch.abs(logits_a - logits_b)
    print(f"\n  A vs B:  Max={diff.max().item():.6e}  Mean={diff.mean().item():.6e}  Equal={torch.equal(logits_a, logits_b)}")

    # Top differences
    top5 = torch.topk(diff, 5)
    print(f"\n  Top-5 largest differences:")
    print(f"  {'Index':<8} {'Full Forward':<16} {'Prefill+Decode':<16} {'Abs Diff':<14} {'Token'}")
    for i in range(5):
        idx = top5.indices[i].item()
        print(
            f"  {idx:<8} {logits_a[idx].item():<16.6f} {logits_b[idx].item():<16.6f} "
            f"{top5.values[i].item():<14.6e} '{tokenizer.decode([idx])}'"
        )

    # Per-layer hidden state divergence
    print(f"\n  Per-layer hidden state divergence (last token, A vs B):")
    with torch.no_grad():
        out_full = model(input_ids=input_ids, use_cache=False, output_hidden_states=True)
        out_pfx = model(input_ids=input_ids[:, :-1], use_cache=True, output_hidden_states=True)
        out_dec = model(
            input_ids=input_ids[:, -1:],
            past_key_values=out_pfx.past_key_values,
            use_cache=True,
            output_hidden_states=True,
        )

    num_layers = len(out_full.hidden_states)
    print(f"  {'Layer':<8} {'Max Abs Diff':<16} {'Mean Abs Diff':<16}")
    for layer_idx in range(num_layers):
        h_full = out_full.hidden_states[layer_idx][0, -1, :].float()
        h_dec = out_dec.hidden_states[layer_idx][0, -1, :].float()
        d = torch.abs(h_full - h_dec)
        print(f"  {layer_idx:<8} {d.max().item():<16.6e} {d.mean().item():<16.6e}")


def test_mamba_kv_cache_consistency():
    """Test Mamba-130M (also uses conv1d + recurrent state, like GatedDeltaNet)."""
    model_name = "state-spaces/mamba-130m-hf"
    print(f"\n{'=' * 70}")
    print(f"Model: {model_name} (dtype=float32)")
    print(f"Architecture: MambaMixer (conv1d + SSM recurrent state)")
    print(f"{'=' * 70}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float32, device_map="cpu")
    model.eval()

    prompt = "The capital of France is"
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    seq_len = input_ids.shape[1]
    print(f"Input: \"{prompt}\" (seq_len={seq_len})")

    # Mode A: Full sequence, no cache
    with torch.no_grad():
        out_a = model(input_ids=input_ids, use_cache=False)
        logits_a = out_a.logits[0, -1, :].float()

    # Mode B: Prefill + Decode (Mamba uses cache_params)
    with torch.no_grad():
        out_prefix = model(input_ids=input_ids[:, :-1], use_cache=True)
        cache = out_prefix.cache_params if hasattr(out_prefix, "cache_params") else out_prefix.past_key_values
        out_b = model(input_ids=input_ids[:, -1:], cache_params=cache, use_cache=True)
        logits_b = out_b.logits[0, -1, :].float()

    diff = torch.abs(logits_a - logits_b)
    print(f"\n  A vs B:  Max={diff.max().item():.6e}  Mean={diff.mean().item():.6e}  Equal={torch.equal(logits_a, logits_b)}")


def main():
    print("Loading models...")

    # --- Qwen3.5-0.8B ---
    model_id = "Qwen/Qwen3.5-0.8B"
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Test with bfloat16 (model's native dtype)
    model_bf16 = AutoModelForCausalLM.from_pretrained(
        model_id, dtype=torch.bfloat16, device_map="cpu"
    )
    test_kv_cache_consistency(model_bf16, tokenizer, model_id, "bfloat16")
    del model_bf16

    # Test with float32 (to rule out bf16 precision as the cause)
    model_fp32 = AutoModelForCausalLM.from_pretrained(
        model_id, dtype=torch.float32, device_map="cpu"
    )
    test_kv_cache_consistency(model_fp32, tokenizer, model_id, "float32")
    del model_fp32

    # --- Mamba-130M ---
    print("\n\n" + "#" * 70)
    print("# Comparison: Mamba-130M (conv1d + SSM, similar architecture pattern)")
    print("#" * 70)
    test_mamba_kv_cache_consistency()

    # --- Llama-3.2-1B ---
    print("\n\n" + "#" * 70)
    print("# Baseline: Llama-3.2-1B (standard Transformer)")
    print("#" * 70)
    llama_id = "meta-llama/Llama-3.2-1B"
    tokenizer_llama = AutoTokenizer.from_pretrained(llama_id)
    model_llama = AutoModelForCausalLM.from_pretrained(
        llama_id, dtype=torch.float32, device_map="cpu"
    )
    test_kv_cache_consistency(model_llama, tokenizer_llama, llama_id, "float32")


if __name__ == "__main__":
    main()

Expected behavior

Described above

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

Described above

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING