pytorch - 💡(How to fix) Fix Non-blocking DtH/HtD not ordered properly under compile [2 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#182190Fetched 2026-05-02 05:26:49
View on GitHub
Comments
2
Participants
1
Timeline
110
Reactions
1
Author
Participants
Timeline (top)
mentioned ×47subscribed ×47labeled ×11unlabeled ×3

Error Message

Error logs

Fix Action

Fix / Workaround

Layout matches AllToAllTokenDispatcher: rows are grouped by destination EP

rank, and each group contains this rank's token counts for that rank's

local experts.

counts = [ ((rank + 1) * (dst_rank + 2) + expert_idx) % 5 + 1 for dst_rank in range(world_size) for expert_idx in range(NUM_LOCAL_EXPERTS) ] num_tokens_per_expert = torch.tensor(counts, dtype=torch.int64, device="cuda") routed_input = torch.randn( sum(counts), HIDDEN_SIZE, device="cuda", dtype=torch.float32, )

Code Example

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Minimal repro for torch.compile scheduling of async D2H split-size reads.

Run on a single CUDA device:

    TORCH_LOGS=output_code python scripts/repro_compile_async_d2h_splits.py

By default the compiled function uses the problematic pattern:

    split_sizes_gpu.to("cpu", non_blocking=True).tolist()

Inspect the generated code for ``copy_(..., True)`` immediately followed by
``.item()`` reads. Pass ``--blocking`` to compare with ``non_blocking=False``.
"""

from __future__ import annotations

import argparse

import torch
import torch.distributed as dist
from torch.distributed._functional_collectives import (
    all_to_all_single,
    all_to_all_single_autograd,
)
from torch.testing._internal.distributed.fake_pg import FakeStore

FAKE_WORLD_SIZE = 4
NUM_LOCAL_EXPERTS = 4
HIDDEN_SIZE = 8


def repro_step(
    routed_input: torch.Tensor,
    num_tokens_per_expert: torch.Tensor,
    ep_degree: int,
    non_blocking_input_copy: bool,
) -> torch.Tensor:
    num_tokens_per_expert_group = all_to_all_single(
        num_tokens_per_expert,
        None,
        None,
        group=dist.group.WORLD,
    )
    num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
        num_tokens_per_expert_group
    )

    input_splits = (
        num_tokens_per_expert.view(ep_degree, -1)
        .sum(dim=1)
        .to(torch.device("cpu"), non_blocking=non_blocking_input_copy)
    )
    output_splits = (
        num_tokens_per_expert_group.view(ep_degree, -1)
        .sum(dim=1)
        .to(torch.device("cpu"), non_blocking=False)
    )

    input_splits_list = input_splits.tolist()
    output_splits_list = output_splits.tolist()

    return all_to_all_single_autograd(
        routed_input,
        output_splits_list,
        input_splits_list,
        group=dist.group.WORLD,
    )


parser = argparse.ArgumentParser()
parser.add_argument(
    "--blocking",
    action="store_true",
    help="Use non_blocking=False for the input_splits D2H copy for comparison.",
)
args = parser.parse_args()

if not torch.cuda.is_available():
    raise RuntimeError("This repro requires CUDA.")

torch.cuda.set_device(0)
dist.init_process_group(
    "fake",
    rank=0,
    world_size=FAKE_WORLD_SIZE,
    store=FakeStore(),
)
rank = dist.get_rank()
world_size = dist.get_world_size()

# Capture .tolist()/.item() as unbacked SymInts instead of graph-breaking.
torch._dynamo.config.capture_scalar_outputs = True

# Layout matches AllToAllTokenDispatcher: rows are grouped by destination EP
# rank, and each group contains this rank's token counts for that rank's
# local experts.
counts = [
    ((rank + 1) * (dst_rank + 2) + expert_idx) % 5 + 1
    for dst_rank in range(world_size)
    for expert_idx in range(NUM_LOCAL_EXPERTS)
]
num_tokens_per_expert = torch.tensor(counts, dtype=torch.int64, device="cuda")
routed_input = torch.randn(
    sum(counts),
    HIDDEN_SIZE,
    device="cuda",
    dtype=torch.float32,
)

compiled_step = torch.compile(
    repro_step,
    backend="inductor",
    fullgraph=True,
    dynamic=True,
)
out = compiled_step(
    routed_input,
    num_tokens_per_expert,
    world_size,
    not args.blocking,
)
torch.cuda.synchronize()

dist.destroy_process_group()

---

triton_red_fused_sum_view_1.run(arg1_1, buf19, ...)
  buf20.copy_(buf19, True)

  u0 = reinterpret_tensor(buf20, (), (), 0).item()
  u1 = reinterpret_tensor(buf20, (), (), 1).item()
  u2 = reinterpret_tensor(buf20, (), (), 2).item()
  u3 = reinterpret_tensor(buf20, (), (), 3).item()

  buf25 = torch.ops._c10d_functional.all_to_all_single.default(
      arg5_1,
      [u4, u5, u6, u7],  # output_splits
      [u0, u1, u2, u3],  # input_splits
      "0",
  )

---

triton_red_fused_sum_view_1.run(arg1_1, buf19, ...)
  buf20.copy_(buf19, False)

  u0 = reinterpret_tensor(buf20, (), (), 0).item()
  u1 = reinterpret_tensor(buf20, (), (), 1).item()
  u2 = reinterpret_tensor(buf20, (), (), 2).item()
  u3 = reinterpret_tensor(buf20, (), (), 3).item()

  buf25 = torch.ops._c10d_functional.all_to_all_single.default(
      arg5_1,
      [u4, u5, u6, u7],
      [u0, u1, u2, u3],
      "0",
  )
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

https://github.com/pytorch/torchtitan/pull/2951

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Minimal repro for torch.compile scheduling of async D2H split-size reads.

Run on a single CUDA device:

    TORCH_LOGS=output_code python scripts/repro_compile_async_d2h_splits.py

By default the compiled function uses the problematic pattern:

    split_sizes_gpu.to("cpu", non_blocking=True).tolist()

Inspect the generated code for ``copy_(..., True)`` immediately followed by
``.item()`` reads. Pass ``--blocking`` to compare with ``non_blocking=False``.
"""

from __future__ import annotations

import argparse

import torch
import torch.distributed as dist
from torch.distributed._functional_collectives import (
    all_to_all_single,
    all_to_all_single_autograd,
)
from torch.testing._internal.distributed.fake_pg import FakeStore

FAKE_WORLD_SIZE = 4
NUM_LOCAL_EXPERTS = 4
HIDDEN_SIZE = 8


def repro_step(
    routed_input: torch.Tensor,
    num_tokens_per_expert: torch.Tensor,
    ep_degree: int,
    non_blocking_input_copy: bool,
) -> torch.Tensor:
    num_tokens_per_expert_group = all_to_all_single(
        num_tokens_per_expert,
        None,
        None,
        group=dist.group.WORLD,
    )
    num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
        num_tokens_per_expert_group
    )

    input_splits = (
        num_tokens_per_expert.view(ep_degree, -1)
        .sum(dim=1)
        .to(torch.device("cpu"), non_blocking=non_blocking_input_copy)
    )
    output_splits = (
        num_tokens_per_expert_group.view(ep_degree, -1)
        .sum(dim=1)
        .to(torch.device("cpu"), non_blocking=False)
    )

    input_splits_list = input_splits.tolist()
    output_splits_list = output_splits.tolist()

    return all_to_all_single_autograd(
        routed_input,
        output_splits_list,
        input_splits_list,
        group=dist.group.WORLD,
    )


parser = argparse.ArgumentParser()
parser.add_argument(
    "--blocking",
    action="store_true",
    help="Use non_blocking=False for the input_splits D2H copy for comparison.",
)
args = parser.parse_args()

if not torch.cuda.is_available():
    raise RuntimeError("This repro requires CUDA.")

torch.cuda.set_device(0)
dist.init_process_group(
    "fake",
    rank=0,
    world_size=FAKE_WORLD_SIZE,
    store=FakeStore(),
)
rank = dist.get_rank()
world_size = dist.get_world_size()

# Capture .tolist()/.item() as unbacked SymInts instead of graph-breaking.
torch._dynamo.config.capture_scalar_outputs = True

# Layout matches AllToAllTokenDispatcher: rows are grouped by destination EP
# rank, and each group contains this rank's token counts for that rank's
# local experts.
counts = [
    ((rank + 1) * (dst_rank + 2) + expert_idx) % 5 + 1
    for dst_rank in range(world_size)
    for expert_idx in range(NUM_LOCAL_EXPERTS)
]
num_tokens_per_expert = torch.tensor(counts, dtype=torch.int64, device="cuda")
routed_input = torch.randn(
    sum(counts),
    HIDDEN_SIZE,
    device="cuda",
    dtype=torch.float32,
)

compiled_step = torch.compile(
    repro_step,
    backend="inductor",
    fullgraph=True,
    dynamic=True,
)
out = compiled_step(
    routed_input,
    num_tokens_per_expert,
    world_size,
    not args.blocking,
)
torch.cuda.synchronize()

dist.destroy_process_group()

Error logs

Run that with TORCH_LOGS="output_code"

The difference is only the input_splits copy.

Non-blocking, problematic:

  triton_red_fused_sum_view_1.run(arg1_1, buf19, ...)
  buf20.copy_(buf19, True)

  u0 = reinterpret_tensor(buf20, (), (), 0).item()
  u1 = reinterpret_tensor(buf20, (), (), 1).item()
  u2 = reinterpret_tensor(buf20, (), (), 2).item()
  u3 = reinterpret_tensor(buf20, (), (), 3).item()

  buf25 = torch.ops._c10d_functional.all_to_all_single.default(
      arg5_1,
      [u4, u5, u6, u7],  # output_splits
      [u0, u1, u2, u3],  # input_splits
      "0",
  )

Blocking, expected safe version:

  triton_red_fused_sum_view_1.run(arg1_1, buf19, ...)
  buf20.copy_(buf19, False)

  u0 = reinterpret_tensor(buf20, (), (), 0).item()
  u1 = reinterpret_tensor(buf20, (), (), 1).item()
  u2 = reinterpret_tensor(buf20, (), (), 2).item()
  u3 = reinterpret_tensor(buf20, (), (), 3).item()

  buf25 = torch.ops._c10d_functional.all_to_all_single.default(
      arg5_1,
      [u4, u5, u6, u7],
      [u0, u1, u2, u3],
      "0",
  )

So for blocking, the sync happens at copy_(..., False), before .item() materializes split sizes and before those split sizes are consumed by the routed-input all-to-all.

For non-blocking, Inductor emits copy_(..., True) and immediately reads u0-u3; that is the bad ordering.

Versions

main

cc @ezyang @gchanan @kadeng @msaroufim @ptrblck @eqy @jerryzh168 @tinglvv @nWEIdia @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @muchulee8 @amjames @aakhundov @coconutruben @jataylo @Lucaskabela @azahed98

extent analysis

TL;DR

The issue can be fixed by ensuring that the copy_ operation is blocking before reading the tensor values.

Guidance

  • The problematic pattern is the non-blocking copy_ operation followed by immediate reads, which can cause incorrect results.
  • To fix this, set non_blocking_input_copy to False in the repro_step function to ensure that the copy_ operation is blocking.
  • Verify that the generated code no longer contains the problematic pattern by running the script with TORCH_LOGS="output_code".
  • Compare the results with the blocking version by running the script with the --blocking flag.

Example

input_splits = (
    num_tokens_per_expert.view(ep_degree, -1)
    .sum(dim=1)
    .to(torch.device("cpu"), non_blocking=False)  # Set non_blocking to False
)

Notes

  • The issue is specific to the Inductor backend and the torch.compile function.
  • The fix may not apply to other backends or use cases.

Recommendation

Apply the workaround by setting non_blocking_input_copy to False in the repro_step function, as this ensures that the copy_ operation is blocking and the reads are correct.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix Non-blocking DtH/HtD not ordered properly under compile [2 comments, 1 participants]