pytorch - ✅(Solved) Fix Add device_count() hook to DeviceTypeTestBase [1 pull requests, 3 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#184325Fetched 2026-05-20 03:39:17
View on GitHub
Comments
3
Participants
3
Timeline
45
Reactions
0
Author
Timeline (top)
mentioned ×17subscribed ×17labeled ×6commented ×3

Fix Action

Fixed

PR fix notes

PR #184333: [Test] Add device_count() hook to DeviceTypeTestBase

Description (problem / solution / changelog)

Fixes #184325

Summary

Add device_count() hook to DeviceTypeTestBase as a device-level extension point for querying available device count.

Changed files

  • test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_testing.py (modified, +17/-0)
  • torch/testing/_internal/common_device_type.py (modified, +5/-0)

Code Example

torch.cuda.device_count()
torch.hpu.device_count()
torch.xpu.device_count()
torch.accelerator.device_count()

---

class DeviceTypeTestBase(TestCase):
    @classmethod
    def device_count(cls) -> int:
        return torch.get_device_module(cls.device_type).device_count()
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

Motivation

Distributed tests currently contain many hardcoded device count lookups, for example:

torch.cuda.device_count()
torch.hpu.device_count()
torch.xpu.device_count()
torch.accelerator.device_count()

These lookups are scattered across distributed test bases, decorators, helpers, and world-size setup logic.

This makes distributed test infrastructure harder to reuse for out-of-tree backends and harder to make device-generic.

torch.cuda.device_count(), torch.hpu.device_count(), and torch.xpu.device_count() hardcode specific device families, while torch.accelerator.device_count() is tied to the current accelerator rather than the instantiated device-specific test class's device_type.

Device count is fundamentally a device-type-level property. For instantiated device-specific test classes, device count should instead be derived from the test class's device_type.

Proposal

Add a device_count() hook to DeviceTypeTestBase:

class DeviceTypeTestBase(TestCase):
    @classmethod
    def device_count(cls) -> int:
        return torch.get_device_module(cls.device_type).device_count()

This provides a single device-type-aware entry point for distributed test infrastructure to query device count.

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @weifengpy @fffrog

Alternatives

No response

Additional context

No response

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING