pytorch - ✅(Solved) Fix [symm_mem] small-M FP8 fused_scaled_matmul_reduce_scatter is slower than unfused baseline on B300/SM100 [2 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179770Fetched 2026-04-09 07:50:07
View on GitHub
Comments
1
Participants
2
Timeline
31
Reactions
0
Timeline (top)
mentioned ×14subscribed ×14labeled ×2commented ×1

PR fix notes

PR #179918: [SymmetricMemory] Add reduce_scatter_tensor_out and native async-TP fast paths #179770

Description (problem / solution / changelog)

https://github.com/pytorch/pytorch/issues/179770 This PR improves SymmetricMemory for async tensor parallel workloads by:

  • enabling the native fused_all_gather_matmul fast path based on runtime heuristics alone
  • adding a native fast path for fused_scaled_matmul_reduce_scatter
  • adding a new symm_mem.reduce_scatter_tensor_out CUDA operator

test in vllm side https://github.com/vllm-project/vllm/pull/39505

python test/distributed/test_symmetric_memory.py -k reduce_scatter_tensor_out
I0410 14:31:02.698000 2646070 /root/zdj/pytorch/torch/testing/_internal/common_distributed.py:1986] Testing class SymmMemCollectiveTest on 8 cuda
....I0410 14:31:14.221000 2646070 /root/zdj/pytorch/torch/testing/_internal/common_distributed.py:2019] Class SymmMemCollectiveTest finished

----------------------------------------------------------------------
Ran 4 tests in 11.524s

OK

Changed files

  • test/distributed/test_symmetric_memory.py (modified, +34/-1)
  • torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu (modified, +75/-0)
  • torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp (modified, +2/-0)
  • torch/distributed/_symmetric_memory/__init__.py (modified, +179/-2)

Code Example

torchrun --nproc_per_node=8 /root/zdj/pytorch/tmp/bench_symm_mem_rs.py -- \
  --mode fp8 \
  --m-list 32 64 96 128 160 192 224 256 320 384 448 512 768 1024 2048 4096 8192 16384 \
  --scatter-dim-list 0 \
  --k 7168 \
  --n 16384 \
  --warmup 20 \
  --iters 100 \
  --out /tmp/symm_mem_fp8_sd0_main.json

---

mode  m      scatter_dim  fused_ms   baseline_ms  speedup
fp8   32     0            0.249789   0.096237     0.3853
fp8   64     0            0.243963   0.073858     0.3027
fp8   96     0            0.289270   0.087035     0.3009
fp8   128    0            0.246243   0.074093     0.3009
fp8   160    0            0.247176   0.081163     0.3284
fp8   192    0            0.247684   0.073901     0.2984
fp8   224    0            0.246034   0.081759     0.3323
fp8   256    0            0.248132   0.083836     0.3379
fp8   320    0            0.252470   0.087673     0.3473
fp8   384    0            0.250542   0.093923     0.3749
fp8   448    0            0.277906   0.100980     0.3634
fp8   512    0            0.256095   0.101829     0.3976
fp8   768    0            0.277957   0.140830     0.5067
fp8   1024   0            0.292938   0.191193     0.6527
fp8   2048   0            0.344706   0.277644     0.8055
fp8   4096   0            0.473940   0.503326     1.0620
fp8   8192   0            0.832994   0.995565     1.1952
fp8   16384  0            1.743395   1.910754     1.0960
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

https://github.com/vllm-project/vllm/issues/27893

This blocks vLLM from reliably enabling AsyncTP scaled_mm+comms on SM100, since the public symm_mem fused RS+scaled_mm path regresses on small-M FP8 workloads.

benchmark https://paste.ubuntu.com/p/tVNzPgy2Gp/

torchrun --nproc_per_node=8 /root/zdj/pytorch/tmp/bench_symm_mem_rs.py -- \
  --mode fp8 \
  --m-list 32 64 96 128 160 192 224 256 320 384 448 512 768 1024 2048 4096 8192 16384 \
  --scatter-dim-list 0 \
  --k 7168 \
  --n 16384 \
  --warmup 20 \
  --iters 100 \
  --out /tmp/symm_mem_fp8_sd0_main.json

res:

mode  m      scatter_dim  fused_ms   baseline_ms  speedup
fp8   32     0            0.249789   0.096237     0.3853
fp8   64     0            0.243963   0.073858     0.3027
fp8   96     0            0.289270   0.087035     0.3009
fp8   128    0            0.246243   0.074093     0.3009
fp8   160    0            0.247176   0.081163     0.3284
fp8   192    0            0.247684   0.073901     0.2984
fp8   224    0            0.246034   0.081759     0.3323
fp8   256    0            0.248132   0.083836     0.3379
fp8   320    0            0.252470   0.087673     0.3473
fp8   384    0            0.250542   0.093923     0.3749
fp8   448    0            0.277906   0.100980     0.3634
fp8   512    0            0.256095   0.101829     0.3976
fp8   768    0            0.277957   0.140830     0.5067
fp8   1024   0            0.292938   0.191193     0.6527
fp8   2048   0            0.344706   0.277644     0.8055
fp8   4096   0            0.473940   0.503326     1.0620
fp8   8192   0            0.832994   0.995565     1.1952
fp8   16384  0            1.743395   1.910754     1.0960

Versions

2.12.0a0+gite3473e8

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan

extent analysis

TL;DR

The issue can be addressed by investigating and optimizing the symm_mem fused RS+scaled_mm path for small-M FP8 workloads.

Guidance

  • Review the benchmark results to identify patterns or thresholds where the regression occurs, focusing on the transition points in the m values.
  • Investigate the implementation of the symm_mem fused RS+scaled_mm path to understand potential bottlenecks or inefficiencies, especially for small-M FP8 workloads.
  • Consider modifying the benchmark script to include additional diagnostic information or to test specific hypotheses about the cause of the regression.
  • Analyze the performance difference between the fused and baseline implementations across various m values to pinpoint the source of the regression.

Example

No specific code snippet can be provided without more context, but the benchmark script (bench_symm_mem_rs.py) should be examined for any assumptions or simplifications that might not hold for small-M FP8 workloads.

Notes

The provided information suggests a performance regression in a specific scenario but does not offer a clear solution. Further investigation into the symm_mem implementation and the benchmarking process is necessary.

Recommendation

Apply a workaround by optimizing the symm_mem fused RS+scaled_mm path specifically for small-M FP8 workloads, as the current implementation seems to be the root cause of the regression.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING