pytorch - ✅(Solved) Fix torch.cond()'s true_fn and false_fn are not allowed to return None [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181891Fetched 2026-04-30 06:17:55
View on GitHub
Comments
0
Participants
1
Timeline
51
Reactions
0
Author
Participants
Assignees
Timeline (top)
mentioned ×21subscribed ×21labeled ×7assigned ×1

Fix Action

Fix / Workaround

import torch
from torch.utils.data import DataLoader, Dataset
from torch._higher_order_ops.cudagraph_conditional_nodes import (
    ControlFlowOpWarmupDispatchMode,
    CUDAGraphCaptureControlFlowOpDispatchMode,
)

for i, (X, y) in enumerate(data_loader):
    if i == 0:
        with ControlFlowOpWarmupDispatchMode():
            loss = train(X.cuda(), y.cuda())
        print(loss.item())
        del loss
        continue
    if i == 1:
        X_static = torch.empty_like(X, device="cuda")
        y_static = torch.empty_like(y, device="cuda")
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g, capture_error_mode="thread_local"), CUDAGraphCaptureControlFlowOpDispatchMode():
            loss_static = train(X_static, y_static)
    if i == 2:
        torch.cuda.cudart().cudaProfilerStart()
    X_static.copy_(X, non_blocking=True)
    y_static.copy_(y, non_blocking=True)
    g.replay()
    print(loss_static.item())

PR fix notes

PR #181915: Support returning None from torch.cond() true_fn and false_fn

Description (problem / solution / changelog)

This allows us to run conditional code that only mutates inputs or captured variables, but does not return anything.

While making this change, I realized as well that it should be disallowed for true_fn and false_fn to return different non-tensor outputs in the same output slot when using torch.cond() during cuda graph capture. This is banned because a different non-tensor output could cause the cuda graph to be different depending upon which path was taken. I now raise an explicit error alerting the user. Previously, you would simply get a runtime crash if you tried to return a non-tensor output at all, even if it was the same for both true_fn and false_fn.

I use either claude-code or codex to help make these edits. Forgot which though.

Fixes #181891

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo @azahed98 @ydwu4

Changed files

  • test/functorch/test_control_flow.py (modified, +64/-0)
  • torch/_dynamo/graph_break_registry.json (modified, +8/-0)
  • torch/_dynamo/variables/higher_order_ops.py (modified, +2/-2)
  • torch/_higher_order_ops/cudagraph_conditional_nodes.py (modified, +10/-0)

Code Example

import torch
from torch.utils.data import DataLoader, Dataset
from torch._higher_order_ops.cudagraph_conditional_nodes import (
    ControlFlowOpWarmupDispatchMode,
    CUDAGraphCaptureControlFlowOpDispatchMode,
)

class RandomDataset(Dataset):
    def __init__(self, n_samples=256, in_dim=16, out_dim=4):
        self.n_samples = n_samples
        self.in_dim = in_dim
        self.out_dim = out_dim

    def __len__(self):
        return self.n_samples

    def __getitem__(self, _):
        return torch.randn(self.in_dim), torch.randn(self.out_dim)

model = torch.nn.Linear(16, 4).cuda()
criterion = torch.nn.MSELoss()
data_loader = DataLoader(RandomDataset(), batch_size=32, shuffle=False,
                         pin_memory=True, num_workers=1)


lr = torch.tensor(0.1, device="cuda")
loss_scale = torch.tensor(2.0**8, device="cuda")

def train(X, y):
    with torch.autograd.grad_mode.set_multithreading_enabled(False):
        global loss_scale
        model.zero_grad(set_to_none=True)
        y_pred = model(X)
        loss = criterion(y_pred, y)
        (loss * loss_scale).backward()
        all_gradients_finite = torch.stack([
            torch.all(param.grad.isfinite()) for param in model.parameters() if param.grad is not None
        ]).all()
        with torch.no_grad():
            def true_fn():
                global lr, loss_scale
                lr.mul_(0.99999)
                inv_scale = loss_scale.reciprocal()
                for param in model.parameters():
                    if param.grad is not None:
                        param.sub_(lr * param.grad * inv_scale)
                loss_scale.mul_(2.0)

            def false_fn():
                loss_scale.mul_(0.5)

            _ = torch.cond(all_gradients_finite, true_fn, false_fn)
        return loss

for i, (X, y) in enumerate(data_loader):
    if i == 0:
        with ControlFlowOpWarmupDispatchMode():
            loss = train(X.cuda(), y.cuda())
        print(loss.item())
        del loss
        continue
    if i == 1:
        X_static = torch.empty_like(X, device="cuda")
        y_static = torch.empty_like(y, device="cuda")
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g, capture_error_mode="thread_local"), CUDAGraphCaptureControlFlowOpDispatchMode():
            loss_static = train(X_static, y_static)
    if i == 2:
        torch.cuda.cudart().cudaProfilerStart()
    X_static.copy_(X, non_blocking=True)
    y_static.copy_(y, non_blocking=True)
    g.replay()
    print(loss_static.item())
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.cond() currently fails if its true_fn and false_fn return None.

I realized this when I wrote the following code intended to implement loss scaling for mixed precision training inside a cuda graph:

import torch
from torch.utils.data import DataLoader, Dataset
from torch._higher_order_ops.cudagraph_conditional_nodes import (
    ControlFlowOpWarmupDispatchMode,
    CUDAGraphCaptureControlFlowOpDispatchMode,
)

class RandomDataset(Dataset):
    def __init__(self, n_samples=256, in_dim=16, out_dim=4):
        self.n_samples = n_samples
        self.in_dim = in_dim
        self.out_dim = out_dim

    def __len__(self):
        return self.n_samples

    def __getitem__(self, _):
        return torch.randn(self.in_dim), torch.randn(self.out_dim)

model = torch.nn.Linear(16, 4).cuda()
criterion = torch.nn.MSELoss()
data_loader = DataLoader(RandomDataset(), batch_size=32, shuffle=False,
                         pin_memory=True, num_workers=1)


lr = torch.tensor(0.1, device="cuda")
loss_scale = torch.tensor(2.0**8, device="cuda")

def train(X, y):
    with torch.autograd.grad_mode.set_multithreading_enabled(False):
        global loss_scale
        model.zero_grad(set_to_none=True)
        y_pred = model(X)
        loss = criterion(y_pred, y)
        (loss * loss_scale).backward()
        all_gradients_finite = torch.stack([
            torch.all(param.grad.isfinite()) for param in model.parameters() if param.grad is not None
        ]).all()
        with torch.no_grad():
            def true_fn():
                global lr, loss_scale
                lr.mul_(0.99999)
                inv_scale = loss_scale.reciprocal()
                for param in model.parameters():
                    if param.grad is not None:
                        param.sub_(lr * param.grad * inv_scale)
                loss_scale.mul_(2.0)

            def false_fn():
                loss_scale.mul_(0.5)

            _ = torch.cond(all_gradients_finite, true_fn, false_fn)
        return loss

for i, (X, y) in enumerate(data_loader):
    if i == 0:
        with ControlFlowOpWarmupDispatchMode():
            loss = train(X.cuda(), y.cuda())
        print(loss.item())
        del loss
        continue
    if i == 1:
        X_static = torch.empty_like(X, device="cuda")
        y_static = torch.empty_like(y, device="cuda")
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g, capture_error_mode="thread_local"), CUDAGraphCaptureControlFlowOpDispatchMode():
            loss_static = train(X_static, y_static)
    if i == 2:
        torch.cuda.cudart().cudaProfilerStart()
    X_static.copy_(X, non_blocking=True)
    y_static.copy_(y, non_blocking=True)
    g.replay()
    print(loss_static.item())

The current restriction is a reasonable oversight, since input mutations in torch.cond() were not supported until https://github.com/pytorch/pytorch/pull/172836. Without input mutations, it doesn't make any sense to return nothing. However, things have changed now. You can see clearly that, for gradient nan checks like above, there is no obvious value to return.

I will upload a fix for this shortly. FYI @ydwu4 @kshitij12345 @kiya00 It is likely that torch.while_loop() has a similar issue, though I haven't tested that yet.

Versions

top of tree

cc @mcarilli @ezyang @eellison @penguinwu @BoyuanFeng @chauhang @ydwu4 @bdhirsh @bobrenjc93 @aorenste

extent analysis

TL;DR

The issue can be worked around by modifying the true_fn and false_fn to return a value, even if it's just a placeholder.

Guidance

  • Identify the functions passed to torch.cond() and ensure they return a value.
  • Consider returning a placeholder value (e.g., torch.tensor(0)) if no meaningful value can be returned.
  • Review the code for any potential side effects of returning a placeholder value.
  • Be aware that torch.while_loop() may have a similar issue, although it has not been tested.

Notes

The provided code snippet is specific to the user's use case, and the suggested workaround may need to be adapted for other scenarios.

Recommendation

Apply workaround: Return a placeholder value from the functions passed to torch.cond() to avoid the current restriction.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix torch.cond()'s true_fn and false_fn are not allowed to return None [1 pull requests, 1 participants]