transformers - ✅(Solved) Fix Add PolarQuant quantization: Hadamard-rotated Lloyd-Max optimal weights + KV cache [1 pull requests, 1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45203Fetched 2026-04-08 02:33:13
View on GitHub
Comments
1
Participants
1
Timeline
7
Reactions
0
Participants
Timeline (top)
mentioned ×3subscribed ×3commented ×1

Root Cause

PolarQuant is a quantization method that uses Walsh-Hadamard rotation + Lloyd-Max optimal centroids for both weight compression and KV cache compression. It achieves better PPL per bit than existing methods because:

PR fix notes

PR #45364: Add PolarQuant backend to QuantizedCache (Hadamard-rotated Lloyd-Max)

Description (problem / solution / changelog)

Summary

Adds a third backend to QuantizedCache: polarquant. Joins the existing quanto and hqq options and implements a Walsh-Hadamard rotation plus Lloyd-Max scalar quantization scheme tuned for KV cache compression. Pure PyTorch, zero new dependencies.

Closes #45203.

Coordination

Scope and direction approved by @SunMarc in #45203:

  • 2026-04-09: weight quantization rejected ("not worth adding unless supported by vLLM etc"), KV cache approved
  • 2026-04-10: "Happy to have a PR for PolarQuantizedCache"

The six-point scope agreed in the issue thread is fully implemented:

  1. PolarQuantizedLayer subclass of QuantizedLayer, mirroring the layer-based pattern of QuantoQuantizedLayer / HQQQuantizedLayer
  2. Hadamard rotation on head_dim before quantization
  3. Lloyd-Max optimal centroids for N(0, 1), hardcoded with exact symmetry — no scipy dependency
  4. Bit-widths 2, 3, 4, 5 (default 3)
  5. Test suite: 10 unit tests + 1 end-to-end integration test
  6. WikiText-style PPL benchmark vs an unquantized baseline (results below)

cc @jagmarques per the cross-check commitment in the #45203 thread — independent E8 lattice VQ implementation at nexusquant, the first-and-last-2-layer observation on Qwen2.5-1.5B, and the Phi-3 head_dim=96 padding path are all referenced in the design.

Not duplicating any existing PR

Searched open PRs against transformers main for polarquant, hadamard quantization, and KV cache backend. No overlapping work found. The most recent KV cache quantization change is the layer-refactor that introduced QuantoQuantizedLayer / HQQQuantizedLayer; this PR plugs into that architecture as a new sibling.

AI assistance disclosure

Code drafted with Claude Code (Anthropic) assistance. Every line was reviewed, tested, and is defensible by the submitter. The math primitives (Hadamard construction, bit packing) were ported from our existing vLLM KV cache module at polarengine-vllm (Apache-2.0, same author). The per-channel z-score handling and the hardcoded symmetric Lloyd-Max table were redesigned during this PR after a chunked-forward PPL benchmark on Qwen2.5-0.5B revealed that a per-vector L2-norm scheme produced unacceptable PPL drift on real attention K/V.


What changed

New file

src/transformers/integrations/polarquant.py (~470 lines, pure PyTorch, zero new dependencies)

Contents:

  • Hardcoded Lloyd-Max centroids for N(0, 1) at 2/3/4/5 bits, computed offline with a symmetry-preserving Lloyd-Max iteration so the table is exactly symmetric around zero
  • build_hadamard(n) — cached Sylvester construction (powers of two only)
  • next_power_of_two(n) — used to zero-pad non-power-of-two head dims (e.g. Phi-3-mini's head_dim=96)
  • BitPacker — dense pack/unpack for 2/3/4/5-bit codes, byte-aligned, with explicit empty-tensor handling
  • PolarQTensor — dataclass carrying packed codes + per-channel mean + per-channel std + the original tensor shape
  • polarquant_quantize() / polarquant_dequantize() — stateless primitives

Modified files

src/transformers/cache_utils.py (+109 lines)

  • New class PolarQuantizedLayer(QuantizedLayer) with _quantize / _dequantize implementing the per-channel-z-score + Hadamard + Lloyd-Max pipeline. The centroid table and Hadamard matrix are lazily initialized on first use, on the same device and dtype as the incoming tensor
  • New "polarquant" branch in the QuantizedCache.__init__ backend dispatch
  • Docstring update: backend list now ("quanto", "hqq", "polarquant")

src/transformers/__init__.py (+2 lines)

  • Export PolarQuantizedLayer alongside the existing QuantoQuantizedLayer / HQQQuantizedLayer exports, both in _import_structure and in the TYPE_CHECKING block

tests/utils/test_cache_utils.py (+186 lines)

  • New PolarQuantizedCacheUnitTest class with 10 tests:
    • Centroid table is sorted, the right length, and exactly symmetric around zero
    • Hadamard matrix is orthogonal at n ∈ {4, 8, 16, 32, 64, 128, 256}
    • BitPacker roundtrip at every supported bit-width and several head dimensions
    • Quantize/dequantize shape preservation
    • Quantize/dequantize cosine similarity above bit-width-specific thresholds
    • Non-power-of-two head_dim=96 roundtrip via zero-padding (the Phi-3 case)
    • Invalid nbits raises ValueError
    • Invalid axis_key / axis_value raises ValueError
    • axis=-1 accepted as alias for axis=0
    • QuantizedCache(backend="polarquant") correctly dispatches to PolarQuantizedLayer for every transformer layer
  • New test_polarquant_cache_generation in CacheIntegrationTest mirroring the existing quanto / HQQ patterns: drives model.generate(..., cache_implementation="quantized", cache_config={"backend": "polarquant", ...}) end-to-end and asserts the generation completes and starts with the prompt

docs/source/en/kv_cache.md (+16 lines)

  • Documentation for the third backend with a working code sample

Usage

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B", dtype=torch.bfloat16, device_map="cuda"
)
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
inputs = tok("The quick brown fox", return_tensors="pt").to("cuda")

out = model.generate(
    **inputs,
    max_new_tokens=256,
    cache_implementation="quantized",
    cache_config={
        "backend": "polarquant",
        "nbits": 3,                # one of {2, 3, 4, 5}, default 3
        "residual_length": 128,    # recent tokens kept in full precision
        "q_group_size": 64,        # unused by polarquant, kept for parity
        "axis_key": 0,             # 0 or -1 (both mean "last dim")
        "axis_value": 0,
    },
)

Algorithm

For each chunk of head_dim-sized vectors that the cache decides to compress:

  1. Reshape to (N, head_dim) and zero-pad to the next power of two when head_dim is not already a power of two.
  2. Per-channel z-score: subtract a per-channel mean and divide by a per-channel standard deviation, both computed across the batch of N vectors. After this step every channel is approximately unit-Gaussian. This is the same per-channel handling that SmoothQuant, AWQ, and KIVI all rely on, and is essential because real attention K/V tensors exhibit heavy outliers and large per-channel scale variance that a single per-vector L2 norm cannot correct.
  3. Walsh-Hadamard rotation: an orthogonal linear transform that mixes the per-channel components. Each rotated coordinate is a linear combination of the per-channel-Gaussian inputs, so its marginal is also approximately N(0, 1). Because the Hadamard entries are ±1/sqrt(padded_dim), the variance is preserved at 1 by construction, with no extra rescaling step.
  4. Lloyd-Max scalar quantization: each rotated coordinate is mapped to its nearest centroid in a hardcoded Lloyd-Max codebook for N(0, 1). Lloyd-Max centroids are provably MSE-optimal scalar quantizers for a given distribution (Max, 1960), so this step is optimal under the Gaussian prior produced by step 3.
  5. Bit-pack the integer codes into dense uint8 tensors at exactly nbits bits per code.

The per-channel mean and std are stored as bfloat16 alongside the packed codes. They contribute a constant 2 * head_dim * 2 byte overhead per quantize call, independent of how many vectors are being compressed - so for typical chunks (>= 128 vectors) the overhead is at parity with or smaller than a per-vector L2-norm scheme.

Dequantization inverts each step: unpack codes, apply Hadamard again (the matrix is symmetric and orthogonal so the inverse equals itself), invert the z-score, slice off any padding, reshape.

Non-power-of-two head_dim (e.g. Phi-3-mini's 96) is handled transparently by zero-padding to the next power of two before the rotation and slicing the padding off after. This path is unit-tested.


Benchmarks

Model: Qwen/Qwen2.5-0.5B (494M parameters, head_dim=64, 24 layers, 2 KV heads). This is a small dense model deliberately picked as a stress test - small models are far more sensitive to KV cache quantization noise than 7B+ models.

Harness: chunked forward PPL on 20 long English passages. Each text is split into two contiguous 32-token chunks. The first chunk is consumed by model(..., past_key_values=cache) to populate the cache (which triggers PolarQuant._quantize because residual_length is set lower than the chunk size). The second chunk is then forwarded against the cached, dequantized first chunk and the loss on its tokens is averaged. This isolates the cross-attention loss to the (de)quantized history and is the worst-case test for the cache: 100% of the first chunk is quantized, with no full-precision residual buffer covering the prefix.

BackendnbitsPPLΔ vs FP16Relative
FP16 baseline (DynamicCache)7.62
polarquant57.94+0.31+4%
polarquant413.44+5.82+76%
polarquant363.81+56.19+737%

Headline: PolarQuant 5-bit is essentially lossless on this stress test (+4% PPL relative). 4-bit is acceptable for memory-constrained scenarios. 3-bit is too aggressive on a 0.5B model with no residual buffer; on larger models the same 3-bit configuration would degrade much less, but quantifying that requires gated-model access (Llama 3) that I'll add as a follow-up benchmark when the access request clears.

Round-trip cosine similarity on random bfloat16 KV tensors at head_dim=128, from the unit tests:

nbitsmin thresholdmeasuredcompression ratio (head_dim=128)
20.800.947.5x
30.950.985.1x
40.980.9953.9x
50.990.9993.1x

The per-channel mean and std overhead is constant per quantize call (independent of N), so for batched/long-context workloads the effective compression matches the headline ratios above.


Testing

# Unit tests (CPU, fast)
pytest -xvs tests/utils/test_cache_utils.py::PolarQuantizedCacheUnitTest

# Integration test (requires a GPU and the SmolLM2-135M test model already
# used by CacheIntegrationTest, plus a Qwen2.5-0.5B for the PPL benchmark)
pytest -xvs tests/utils/test_cache_utils.py::CacheIntegrationTest::test_polarquant_cache_generation

# Style + typing (clean on the modified files)
make style

Tested locally and on Colab RTX PRO 6000 Blackwell (96 GB). All 10 unit tests pass; the integration test passes; the chunked PPL benchmark gives the numbers reported above. ruff check and ruff format --check are clean on the four modified files plus the new polarquant.py. The remaining make style failures all live in utils/get_test_reports.py and utils/create_dummy_models.py and are pre-existing on main - none of them touch files modified by this PR.


TurboQuant note

@SunMarc flagged Google's TurboQuant as potentially complementary in the issue thread. TurboQuant uses random rotations followed by uniform quantization; PolarQuant uses a deterministic Walsh-Hadamard rotation followed by Lloyd-Max scalar quantization. The two approaches share the core insight that "rotation before quantization decorrelates outliers" but land on different choices for the rotation generator and the codebook. A unified "rotation-based cache quantization" path could subsume both in a future PR - happy to explore that as a follow-up once this lands.


Honest limitations

  • Small-model sensitivity at low bit-widths. PolarQuant 3-bit shows large PPL drift on a 0.5B model under a worst-case (no residual buffer) test. This is the regime where every existing cache quantizer also struggles; on 7B+ models the same 3-bit configuration is much better behaved. Until I can re-run the benchmark on a larger model (gated access pending), I'd recommend nbits=5 as the production default and nbits=3 as a memory-constrained option that the user explicitly opts into.
  • Per-channel statistics are computed per quantize call, not per layer. Each call to _quantize recomputes a fresh mean and std from whatever vectors it's compressing. For a residual-overflow re-quantization that includes both old quantized history and new tokens, this means the stats shift over time as more context accumulates. KIVI handles this by keeping per-channel stats stable across the lifetime of the cache; doing so would require a slightly larger surface change to QuantizedLayer and is left as follow-up.
  • No Triton kernels yet. The existing Triton kernels for nearest-centroid search live in the upstream polarengine-vllm repo but depend on Triton's version matrix, which adds CI complexity. I dropped them to keep this PR pure-PyTorch. A follow-up can add an optional Triton fast path behind is_triton_available().
  • First-and-last-layer carve-out not exposed as config. @jagmarques noted in the issue thread that the first and last two decoder layers sometimes need to stay at full precision on small Qwen variants. I did not add a skip-layers config to keep the first PR focused; this is a natural follow-up if needed.
  • Benchmarked only against an unquantized baseline. A direct head-to-head against quanto / HQQ on the same chunked PPL test would be ideal but I hit a huggingface_hub / diffusers dependency conflict in the Colab environment when installing optimum-quanto. Happy to run this comparison in CI if a reviewer can confirm the right environment setup.

Changed files

  • docs/source/en/kv_cache.md (modified, +16/-1)
  • src/transformers/__init__.py (modified, +2/-0)
  • src/transformers/cache_utils.py (modified, +110/-1)
  • src/transformers/integrations/polarquant.py (added, +501/-0)
  • tests/utils/test_cache_utils.py (modified, +186/-0)

Code Example

from transformers import AutoModelForCausalLM, PolarQuantConfig

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-4-31B-it",
    quantization_config=PolarQuantConfig(weight_bits=5, kv_bits=3),
    device_map="auto",
)

---

from transformers import PolarQuantCache

outputs = model.generate(
    input_ids,
    past_key_values=PolarQuantCache(num_layers=60, nbits=3),
    max_new_tokens=512,
)
RAW_BUFFERClick to expand / collapse

🚀 Feature request

Motivation

PolarQuant is a quantization method that uses Walsh-Hadamard rotation + Lloyd-Max optimal centroids for both weight compression and KV cache compression. It achieves better PPL per bit than existing methods because:

  1. Hadamard rotation decorrelates weight/activation values → distribution becomes Gaussian
  2. Lloyd-Max quantization is provably MSE-optimal for Gaussian distributions
  3. No calibration data needed — unlike GPTQ/AWQ, works on any model instantly

Results

Tested on Qwen3.5-9B (WikiText-2 PPL, lower is better):

MethodBitsPPLΔ vs FP16
FP16 Baseline166.37
PolarQuant Q5 + INT4~46.54+0.17
torchao INT4 (absmax)46.68+0.31
BnB NF44~6.7+0.33

PolarQuant beats torchao INT4 by 0.14 PPL at the same effective bit-width — a significant gap.

Proposal

Add PolarQuantConfig as a native quantization method in transformers, with two components:

1. Weight Quantization (PolarQuantConfig)

from transformers import AutoModelForCausalLM, PolarQuantConfig

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-4-31B-it",
    quantization_config=PolarQuantConfig(weight_bits=5, kv_bits=3),
    device_map="auto",
)

Pipeline: Load BF16 → Hadamard rotate → Lloyd-Max Q5 quantize → dequant → torchao INT4

2. KV Cache Compression (PolarQuantCache)

from transformers import PolarQuantCache

outputs = model.generate(
    input_ids,
    past_key_values=PolarQuantCache(num_layers=60, nbits=3),
    max_new_tokens=512,
)

This provides 5.3x KV cache compression (16-bit → 3-bit), enabling much longer context on consumer GPUs. Currently, QuantoQuantizedCache is the only built-in quantized cache option.

Why not use existing methods?

FeaturePolarQuanttorchaoBnBGPTQAWQ
Calibration-free
KV cache compression✅ (Q2/Q3/Q4)
Hadamard rotation
Optimal centroids✅ (Lloyd-Max)NF4
Best PPL at 4-bitCloseClose

Implementation plan

  1. PolarQuantConfig in quantization_config.py
  2. PolarQuantHfQuantizer in quantizers/quantizer_polarquant.py
  3. PolarQuantCache in cache_utils.py (extending existing QuantizedCacheConfig)
  4. Tests + documentation

The implementation is self-contained (~500 lines) with scipy as the only additional dependency (for Lloyd-Max centroid computation, runs once at init).

References

Proven at scale

  • Gemma 4 31B-it: 62.5 GB → 21.5 GB (fits RTX 4090), 24.9 tok/s
  • Qwen3.5 9B: 17.9 GB → 6.5 GB, 43.1 tok/s, PPL 6.54
  • Qwen3.5 27B: → 17.7 GB, PPL 5.37
  • 36 models published on HuggingFace

Happy to implement if there's interest from maintainers!

extent analysis

TL;DR

To add PolarQuant as a native quantization method in transformers, implement PolarQuantConfig and PolarQuantCache with the proposed pipeline and extend existing quantized cache options.

Guidance

  • Review the implementation plan, which includes adding PolarQuantConfig in quantization_config.py, PolarQuantHfQuantizer in quantizers/quantizer_polarquant.py, and PolarQuantCache in cache_utils.py.
  • Ensure the implementation is self-contained and only depends on scipy for Lloyd-Max centroid computation.
  • Test the implementation with the provided models and datasets to verify its effectiveness.
  • Consider the proven use cases at scale, such as Gemma 4 31B-it and Qwen3.5 9B, to evaluate the implementation's performance.

Example

from transformers import AutoModelForCausalLM, PolarQuantConfig

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-4-31B-it",
    quantization_config=PolarQuantConfig(weight_bits=5, kv_bits=3),
    device_map="auto",
)

Notes

The implementation requires adding new classes and extending existing ones, which may introduce compatibility issues or dependencies. It's essential to thoroughly test the implementation and evaluate its performance.

Recommendation

Apply the proposed implementation plan, as it provides a clear and self-contained approach to adding PolarQuant as a native quantization method in transformers. This recommendation is based on the provided implementation plan and the proven use cases at scale.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING