pytorch - 💡(How to fix) Fix `torch.compile` produces different classification output for model using E8M0 bit-manipulation quantization (`view(int32) → bitshift → clamp → uint8`) [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178880Fetched 2026-04-08 01:57:17
View on GitHub
Comments
1
Participants
1
Timeline
77
Reactions
0
Author
Participants
Timeline (top)
mentioned ×35subscribed ×35labeled ×6commented ×1

Error Message

Error logs

No error — outputs silently differ.

Root Cause

Root cause hypothesis

Code Example

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

class E8M0QuantizationBlock(nn.Module):
    def __init__(self, in_channels=16, out_channels=32):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def e8m0_quantize(self, x):
        x = x.contiguous()
        x_bits = x.view(torch.int32)
        biased_exp = (x_bits >> 23) & 0xFF
        mantissa = x_bits & 0x7FFFFF
        needs_round_up = mantissa != 0
        e8m0_biased = biased_exp + needs_round_up.to(torch.int32)
        e8m0_biased = torch.clamp(e8m0_biased, 0, 255)
        return e8m0_biased.to(torch.uint8)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        quantized = self.e8m0_quantize(x)
        x = quantized.to(torch.float32)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        return x, quantized


class MultiScaleE8M0Model(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.quant_block1 = E8M0QuantizationBlock(32, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.quant_block2 = E8M0QuantizationBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.quant_block3 = E8M0QuantizationBlock(128, 256)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x, q1 = self.quant_block1(x)
        x = self.pool1(x)
        x, q2 = self.quant_block2(x)
        x = self.pool2(x)
        x, q3 = self.quant_block3(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        return logits, (q1, q2, q3)


device = "cuda"
torch.manual_seed(42)
model = MultiScaleE8M0Model().to(device).eval()
x = torch.randn(4, 3, 64, 64, device=device)

# Eager
with torch.no_grad():
    eager_logits, eager_quants = model(x)

# Compiled
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp_logits, comp_quants = compiled(x)

# Compare
diff = (eager_logits - comp_logits).abs().max().item()
print(f"Logits max diff: {diff}")
print(f"Match: {torch.allclose(eager_logits, comp_logits, atol=1e-5)}")
# Expected: max_diff > 0.1 — significant mismatch

---

PyTorch version: 2.12.0.dev20260327+cu126
CUDA used to build PyTorch: 12.6
OS: Ubuntu 22.04.5 LTS (x86_64)WSL2
Python version: 3.10.12
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with inductor backend produces numerically different classification logits for a model that includes E8M0 (8-bit exponent-only) quantization via direct bit manipulation. The model reinterprets float32 tensors as int32 (x.view(torch.int32)), extracts exponent bits via bitwise shift, applies ceiling rounding based on mantissa, clamps, and converts to uint8. The classification output after quantization differs between eager and compiled modes.

Minimal reproducer

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

class E8M0QuantizationBlock(nn.Module):
    def __init__(self, in_channels=16, out_channels=32):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def e8m0_quantize(self, x):
        x = x.contiguous()
        x_bits = x.view(torch.int32)
        biased_exp = (x_bits >> 23) & 0xFF
        mantissa = x_bits & 0x7FFFFF
        needs_round_up = mantissa != 0
        e8m0_biased = biased_exp + needs_round_up.to(torch.int32)
        e8m0_biased = torch.clamp(e8m0_biased, 0, 255)
        return e8m0_biased.to(torch.uint8)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        quantized = self.e8m0_quantize(x)
        x = quantized.to(torch.float32)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        return x, quantized


class MultiScaleE8M0Model(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.quant_block1 = E8M0QuantizationBlock(32, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.quant_block2 = E8M0QuantizationBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.quant_block3 = E8M0QuantizationBlock(128, 256)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x, q1 = self.quant_block1(x)
        x = self.pool1(x)
        x, q2 = self.quant_block2(x)
        x = self.pool2(x)
        x, q3 = self.quant_block3(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        return logits, (q1, q2, q3)


device = "cuda"
torch.manual_seed(42)
model = MultiScaleE8M0Model().to(device).eval()
x = torch.randn(4, 3, 64, 64, device=device)

# Eager
with torch.no_grad():
    eager_logits, eager_quants = model(x)

# Compiled
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp_logits, comp_quants = compiled(x)

# Compare
diff = (eager_logits - comp_logits).abs().max().item()
print(f"Logits max diff: {diff}")
print(f"Match: {torch.allclose(eager_logits, comp_logits, atol=1e-5)}")
# Expected: max_diff > 0.1 — significant mismatch

Behavior summary

ComponentEagertorch.compileConsistent?
Classification logitsReferenceDiffers (max_diff observed)No
E8M0 uint8 encodingReferenceMay differ due to float precisionNo

Root cause hypothesis

The Inductor backend's handling of x.view(torch.int32) reinterpretation combined with bitwise operations (>>, &) introduces precision differences in the upstream Conv2d/BatchNorm2d pipeline (similar to GELU fusion precision). These tiny float differences propagate through the discrete quantization boundary, causing different uint8 encodings and ultimately different classification outputs.

Error logs

No error — outputs silently differ.

Versions

PyTorch version: 2.12.0.dev20260327+cu126
CUDA used to build PyTorch: 12.6
OS: Ubuntu 22.04.5 LTS (x86_64) — WSL2
Python version: 3.10.12
GPU: NVIDIA GeForce RTX 3060 Laptop GPU

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The most likely fix is to modify the e8m0_quantize method to reduce precision differences introduced by the Inductor backend.

Guidance

  • Investigate the e8m0_quantize method and consider using a more robust quantization approach that minimizes precision differences.
  • Verify the fix by comparing the classification logits between eager and compiled modes using torch.allclose with a suitable tolerance.
  • To mitigate the issue, consider adding a small tolerance to the comparison of classification logits to account for minor precision differences.
  • Review the PyTorch version and CUDA configuration to ensure compatibility and consider upgrading to a newer version if available.

Example

def e8m0_quantize(self, x):
    # Consider using torch.round instead of manual rounding
    x = x.contiguous()
    x_bits = x.view(torch.int32)
    biased_exp = (x_bits >> 23) & 0xFF
    mantissa = x_bits & 0x7FFFFF
    needs_round_up = mantissa != 0
    e8m0_biased = biased_exp + needs_round_up.to(torch.int32)
    e8m0_biased = torch.clamp(e8m0_biased, 0, 255)
    return e8m0_biased.to(torch.uint8)

Notes

The provided code snippet and analysis suggest that the issue is related to the precision differences introduced by the Inductor backend. However, without further information or testing, it is difficult to provide a definitive fix. The suggested modifications to the e8m0_quantize method are intended to reduce precision differences, but may require additional testing and verification.

Recommendation

Apply a workaround by modifying the e8m0_quantize method to reduce precision differences, as the root cause is likely related to the Inductor backend's handling of bitwise operations and float precision.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.compile` produces different classification output for model using E8M0 bit-manipulation quantization (`view(int32) → bitshift → clamp → uint8`) [1 comments, 1 participants]