pytorch - 💡(How to fix) Fix [Discussion] Parallelizing serial for-loops with free-threaded Python — and the `torch.distributed` challenge [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180550Fetched 2026-04-17 08:26:27
View on GitHub
Comments
0
Participants
1
Timeline
28
Reactions
0
Participants
Timeline (top)
mentioned ×13subscribed ×13labeled ×2

Code Example

# bench.py — run with: python3.14t bench.py (requires free-threaded Python + CUDA)
import sys
import time
import threading
import torch

def parallel_map(fn, items):
    results = [None] * len(items)
    def worker(i):
        results[i] = fn(items[i])
    threads = [threading.Thread(target=worker, args=(i,)) for i in range(len(items))]
    for t in threads:
        t.start()
    for t in threads:
        t.join()
    return results

def cpu_burn(n):
    s = 0
    for i in range(n):
        s += i * i
    return s

def main():
    device = torch.device("cuda:0")
    wa = torch.randn(256, 256, device=device)
    wb = torch.randn(256, 256, device=device)
    wc = torch.randn(256, 1, device=device)

    N = 6
    inputs = [torch.randn(256, 256, device=device) for _ in range(N)]

    def work(x):
        h = x
        for _ in range(10):
            h = h @ wa
            h = torch.relu(h)
            h = h @ wb
            h = torch.relu(h)
        cpu_burn(16_000)
        return (h @ wc).sum()

    # serial
    def run_serial():
        return [work(inputs[i]) for i in range(N)]

    # parallel
    def run_parallel():
        return parallel_map(work, inputs)

    for _ in range(3):
        run_serial()
        run_parallel()
    torch.cuda.synchronize()

    times = []
    for _ in range(7):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        run_serial()
        torch.cuda.synchronize()
        times.append(time.perf_counter() - t0)
    serial_ms = sum(times) / len(times) * 1000

    times = []
    for _ in range(7):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        run_parallel()
        torch.cuda.synchronize()
        times.append(time.perf_counter() - t0)
    parallel_ms = sum(times) / len(times) * 1000

    print(f"Python {sys.version}")
    print(f"GIL enabled: {sys._is_gil_enabled()}")
    print(f"  serial (for-loop):    {serial_ms:7.2f} ms")
    print(f"  parallel (6 threads): {parallel_ms:7.2f} ms  ({serial_ms/parallel_ms:.1f}x)")

if __name__ == "__main__":
    main()

---

Python 3.14.3 free-threading build
GIL enabled: False
  serial (for-loop):      6.70 ms
  parallel (6 threads):   2.72 ms  (2.5x)

---

def work(x):
    ...  # same GPU compute + Python logic
    loss = (h @ wc).sum().reshape(1)
    dist.all_reduce(loss)   # <-- this breaks multi-threading
    return loss

---

# before
results = [work(inputs[i]) for i in range(N)]

# after
results = parallel_map(work, inputs, max_workers=N)
RAW_BUFFERClick to expand / collapse

With free-threaded Python (PEP 703, officially supported in CPython 3.14), true multi-threaded parallelism in pure Python is becoming practical. One pattern that immediately benefits is the serial for-loop over N independent items — each iteration doing a mix of CPU-side logic and GPU ops. With nogil, you can just run them in parallel threads.

Here's a self-contained demo. work() does some GPU compute (matrix multiplications + relu) interleaved with CPU-side Python work (target matching, label encoding, etc.). We define a trivial parallel_map using raw threads and compare it against a serial for-loop:

# bench.py — run with: python3.14t bench.py (requires free-threaded Python + CUDA)
import sys
import time
import threading
import torch

def parallel_map(fn, items):
    results = [None] * len(items)
    def worker(i):
        results[i] = fn(items[i])
    threads = [threading.Thread(target=worker, args=(i,)) for i in range(len(items))]
    for t in threads:
        t.start()
    for t in threads:
        t.join()
    return results

def cpu_burn(n):
    s = 0
    for i in range(n):
        s += i * i
    return s

def main():
    device = torch.device("cuda:0")
    wa = torch.randn(256, 256, device=device)
    wb = torch.randn(256, 256, device=device)
    wc = torch.randn(256, 1, device=device)

    N = 6
    inputs = [torch.randn(256, 256, device=device) for _ in range(N)]

    def work(x):
        h = x
        for _ in range(10):
            h = h @ wa
            h = torch.relu(h)
            h = h @ wb
            h = torch.relu(h)
        cpu_burn(16_000)
        return (h @ wc).sum()

    # serial
    def run_serial():
        return [work(inputs[i]) for i in range(N)]

    # parallel
    def run_parallel():
        return parallel_map(work, inputs)

    for _ in range(3):
        run_serial()
        run_parallel()
    torch.cuda.synchronize()

    times = []
    for _ in range(7):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        run_serial()
        torch.cuda.synchronize()
        times.append(time.perf_counter() - t0)
    serial_ms = sum(times) / len(times) * 1000

    times = []
    for _ in range(7):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        run_parallel()
        torch.cuda.synchronize()
        times.append(time.perf_counter() - t0)
    parallel_ms = sum(times) / len(times) * 1000

    print(f"Python {sys.version}")
    print(f"GIL enabled: {sys._is_gil_enabled()}")
    print(f"  serial (for-loop):    {serial_ms:7.2f} ms")
    print(f"  parallel (6 threads): {parallel_ms:7.2f} ms  ({serial_ms/parallel_ms:.1f}x)")

if __name__ == "__main__":
    main()
Python 3.14.3 free-threading build
GIL enabled: False
  serial (for-loop):      6.70 ms
  parallel (6 threads):   2.72 ms  (2.5x)

The workload mix (GPU compute + CPU-side logic) approximates a pattern we've found in real training code that's well-suited for multi-threaded parallelization. In many cases the GPU kernels themselves execute fast, but the CPU-side work between them — Python logic, pre/postprocessing, data preparation — becomes the bottleneck under the GIL. With nogil, these CPU portions genuinely run in parallel across threads. Since this kind of forward computation is typically on the critical path of training, the speedup translates directly to end-to-end iteration time.

When threads operate on independent data with no shared mutable state, this kind of parallelization is straightforward — as the demo above shows, a few lines of threading code is all you need. But things get complicated when work() contains cross-process communication code.

The problem: what if work() calls dist.all_reduce?

In distributed training, these loop iterations may contain collective operations — e.g. dist.all_reduce to sync a loss scalar. Something like:

def work(x):
    ...  # same GPU compute + Python logic
    loss = (h @ wc).sum().reshape(1)
    dist.all_reduce(loss)   # <-- this breaks multi-threading
    return loss

The simple thread-based parallel_map above can't handle this. Two fundamental issues:

First, NCCL is not thread-safe for concurrent calls on the same communicator (docs). Using separate communicators per thread on the same device doesn't work either — cudaMemcpy implicit synchronization and shared memory key conflicts make it impractical (NCCL #195, #1520, #1174).

Second, there's a collective ordering problem: collectives are matched across ranks by call order. With threads, rank 0 might all_reduce item 3 while rank 1 all_reduces item 1 — silent data corruption or hangs.

For users, this is a frustrating situation. The Python/GPU parts would parallelize trivially with the threading approach shown above, but a single dist.all_reduce per iteration makes the whole thing unsafe. In the loss computation pattern we're trying to optimize, the tensors involved in these collectives are very small — typically just a scalar loss or a short vector. The communication itself is negligible compared to the CPU-side compute that dominates each iteration. So the actual data transfer is tiny, but its mere presence in the code path blocks you from using threads altogether. Working around this manually means extracting collective calls out of work(), building cross-thread synchronization, enforcing deterministic ordering, and ensuring consistency across all ranks — non-trivial effort for what is otherwise clean, serial code.

Our prototype: collective interception in parallel_map

We prototyped a parallel_map that handles this transparently. The key idea is intercepting collective calls at the Python API level:

  • Workers run work(inputs[i]) in parallel, just like the threaded version above.
  • When a worker calls dist.all_reduce (or any collective), the call is intercepted via a thread-local hook and redirected to the main thread.
  • The main thread executes collectives in strict worker-id order (worker 0, then 1, then 2, ...), so all ranks see the same call sequence.
  • Each worker blocks only for its own collective, then resumes.

The interception is a lightweight thread-local check in distributed_c10d.py, inserted right after the existing __torch_function__ check — it doesn't affect tracing or FakeTensor paths. The whole implementation is about 400 lines (interceptor + thread coordinator + public API).

From the user's perspective, the only change is swapping the for-loop:

# before
results = [work(inputs[i]) for i in range(N)]

# after
results = parallel_map(work, inputs, max_workers=N)

work() itself doesn't change at all — the dist.all_reduce call inside stays exactly where it is. With this approach and the same workload as above (plus a scalar all_reduce per item), we measure similar speedup on 2x GPU with NCCL backend.

Questions

We looked through issues and dev-discuss but didn't find prior work on this. A few questions:

  • More broadly, what's PyTorch's roadmap for leveraging free-threaded Python beyond bug fixes and compatibility? We've seen discussion around multi-threaded DataLoader as one direction — are there other areas being explored, e.g. parallel execution within the training loop itself?
  • For the specific problem of multi-threaded collectives in torch.distributed — is there a known best practice or recommended approach? We went with intercepting collective calls at the Python API level and serializing them on the main thread, which works but we're curious if there's a better way.
  • If this approach seems reasonable, would it make sense to upstream something like this into PyTorch? We have a working prototype with tests (~400 lines) and would be glad to put up a PR if there's interest.

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @weifengpy

extent analysis

TL;DR

To enable multi-threaded parallelization with collective operations like dist.all_reduce in PyTorch, consider using a parallel_map function that intercepts collective calls at the Python API level and serializes them on the main thread.

Guidance

  • Identify the specific collective operations (e.g., dist.all_reduce) in your work() function that are blocking multi-threading.
  • Implement a parallel_map function that intercepts these collective calls and redirects them to the main thread, ensuring strict worker-id order.
  • Use a thread-local hook to intercept collective calls, and consider using a thread coordinator to manage the execution of collectives on the main thread.
  • Test your implementation with a variety of workloads and collective operations to ensure correctness and performance.

Example

def parallel_map(fn, inputs, max_workers):
    # Implement thread-local hook to intercept collective calls
    def interceptor(func, *args, **kwargs):
        if func == dist.all_reduce:
            # Redirect collective call to main thread
            return main_thread_executor.submit(func, *args, **kwargs)
        return func(*args, **kwargs)

    # Create thread-local hook
    threading.local().interceptor = interceptor

    # Run workers in parallel
    results = [None] * len(inputs)
    def worker(i):
        results[i] = fn(inputs[i])
    threads = [threading.Thread(target=worker, args=(i,)) for i in range(len(inputs))]
    for t in threads:
        t.start()
    for t in threads:
        t.join()
    return results

Notes

  • The implementation of the parallel_map function will depend on the specific requirements of your use case, including the type of collective operations and the desired level of parallelism.
  • Upstreaming this implementation into PyTorch may require additional testing, documentation, and review to ensure compatibility and correctness.

Recommendation

Apply the workaround by implementing a parallel_map function that intercepts collective calls at the Python API level and serializes them on the main thread, as this approach has been shown to work for the specific problem of multi-threaded collectives in torch.distributed.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [Discussion] Parallelizing serial for-loops with free-threaded Python — and the `torch.distributed` challenge [1 participants]