pytorch - ✅(Solved) Fix Graph break on use_deterministic_algorithms [3 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179194Fetched 2026-04-08 02:32:56
View on GitHub
Comments
0
Participants
1
Timeline
70
Reactions
0
Author
Participants
Timeline (top)
mentioned ×32subscribed ×32labeled ×4cross-referenced ×2

Fix Action

Fixed

PR fix notes

PR #2775: [MoE] change torch.bmm back to scatter add

Description (problem / solution / changelog)

Change

scatter_add was replaced by torch.bmm in https://github.com/pytorch/torchtitan/pull/1974 due to its non determinism. However, bmm backward kernel was reported to be slow in https://github.com/pytorch/torchtitan/issues/2225. Although @drisspg has https://github.com/pytorch/pytorch/pull/176552 which greatly optimized the bmm backward time under compile, eager mode backward is still slow.

Therefore we try to switch back to scatter_add with torch.use_deterministic_algorithms(True), see more in https://docs.pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms. We tested out the determinism part using this test script. , updated test

Test with topk=2

llama4 debug model(topk=2), qwen3 debug model(topk=8)

NCCL_NVLS_ENABLE=0 NGPU=8 ./run_train.sh     --module llama4     --config llama4_debugmodel     --parallelism.pipeline_parallel_degree 1     --parallelism.data_parallel_shard_degree 4     --parallelism.tensor_parallel_degree 2     --parallelism.expert_parallel_degree 2     --parallelism.expert_tensor_parallel_degree 2      --metrics.enable_wandb --training.steps 100
<img width="1075" height="288" alt="image" src="https://github.com/user-attachments/assets/81e7523b-e833-4351-ad57-f5e1276807f3" />

Loss compare

NCCL_NVLS_ENABLE=0 python scripts/loss_compare.py . main --baseline-module='llama4' --baseline-config='llama4_debugmodel' --baseline-options="--parallelism.pipeline_parallel_degree 1 --parallelism.data_parallel_shard_degree 2 --parallelism.tensor_parallel_degree 2 --parallelism.expert_parallel_degree 4 --parallelism.expert_tensor_parallel_degree 1" --baseline-ngpus=4 --test-module='llama4' --test-config='llama4_debugmodel' --test-options="--parallelism.pipeline_parallel_degree 1 --parallelism.data_parallel_shard_degree 2 --parallelism.tensor_parallel_degree 2 --parallelism.expert_parallel_degree 4 --parallelism.expert_tensor_parallel_degree 1" --test-ngpus=4 --assert-equal --no-seed-checkpoint

Testing with llama4 debug model, scatter_add with torch deterministic mode on seems still slightly faster than torch.bmm.

cc @garrett361 @chelsea0x3b @rakkit @drisspg

Changed files

  • torchtitan/config/configs.py (modified, +1/-1)
  • torchtitan/distributed/expert_parallel.py (modified, +5/-1)
  • torchtitan/models/common/moe/moe.py (modified, +18/-25)

Code Example

import torch

@torch.compile(fullgraph=True)
def fn(x):
    torch.use_deterministic_algorithms(True, warn_only=True)
    res = x.scatter_add(0, torch.tensor([0, 1, 0]), torch.tensor([1.0, 2.0, 3.0]))
    torch.use_deterministic_algorithms(False, warn_only=True)
    return res

fn(torch.zeros(2))
RAW_BUFFERClick to expand / collapse

Repro:

import torch

@torch.compile(fullgraph=True)
def fn(x):
    torch.use_deterministic_algorithms(True, warn_only=True)
    res = x.scatter_add(0, torch.tensor([0, 1, 0]), torch.tensor([1.0, 2.0, 3.0]))
    torch.use_deterministic_algorithms(False, warn_only=True)
    return res

fn(torch.zeros(2))

Originally happened on https://github.com/pytorch/torchtitan/pull/2775#discussion_r3026156120.

IMO, niche usecase where you want to isolate a particularly bad op's numerics. Doesn't matter if you non-strict trace. If you run into this, and you care about numerics, you should probably just enable it for the whole run.

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @Lucaskabela @jataylo @azahed98

extent analysis

TL;DR

Enable deterministic algorithms for the entire run to ensure consistent numerics.

Guidance

  • The issue arises from isolating a specific operation's numerics using torch.use_deterministic_algorithms within a compiled function.
  • To verify the fix, run the code with deterministic algorithms enabled for the entire execution.
  • Consider enabling deterministic algorithms for the whole run if numerics are a concern, as suggested in the issue.
  • Be aware that this might impact performance, so it's essential to weigh the trade-offs.

Notes

The provided code snippet and issue discussion imply that enabling deterministic algorithms for the entire run is a viable workaround. However, this might not be suitable for all use cases, and the performance impact should be considered.

Recommendation

Apply workaround: Enable deterministic algorithms for the entire run, as this ensures consistent numerics and avoids potential issues with isolated operations.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING