pytorch - 💡(How to fix) Fix Perf regression: nn.Module init ~3.4× slower in PT 2.12 due to trunc_normal_ rewrite (BltModel example) [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#182116Fetched 2026-05-02 05:27:06
View on GitHub
Comments
0
Participants
1
Timeline
24
Reactions
0
Author
Participants
Timeline (top)
labeled ×8mentioned ×7subscribed ×7cross-referenced ×2

Code Example

# Save the reproducer script
cat > /tmp/blt_init_repro.py << 'EOF'
import time, torch
print(f"torch: {torch.__version__}")
print(f"torch_git: {torch.version.git_version}")

from transformers import BltConfig, BltModel
config = BltConfig(
    vocab_size=128, hidden_size=64, num_hidden_layers=2,
    num_attention_heads=2, intermediate_size=128, max_position_embeddings=64,
)

t0 = time.perf_counter()
m = BltModel(config)
print(f"BltModel(config) instantiation: {time.perf_counter()-t0:.3f}s")
EOF

# Run on PT 2.11
${PT_211_VENV}/bin/python /tmp/blt_init_repro.py

# Run on PT 2.12 (use whatever recent nightly you have — the relevant change is in main since 2026-03-02)
${PT_212_VENV}/bin/python /tmp/blt_init_repro.py

---

ncalls  cumtime  function
  2383  230.862  {method 'normal_' of 'torch._C.TensorBase' objects}
   772   91.972  /lib/.../torch/nn/init.py:86(_no_grad_trunc_normal_)
  1608   41.135  {built-in method torch.where}
  2380   23.768  {method 'any' of 'torch._C.TensorBase' objects}

---

PyTorch version: 2.11.0+cu128
torch.version.git_version: 70d99e998b4955e0049d13a98d77ae1b14db1f45
OS: CentOS Stream 9 (x86_64)
Python: 3.12.13
GPU: NVIDIA PG509-210

---

PyTorch version: 2.12.0.dev20260407+cu128
torch.version.git_version: 6a13e444ee88996ff01cd2bab41d7f2857291646
OS: CentOS Stream 9 (x86_64)
Python: 3.12.13
GPU: NVIDIA PG509-210
RAW_BUFFERClick to expand / collapse

Describe the bug

nn.Module instantiation is ~3.4× slower in PyTorch 2.12 than in 2.11 for transformers.BltModel. The slowdown is in torch.nn.init._no_grad_trunc_normal_, which was rewritten in PR #174997 from analytical-formula to rejection-sampling.

The rewrite is a real correctness fix (PR #174997 fixes #145498 — fp32 numerical instability of erfinv at extreme bounds destroying weight-init kurtosis). Filing this issue to record the performance cost of the fix and to ask whether the rejection-sampling implementation has optimization opportunities.

Reproducer (copy-paste ready)

Setup (assumes you already have PT 2.11 stable + a recent PT 2.12 nightly venv):

# Save the reproducer script
cat > /tmp/blt_init_repro.py << 'EOF'
import time, torch
print(f"torch: {torch.__version__}")
print(f"torch_git: {torch.version.git_version}")

from transformers import BltConfig, BltModel
config = BltConfig(
    vocab_size=128, hidden_size=64, num_hidden_layers=2,
    num_attention_heads=2, intermediate_size=128, max_position_embeddings=64,
)

t0 = time.perf_counter()
m = BltModel(config)
print(f"BltModel(config) instantiation: {time.perf_counter()-t0:.3f}s")
EOF

# Run on PT 2.11
${PT_211_VENV}/bin/python /tmp/blt_init_repro.py

# Run on PT 2.12 (use whatever recent nightly you have — the relevant change is in main since 2026-03-02)
${PT_212_VENV}/bin/python /tmp/blt_init_repro.py

Requires transformers==5.5.3 (or later — transformers/models/blt/modeling_blt.py is byte-identical across these versions, so model code is not a variable here).

Benchmark results

MetricPT 2.11.0+cu128PT 2.12.0.dev20260407+cu128Δ
BltModel(config) instantiation wall time111.18s373.96s3.4×
_no_grad_trunc_normal_ cumtime (cProfile)88.3s (772 calls)367.9s (772 calls)4.2×
Tensor.normal_ calls (rejection-sampling overhead)(not in top)230.9s (2383 calls)
torch.where calls (mask construction)(not in top)41.1s (1608 calls)
Tensor.any() calls (rejection convergence check)(not in top)23.8s (2380 calls)

cProfile breakdown (top by cumtime, PT 2.12):

ncalls  cumtime  function
  2383  230.862  {method 'normal_' of 'torch._C.TensorBase' objects}
   772   91.972  /lib/.../torch/nn/init.py:86(_no_grad_trunc_normal_)
  1608   41.135  {built-in method torch.where}
  2380   23.768  {method 'any' of 'torch._C.TensorBase' objects}

Average: ~3.1 rejection passes per _no_grad_trunc_normal_ call (2383 / 772 ≈ 3.09).

Validation against latest nightly

Verified that torch/nn/init.py is byte-identical between our installed 2.12.0.dev20260407 and pytorch/pytorch:main@HEAD (sourced via direct fetch). The _no_grad_trunc_normal_ change has been stable since PR #174997 merged on 2026-03-02. Latest available 2.12 nightly is 2.12.0.dev20260408 — same source for this function.

Suspected cause

PR #174997 by @albanD ("trunc_normal_ low precision fix", merged 2026-03-02). Replaced the analytical uniform_ + erfinv_ (single-pass) with rejection-sampling (normal_ + where + while True loop). On Blt's bounds (default a=-2, b=2 with std≈0.02), each tensor needs ~3 rejection passes on average (2383 normal_ calls / 772 init calls), producing the 4× slowdown.

The PR body explains the rationale clearly: the old approach generated outliers clamped at ±2.0 (1000σ values) that destroyed weight-init kurtosis. The fix is correct.

Question for the maintainers

Is there a fast-path optimization for the common case where p = norm_cdf((b - mean) / std) - norm_cdf((a - mean) / std) > 0.99? In that regime, virtually all samples are accepted on pass 1, but the while True loop still incurs the per-iteration overhead of mask.any() + torch.where + empty_like + normal_. For weight initialization with default ±2σ bounds and reasonable std, this is the overwhelmingly common case (and is what nn.Linear, nn.Embedding, etc. exercise via the standard initializer paths).

Concrete question: could the loop short-circuit when mask.any() is false on the first pass without paying the torch.where + empty_like + normal_ cost?

Versions

PT 2.11:

PyTorch version: 2.11.0+cu128
torch.version.git_version: 70d99e998b4955e0049d13a98d77ae1b14db1f45
OS: CentOS Stream 9 (x86_64)
Python: 3.12.13
GPU: NVIDIA PG509-210

PT 2.12 (latest 2.12 nightly available):

PyTorch version: 2.12.0.dev20260407+cu128
torch.version.git_version: 6a13e444ee88996ff01cd2bab41d7f2857291646
OS: CentOS Stream 9 (x86_64)
Python: 3.12.13
GPU: NVIDIA PG509-210

cc @jerryzh168 @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @pbelevich (PR author)

extent analysis

TL;DR

The slowdown in nn.Module instantiation can be mitigated by optimizing the rejection-sampling implementation in torch.nn.init._no_grad_trunc_normal_.

Guidance

  • Investigate the possibility of adding a fast-path optimization for the common case where p = norm_cdf((b - mean) / std) - norm_cdf((a - mean) / std) > 0.99.
  • Consider short-circuiting the while True loop when mask.any() is false on the first pass to avoid unnecessary overhead.
  • Review the cProfile breakdown to identify other potential optimization opportunities, such as reducing the number of Tensor.normal_ calls.
  • Verify that any optimizations do not compromise the correctness of the fix introduced in PR #174997.

Example

No code snippet is provided as the issue is focused on optimizing an existing implementation rather than introducing new code.

Notes

The optimization opportunities are specific to the rejection-sampling implementation in torch.nn.init._no_grad_trunc_normal_ and may not be applicable to other parts of the codebase.

Recommendation

Apply a workaround by optimizing the rejection-sampling implementation, as the slowdown is a known consequence of the correctness fix introduced in PR #174997.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING