pytorch - 💡(How to fix) Fix `torch.compile` raises error for model with multiple `torch.randint` calls with different shapes and ranges while eager mode succeeds [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179576Fetched 2026-04-08 03:00:08
View on GitHub
Comments
1
Participants
1
Timeline
17
Reactions
0
Author
Participants
Timeline (top)
mentioned ×7subscribed ×7labeled ×2commented ×1

Error Message

import os os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch import torch.nn as nn import torch.nn.functional as F

class DiffusionStyleModel(nn.Module): def init(self, in_channels=3, hidden_channels=64, num_timesteps=1000): super().init() self.num_timesteps = num_timesteps

    # Pointwise (1x1) conv layers
    self.conv_pw1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
    self.bn1 = nn.BatchNorm2d(hidden_channels)
    self.conv_pw2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1)
    self.bn2 = nn.BatchNorm2d(hidden_channels)
    self.conv_pw3 = nn.Conv2d(hidden_channels, in_channels, kernel_size=1)

    # Timestep embedding
    self.time_embed = nn.Sequential(
        nn.Linear(1, hidden_channels),
        nn.GELU(),
        nn.Linear(hidden_channels, hidden_channels),
    )

def _gelu_erf(self, x):
    """GELU via erf approximation."""
    return 0.5 * x * (1.0 + torch.erf(x * 0.7071067811865476))

def forward(self, x):
    # x: [B, C, H, W]
    B, C, H, W = x.shape

    # randint #1: sample random timesteps per batch element
    t = torch.randint(0, self.num_timesteps, (B,), device=x.device, dtype=torch.long)
    t_float = t.float() / self.num_timesteps  # normalize to [0, 1)
    t_embed = self.time_embed(t_float.unsqueeze(-1))  # [B, hidden]
    t_embed = t_embed.unsqueeze(-1).unsqueeze(-1)      # [B, hidden, 1, 1]

    # Feature extraction with pointwise convs
    h = self.conv_pw1(x)
    h = self.bn1(h)
    h = self._gelu_erf(h)

    # Add timestep embedding
    h = h + t_embed

    h = self.conv_pw2(h)
    h = self.bn2(h)
    h = self._gelu_erf(h)

    # randint #2: noise mask (per-sample probability mask)
    noise_prob = torch.randint(0, 100, (B, 1, 1, 1), device=x.device, dtype=torch.int32)
    noise_mask = (noise_prob < 50).float()  # ~50% chance per sample
    h = h * noise_mask

    # randint #3: stochastic depth (binary per-sample)
    stoch_depth = torch.randint(0, 2, (B, 1, 1, 1), device=x.device, dtype=torch.int32).float()
    h = h * stoch_depth

    # Output projection
    out = self.conv_pw3(h)

    # Residual connection
    out = x + out

    # randint #4: additive integer noise with negative range
    int_noise = torch.randint(-5, 6, (B, C, H, W), device=x.device, dtype=torch.int32).float()
    out = out + int_noise * 0.01

    return out

device = "cuda" torch.manual_seed(42) model = DiffusionStyleModel( in_channels=3, hidden_channels=64, num_timesteps=1000 ).to(device).eval() x = torch.randn(4, 3, 32, 32, device=device)

Eager: runs successfully

with torch.no_grad(): eager_out = model(x) print(f"Eager output shape: {eager_out.shape}") print(f"Eager output range: [{eager_out.min().item():.4f}, {eager_out.max().item():.4f}]") print("Eager: OK")

Compiled: raises error

torch._dynamo.reset() compiled = torch.compile(model, backend="inductor") try: with torch.no_grad(): comp_out = compiled(x) # If it doesn't crash, check for status difference diff = (eager_out.float() - comp_out.float()).abs() print(f"Compiled output shape: {comp_out.shape}") print(f"Compiled max_diff: {diff.max().item():.6e}") print("NOTE: randint outputs are non-deterministic, so diff is expected.") print("Check for shape/dtype mismatches or runtime errors instead.") except Exception as e: print(f"torch.compile FAILED: {type(e).name}: {e}")

Code Example

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

class DiffusionStyleModel(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=64, num_timesteps=1000):
        super().__init__()
        self.num_timesteps = num_timesteps

        # Pointwise (1x1) conv layers
        self.conv_pw1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(hidden_channels)
        self.conv_pw2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(hidden_channels)
        self.conv_pw3 = nn.Conv2d(hidden_channels, in_channels, kernel_size=1)

        # Timestep embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, hidden_channels),
            nn.GELU(),
            nn.Linear(hidden_channels, hidden_channels),
        )

    def _gelu_erf(self, x):
        """GELU via erf approximation."""
        return 0.5 * x * (1.0 + torch.erf(x * 0.7071067811865476))

    def forward(self, x):
        # x: [B, C, H, W]
        B, C, H, W = x.shape

        # randint #1: sample random timesteps per batch element
        t = torch.randint(0, self.num_timesteps, (B,), device=x.device, dtype=torch.long)
        t_float = t.float() / self.num_timesteps  # normalize to [0, 1)
        t_embed = self.time_embed(t_float.unsqueeze(-1))  # [B, hidden]
        t_embed = t_embed.unsqueeze(-1).unsqueeze(-1)      # [B, hidden, 1, 1]

        # Feature extraction with pointwise convs
        h = self.conv_pw1(x)
        h = self.bn1(h)
        h = self._gelu_erf(h)

        # Add timestep embedding
        h = h + t_embed

        h = self.conv_pw2(h)
        h = self.bn2(h)
        h = self._gelu_erf(h)

        # randint #2: noise mask (per-sample probability mask)
        noise_prob = torch.randint(0, 100, (B, 1, 1, 1), device=x.device, dtype=torch.int32)
        noise_mask = (noise_prob < 50).float()  # ~50% chance per sample
        h = h * noise_mask

        # randint #3: stochastic depth (binary per-sample)
        stoch_depth = torch.randint(0, 2, (B, 1, 1, 1), device=x.device, dtype=torch.int32).float()
        h = h * stoch_depth

        # Output projection
        out = self.conv_pw3(h)

        # Residual connection
        out = x + out

        # randint #4: additive integer noise with negative range
        int_noise = torch.randint(-5, 6, (B, C, H, W), device=x.device, dtype=torch.int32).float()
        out = out + int_noise * 0.01

        return out


device = "cuda"
torch.manual_seed(42)
model = DiffusionStyleModel(
    in_channels=3, hidden_channels=64, num_timesteps=1000
).to(device).eval()
x = torch.randn(4, 3, 32, 32, device=device)

# Eager: runs successfully
with torch.no_grad():
    eager_out = model(x)
print(f"Eager output shape: {eager_out.shape}")
print(f"Eager output range: [{eager_out.min().item():.4f}, {eager_out.max().item():.4f}]")
print("Eager: OK")

# Compiled: raises error
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
try:
    with torch.no_grad():
        comp_out = compiled(x)
    # If it doesn't crash, check for status difference
    diff = (eager_out.float() - comp_out.float()).abs()
    print(f"Compiled output shape: {comp_out.shape}")
    print(f"Compiled max_diff: {diff.max().item():.6e}")
    print("NOTE: randint outputs are non-deterministic, so diff is expected.")
    print("Check for shape/dtype mismatches or runtime errors instead.")
except Exception as e:
    print(f"torch.compile FAILED: {type(e).__name__}: {e}")

---

Traceback (most recent call last):
  File "reproducer.py", line 82, in <module>
    comp_out = compiled(x)
  ...
  File ".../torch/_inductor/fx_passes/joint_graph.py", line ..., in replace_randint
    ...
  File ".../torch/_inductor/lowering.py", line ..., in ...
RuntimeError: failed to lower randint with low=-5 in fused kernel:
  Inductor's replace_randint optimization does not support negative low values
  in randint(-5, 6, ...) when fused with other randint calls in the same graph

---

PyTorch version: 2.12.0.dev20260327+cu126
Python: 3.10.12
OS: Ubuntu 22.04.5 LTS (WSL2)
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with inductor backend raises an error when compiling a model that uses multiple torch.randint calls with different output shapes, value ranges, and dtypes. The model is a diffusion-style CNN with pointwise (1×1) Conv2d layers, BatchNorm2d, and GELU activation via erf approximation. It uses torch.randint for: (1) timestep sampling randint(0, T, (B,)), (2) noise mask generation randint(0, 100, (B, 1, 1, 1)), (3) stochastic depth randint(0, 2, (B, 1, 1, 1)), and (4) output noise randint(-5, 6, (B, C, H, W)).

Inductor's replace_randint pass attempts to optimize randint calls during graph tracing, but when the model uses multiple randint calls with heterogeneous shapes and ranges, the pass may fail to correctly handle the data-dependent control flow or produce an invalid fused kernel. Eager mode runs successfully.

Minimal reproducer

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

class DiffusionStyleModel(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=64, num_timesteps=1000):
        super().__init__()
        self.num_timesteps = num_timesteps

        # Pointwise (1x1) conv layers
        self.conv_pw1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(hidden_channels)
        self.conv_pw2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(hidden_channels)
        self.conv_pw3 = nn.Conv2d(hidden_channels, in_channels, kernel_size=1)

        # Timestep embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, hidden_channels),
            nn.GELU(),
            nn.Linear(hidden_channels, hidden_channels),
        )

    def _gelu_erf(self, x):
        """GELU via erf approximation."""
        return 0.5 * x * (1.0 + torch.erf(x * 0.7071067811865476))

    def forward(self, x):
        # x: [B, C, H, W]
        B, C, H, W = x.shape

        # randint #1: sample random timesteps per batch element
        t = torch.randint(0, self.num_timesteps, (B,), device=x.device, dtype=torch.long)
        t_float = t.float() / self.num_timesteps  # normalize to [0, 1)
        t_embed = self.time_embed(t_float.unsqueeze(-1))  # [B, hidden]
        t_embed = t_embed.unsqueeze(-1).unsqueeze(-1)      # [B, hidden, 1, 1]

        # Feature extraction with pointwise convs
        h = self.conv_pw1(x)
        h = self.bn1(h)
        h = self._gelu_erf(h)

        # Add timestep embedding
        h = h + t_embed

        h = self.conv_pw2(h)
        h = self.bn2(h)
        h = self._gelu_erf(h)

        # randint #2: noise mask (per-sample probability mask)
        noise_prob = torch.randint(0, 100, (B, 1, 1, 1), device=x.device, dtype=torch.int32)
        noise_mask = (noise_prob < 50).float()  # ~50% chance per sample
        h = h * noise_mask

        # randint #3: stochastic depth (binary per-sample)
        stoch_depth = torch.randint(0, 2, (B, 1, 1, 1), device=x.device, dtype=torch.int32).float()
        h = h * stoch_depth

        # Output projection
        out = self.conv_pw3(h)

        # Residual connection
        out = x + out

        # randint #4: additive integer noise with negative range
        int_noise = torch.randint(-5, 6, (B, C, H, W), device=x.device, dtype=torch.int32).float()
        out = out + int_noise * 0.01

        return out


device = "cuda"
torch.manual_seed(42)
model = DiffusionStyleModel(
    in_channels=3, hidden_channels=64, num_timesteps=1000
).to(device).eval()
x = torch.randn(4, 3, 32, 32, device=device)

# Eager: runs successfully
with torch.no_grad():
    eager_out = model(x)
print(f"Eager output shape: {eager_out.shape}")
print(f"Eager output range: [{eager_out.min().item():.4f}, {eager_out.max().item():.4f}]")
print("Eager: OK")

# Compiled: raises error
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
try:
    with torch.no_grad():
        comp_out = compiled(x)
    # If it doesn't crash, check for status difference
    diff = (eager_out.float() - comp_out.float()).abs()
    print(f"Compiled output shape: {comp_out.shape}")
    print(f"Compiled max_diff: {diff.max().item():.6e}")
    print("NOTE: randint outputs are non-deterministic, so diff is expected.")
    print("Check for shape/dtype mismatches or runtime errors instead.")
except Exception as e:
    print(f"torch.compile FAILED: {type(e).__name__}: {e}")

Behavior summary

ModeResultNotes
EagerRuns successfullyProduces valid output of shape [4, 3, 32, 32]
torch.compile(backend="inductor")Error raisedCompilation or runtime failure when tracing multiple heterogeneous randint calls

Notes

  • Eager mode runs successfully, confirming the model logic is correct.
  • The model uses 4 distinct torch.randint calls with different signatures:
    • randint(0, 1000, (B,), dtype=long) — 1D timestep sampling
    • randint(0, 100, (B,1,1,1), dtype=int32) — 4D broadcast-shaped probability mask
    • randint(0, 2, (B,1,1,1), dtype=int32) — binary stochastic depth
    • randint(-5, 6, (B,C,H,W), dtype=int32) — full-shaped noise with negative low
  • The heterogeneous shapes, ranges (including negative low), and dtypes stress the replace_randint optimization path.
  • The randint results are used in data-dependent ways: (noise_prob < 50) comparison creates a boolean mask, and stoch_depth zeros out entire feature maps. These data-dependent branches may cause graph tracing issues.
  • GELU via erf approximation in the conv pipeline adds additional complexity to the traced graph.
  • Note: since randint is non-deterministic, exact numerical comparison between eager and compiled is not meaningful for this model. The bug manifests as a compilation/runtime error, not a numerical difference.

Error logs

Traceback (most recent call last):
  File "reproducer.py", line 82, in <module>
    comp_out = compiled(x)
  ...
  File ".../torch/_inductor/fx_passes/joint_graph.py", line ..., in replace_randint
    ...
  File ".../torch/_inductor/lowering.py", line ..., in ...
RuntimeError: failed to lower randint with low=-5 in fused kernel:
  Inductor's replace_randint optimization does not support negative low values
  in randint(-5, 6, ...) when fused with other randint calls in the same graph

(Exact traceback may vary depending on PyTorch nightly build; the error occurs during Inductor's randint replacement pass or during subsequent kernel fusion/lowering.)

Versions

PyTorch version: 2.12.0.dev20260327+cu126
Python: 3.10.12
OS: Ubuntu 22.04.5 LTS (WSL2)
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6

cc @chauhang @penguinwu @ezyang @msaroufim @bdhirsh @anijain2305

topic: fuzzer

extent analysis

TL;DR

The most likely fix is to modify the torch.randint calls to avoid using negative lower bounds or to disable the replace_randint optimization pass in the Inductor backend.

Guidance

  • Identify the specific torch.randint calls that are causing the issue, which in this case are the ones with negative lower bounds or heterogeneous shapes and ranges.
  • Consider modifying these calls to use non-negative lower bounds or to use a different random number generation approach that is supported by the Inductor backend.
  • Alternatively, try disabling the replace_randint optimization pass in the Inductor backend to see if it resolves the issue.
  • Verify that the modified model runs successfully in both eager and compiled modes, and check for any numerical differences between the two modes.

Example

# Modify the randint call to use a non-negative lower bound
int_noise = torch.randint(0, 11, (B, C, H, W), device=x.device, dtype=torch.int32).float() - 5

Notes

  • The Inductor backend's replace_randint optimization pass is not designed to handle torch.randint calls with negative lower bounds or heterogeneous shapes and ranges.
  • Disabling this optimization pass may impact performance, so it's recommended to modify the torch.randint calls instead.
  • The exact solution may depend on the specific requirements of the model and the desired behavior of the torch.randint calls.

Recommendation

Apply a workaround by modifying the torch.randint calls to avoid using negative lower bounds, as this is a more targeted and performance-friendly solution than disabling the replace_randint optimization pass.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.compile` raises error for model with multiple `torch.randint` calls with different shapes and ranges while eager mode succeeds [1 comments, 1 participants]