pytorch - ✅(Solved) Fix [RFC] Introduce the unified RNG-related API to torch.accelerator [1 pull requests, 2 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177010Fetched 2026-04-08 00:22:55
View on GitHub
Comments
2
Participants
3
Timeline
44
Reactions
2
Author
Timeline (top)
mentioned ×15subscribed ×15labeled ×6unsubscribed ×3

Fix Action

Fixed

PR fix notes

PR #169039: port freeze_rng_state() to accelerators

Description (problem / solution / changelog)

Refractor freeze_rng_state() for accelerators.

Changed files

  • test/nn/test_dropout.py (modified, +0/-3)
  • torch/testing/_utils.py (modified, +14/-4)
RAW_BUFFERClick to expand / collapse

Motivation

This RFC proposes adding unified RNG-related APIs to torch.accelerator, including:

  • torch.accelerator.initial_seed
  • torch.accelerator.seed
  • torch.accelerator.seed_all
  • torch.accelerator.manual_seed
  • torch.accelerator.manual_seed_all
  • torch.accelerator.set_rng_state
  • torch.accelerator.set_rng_state_all
  • torch.accelerator.get_rng_state
  • torch.accelerator.get_rng_state_all

The goal is to provide a backend-agnostic RNG surface that works consistently across in-tree backends (e.g., CUDA, XPU, MPS, MTIA, etc., where implemented) and out-of-tree backends integrating with the accelerator runtime.

Background

PyTorch basic runtime responsibilities for accelerator execution include:

  1. Device management (enumeration, current device, context switching)
  2. Streams (current stream, stream switching, synchronization)
  3. Events (record/wait/timing)
  4. Allocator and memory statistics
  5. Generators / RNG states

torch.accelerator already unifies the first four categories in a backend-agnostic way. And generalize the advanced runtime functionality, Graph capturing and replay, for accelerator on https://github.com/pytorch/pytorch/issues/158827.

Today, generator APIs are still mostly backend-specific (torch.cuda.*, torch.xpu.*). This creates API fragmentation and requires higher-level code to branch on backend type for common RNG tasks (seeding, RNG state save/restore, deterministic replay). We've listed the following pain points the developer/user faces.

  • Users writing backend-agnostic code must do backend-specific imports/branches for RNG.
  • Framework code (distributed, checkpointing, reproducibility utilities) cannot rely on a single RNG API across accelerators.
  • Porting code from one accelerator backend to another requires mechanical API migration.
  • In-tree backends duplicate similar generator helper logic.
  • Out-of-tree backends lack a clear contract to expose generator semantics under torch.accelerator.
  • Testing for backend-agnostic reproducibility is harder without a common API.

Now, torch.accelerator has already established a unified model for device, stream, event, and allocator/memory APIs. Extending it to RNG fills a visible gap and completes the runtime API story.

cc @pbelevich @albanD @EikanWang

extent analysis

Unified RNG APIs for torch.accelerator

Fix Plan

Step 1: Implement Unified RNG APIs

  • Create a new file torch/accelerator/rng.py with the following code:

import torch

class AcceleratorRNG: def init(self): self._seed = None

def initial_seed(self):
    return torch.manual_seed(torch.randint(0, 1000000, (1,)))

def seed(self):
    if self._seed is None:
        self._seed = torch.manual_seed(torch.randint(0, 1000000, (1,)))
    return self._seed

def seed_all(self):
    torch.manual_seed_all(torch.randint(0, 1000000, (1,)))

def manual_seed(self, seed):
    self._seed = torch.manual_seed(seed)

def manual_seed_all(self, seed):
    torch.manual_seed_all(seed)

def set_rng_state(self, state):
    torch.manual_seed(state)

def set_rng_state_all(self, state):
    torch.manual_seed_all(state)

def get_rng_state(self):
    return torch.get_rng_state()

def get_rng_state_all(self):
    return torch.get_rng_state_all()

*   Update `torch/accelerator/__init__.py` to import the new RNG class:

    ```python
from .rng import AcceleratorRNG

class Accelerator:
    def __init__(self):
        self.rng = AcceleratorRNG()

Step 2: Update torch.cuda and torch.xpu to use torch.accelerator

  • Update torch/cuda/__init__.py to use torch.accelerator for RNG operations:

from torch import accelerator

def manual_seed(seed): accelerator.rng.manual_seed(seed)

def manual_seed_all(seed): accelerator.rng.manual_seed_all(seed) ``

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING