pytorch - 💡(How to fix) Fix `torch.compile` produces different uint8 quantization output for E8M0 bit-manipulation pattern (`view(int32) → bitshift → clamp → uint8`) compared to eager mode [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179569Fetched 2026-04-08 03:00:18
View on GitHub
Comments
0
Participants
1
Timeline
109
Reactions
0
Author
Participants
Timeline (top)
mentioned ×50subscribed ×50labeled ×9

Error Message

Error logs

No error — the compiled model produces a result, but it differs from eager:

Code Example

import torch
import torch.nn as nn
import torch.nn.functional as F

class E8M0QuantizationNetwork(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # Feature extraction: Conv + BN + ReLU
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool = nn.AdaptiveAvgPool2d(8)
        self.fc = nn.Linear(64 * 8 * 8, num_classes)

    def fp32_to_e8m0_rceil(self, x):
        """E8M0 quantization via bit manipulation with ceiling rounding."""
        x = x.contiguous()
        # Reinterpret float32 bits as int32
        x_bits = x.view(torch.int32)
        # Extract biased exponent: bits [30:23]
        biased_exp = (x_bits >> 23) & 0xFF
        # Extract mantissa: bits [22:0]
        mantissa = x_bits & 0x7FFFFF
        # Ceiling rounding: round up if any mantissa bit is set
        needs_round_up = (mantissa != 0).to(torch.int32)
        e8m0_biased = biased_exp + needs_round_up
        # Clamp to valid uint8 range
        e8m0_biased = torch.clamp(e8m0_biased, 0, 255)
        return e8m0_biased.to(torch.uint8)

    def forward(self, x):
        # Feature extraction
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)

        # E8M0 quantization via bit manipulation
        quantized = self.fp32_to_e8m0_rceil(x)

        # Continue with quantized values as float for classification
        x = quantized.to(torch.float32)
        x = x.flatten(1)
        logits = self.fc(x)
        return logits, quantized

# NOTE: CPU, not CUDA
device = "cpu"
torch.manual_seed(42)
model = E8M0QuantizationNetwork().to(device).eval()
x = torch.randn(4, 3, 64, 64, device=device)

# Eager: deterministic
with torch.no_grad():
    ref_logits, ref_quant = model(x)
    ref_logits2, ref_quant2 = model(x)
print(f"Eager deterministic (logits): {(ref_logits - ref_logits2).abs().max().item():.6e}")
print(f"Eager deterministic (quant):  {(ref_quant.int() - ref_quant2.int()).abs().max().item()}")

# Compiled: different result
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp_logits, comp_quant = compiled(x)

logit_diff = (ref_logits.float() - comp_logits.float()).abs()
quant_diff = (ref_quant.int() - comp_quant.int()).abs()
print(f"Logits max_diff={logit_diff.max().item():.6e}")
print(f"Quant uint8 mismatches: {(quant_diff > 0).sum().item()} / {ref_quant.numel()}")
print(f"Quant max_diff={quant_diff.max().item()}")

---

Eager deterministic (logits): 0.000000e+00
Eager deterministic (quant):  0
Logits and quantized values differ systematically under torch.compile

---

PyTorch version: 2.12.0.dev20260327+cu126
Python: 3.10.12
OS: Ubuntu 22.04.5 LTS (WSL2)
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with inductor backend produces numerically different results compared to eager mode when the model uses E8M0 float-to-uint8 quantization via direct bit manipulation. The pattern reinterprets float32 tensors as int32 (x.view(torch.int32)), extracts exponent bits via bitwise right-shift (>> 23), masks with & 0xFF, extracts mantissa via & 0x7FFFFF, performs ceiling rounding based on mantissa, clamps to [0, 255], and converts to uint8.

The Inductor backend's graph optimization interacts differently with the bit manipulation operations, causing rounding boundary differences in the quantized uint8 values. Small floating-point precision differences in the upstream Conv2d+BN+ReLU pipeline cross discrete quantization boundaries, producing different uint8 encodings.

Note: This reproduces on CPU (not CUDA), indicating the issue is in Inductor's CPU code generation for bit manipulation operations, not GPU-specific.

Eager mode is perfectly deterministic (max_var=0 across runs), ruling out non-determinism.

Minimal reproducer

import torch
import torch.nn as nn
import torch.nn.functional as F

class E8M0QuantizationNetwork(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # Feature extraction: Conv + BN + ReLU
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool = nn.AdaptiveAvgPool2d(8)
        self.fc = nn.Linear(64 * 8 * 8, num_classes)

    def fp32_to_e8m0_rceil(self, x):
        """E8M0 quantization via bit manipulation with ceiling rounding."""
        x = x.contiguous()
        # Reinterpret float32 bits as int32
        x_bits = x.view(torch.int32)
        # Extract biased exponent: bits [30:23]
        biased_exp = (x_bits >> 23) & 0xFF
        # Extract mantissa: bits [22:0]
        mantissa = x_bits & 0x7FFFFF
        # Ceiling rounding: round up if any mantissa bit is set
        needs_round_up = (mantissa != 0).to(torch.int32)
        e8m0_biased = biased_exp + needs_round_up
        # Clamp to valid uint8 range
        e8m0_biased = torch.clamp(e8m0_biased, 0, 255)
        return e8m0_biased.to(torch.uint8)

    def forward(self, x):
        # Feature extraction
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)

        # E8M0 quantization via bit manipulation
        quantized = self.fp32_to_e8m0_rceil(x)

        # Continue with quantized values as float for classification
        x = quantized.to(torch.float32)
        x = x.flatten(1)
        logits = self.fc(x)
        return logits, quantized

# NOTE: CPU, not CUDA
device = "cpu"
torch.manual_seed(42)
model = E8M0QuantizationNetwork().to(device).eval()
x = torch.randn(4, 3, 64, 64, device=device)

# Eager: deterministic
with torch.no_grad():
    ref_logits, ref_quant = model(x)
    ref_logits2, ref_quant2 = model(x)
print(f"Eager deterministic (logits): {(ref_logits - ref_logits2).abs().max().item():.6e}")
print(f"Eager deterministic (quant):  {(ref_quant.int() - ref_quant2.int()).abs().max().item()}")

# Compiled: different result
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp_logits, comp_quant = compiled(x)

logit_diff = (ref_logits.float() - comp_logits.float()).abs()
quant_diff = (ref_quant.int() - comp_quant.int()).abs()
print(f"Logits max_diff={logit_diff.max().item():.6e}")
print(f"Quant uint8 mismatches: {(quant_diff > 0).sum().item()} / {ref_quant.numel()}")
print(f"Quant max_diff={quant_diff.max().item()}")

Behavior summary

ModeResultNotes
EagerReference outputPerfectly deterministic across runs (max_var=0)
torch.compile(backend="inductor")Different outputuint8 quantized values and classification logits differ

Notes

  • Eager mode is perfectly deterministic (max_var=0 for both logits and quantized uint8 values), confirming this is a systematic difference.
  • The bug reproduces on CPU — this is not GPU-specific.
  • The bit manipulation chain (view(int32) → >>23 → &0xFF → &0x7FFFFF → comparison → add → clamp → to(uint8)) is the core pattern where Inductor's optimization produces different intermediate values.
  • Small floating-point precision differences from Conv2d/BN/ReLU cross discrete quantization boundaries at exponent/mantissa edges, flipping uint8 values by ±1.
  • The view(torch.int32) reinterpretation is the critical operation — it exposes raw float bit patterns to integer arithmetic, where any upstream precision change (even ULP-level) can change the exponent extraction result.

Error logs

No error — the compiled model produces a result, but it differs from eager:

Eager deterministic (logits): 0.000000e+00
Eager deterministic (quant):  0
Logits and quantized values differ systematically under torch.compile

Versions

PyTorch version: 2.12.0.dev20260327+cu126
Python: 3.10.12
OS: Ubuntu 22.04.5 LTS (WSL2)
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @aditew01 @chauhang @penguinwu @voznesenskym @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @ezyang @msaroufim @bdhirsh @anijain2305

topic: fuzzer

extent analysis

TL;DR

The most likely fix is to modify the fp32_to_e8m0_rceil function to avoid using bit manipulation operations that are sensitive to floating-point precision differences.

Guidance

  • Identify the specific bit manipulation operations in the fp32_to_e8m0_rceil function that are causing the differences in results between eager and compiled modes.
  • Consider using alternative quantization methods that are less sensitive to floating-point precision differences, such as using PyTorch's built-in quantization functions.
  • Verify that the modified function produces the same results in both eager and compiled modes.
  • Test the modified function with different input values to ensure that it is working correctly.

Example

def fp32_to_e8m0_rceil(x):
    # Use PyTorch's built-in quantization functions instead of bit manipulation
    x = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.qint8)
    return x

Note that this is just an example and may not be the exact solution, as the original code is using a custom quantization method.

Notes

  • The issue is specific to the CPU backend and does not occur on GPU.
  • The view(torch.int32) operation is the critical point where the precision difference is introduced.
  • The use of bit manipulation operations can make the code sensitive to floating-point precision differences.

Recommendation

Apply a workaround by modifying the fp32_to_e8m0_rceil function to use alternative quantization methods that are less sensitive to floating-point precision differences. This will likely require testing and verification to ensure that the modified function produces the correct results.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING