pytorch - ✅(Solved) Fix Poor BF16 torch.compile + Freezing perf on AArch64 CPUs [2 pull requests, 2 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180447Fetched 2026-04-17 08:22:29
View on GitHub
Comments
2
Participants
1
Timeline
88
Reactions
0
Author
Participants
Assignees
Timeline (top)
mentioned ×35subscribed ×35labeled ×9referenced ×4

Fix Action

Fix / Workaround

When you compile and freeze a BF16 model (for weight prepacking) on AArch64, the matmul gets dispatched to oneDNN reference kernels rather than ACL/jit optimized oneDNN kernels. As a result you run 100x slower than eager mode.

This is reproducible with any torch version. Note: for fp32, we already dispatch to optimized kernels.

This dispatches to optimized oneDNN kernels

#dtype = torch.float32

PR fix notes

PR #1279: feat: Enable BF16 I/O for CpuFullyConnected in the experimental Operator API

Description (problem / solution / changelog)

feat: Enable BF16 CpuFullyConnected

Support is already there, so just removed defensive checks Partially fixes: https://github.com/pytorch/pytorch/issues/180447

Changed files

  • src/runtime/experimental/operators/CpuFullyConnected.cpp (modified, +5/-2)

PR #5024: cpu: aarch64: enable ACL's inner-product for BF16

Description (problem / solution / changelog)

Description

cpu: aarch64: enable ACL's inner-product for BF16

This benchdnn repro with ACL built with https://github.com/ARM-software/ComputeLibrary/pull/1279 now goes to ACL which is > 100x faster:

ONEDNN_VERBOSE=all ./tests/benchdnn/benchdnn --ip --mode=C --dir=FWD_I  --bia-dt=bf16  --dt=bf16 mb1024ic1024oc1024

This PR + https://github.com/ARM-software/ComputeLibrary/pull/1279 fix: https://github.com/pytorch/pytorch/issues/180447

Fixes # (github issue)

Checklist

General

  • Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit?
  • Have you formatted the code using clang-format?

Performance improvements

  • Have you submitted performance data that demonstrates performance improvements?

New features

  • Have you published an RFC for the new feature?
  • Was the RFC approved?
  • Have you added relevant tests?

Bug fixes

  • Have you included information on how to reproduce the issue (either in a github issue or in this PR)?
  • Have you added relevant regression tests?

RFC PR

  • Does RFC document follow the template?
  • Have you added a link to the rendered document?

Changed files

  • src/cpu/aarch64/acl_inner_product.cpp (modified, +4/-2)

Code Example

import torch 

torch._inductor.config.freezing = True

# This dispatches to optimized oneDNN kernels
#dtype = torch.float32

# This dispatches to reference oneDNN kernels -> more than 100x slower than eager 
dtype = torch.bfloat16

def main():
    with torch.no_grad():
        x = torch.rand(size=(1024, 1024), dtype=dtype)
        linear = torch.nn.Linear(1024, 1024).to(dtype).eval()
        # compile and warmup
        linear = torch.compile(linear)
        linear(x)
        print("should have no reorders after this", flush=True)
        for _ in range(10):
            linear(x)
            print("=="*20, flush=True)

if __name__ == "__main__":
    main()
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When you compile and freeze a BF16 model (for weight prepacking) on AArch64, the matmul gets dispatched to oneDNN reference kernels rather than ACL/jit optimized oneDNN kernels. As a result you run 100x slower than eager mode.

This is reproducible with any torch version. Note: for fp32, we already dispatch to optimized kernels.

import torch 

torch._inductor.config.freezing = True

# This dispatches to optimized oneDNN kernels
#dtype = torch.float32

# This dispatches to reference oneDNN kernels -> more than 100x slower than eager 
dtype = torch.bfloat16

def main():
    with torch.no_grad():
        x = torch.rand(size=(1024, 1024), dtype=dtype)
        linear = torch.nn.Linear(1024, 1024).to(dtype).eval()
        # compile and warmup
        linear = torch.compile(linear)
        linear(x)
        print("should have no reorders after this", flush=True)
        for _ in range(10):
            linear(x)
            print("=="*20, flush=True)

if __name__ == "__main__":
    main()

cc @jerryzh168 @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @aditew01 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @milpuz01 @nikhil-arm @robert-hardwick @nWEIdia @chauhang @penguinwu @voznesenskym @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The issue can be mitigated by using torch.float32 instead of torch.bfloat16 for the model, as the optimized oneDNN kernels are already dispatched for fp32.

Guidance

  • The problem seems to be related to the dispatching of oneDNN kernels for torch.bfloat16 data type, which results in a significant performance drop.
  • To verify the issue, run the provided code snippet with torch.bfloat16 and torch.float32 to compare the performance.
  • Consider using torch.float32 for the model if possible, as it already dispatches to optimized oneDNN kernels.
  • If using torch.bfloat16 is necessary, further investigation is required to optimize the oneDNN kernels for this data type.

Example

# Using torch.float32 for the model
dtype = torch.float32

Notes

The provided code snippet is a good starting point for reproducing and investigating the issue. However, the root cause of the problem is not explicitly stated, and further analysis is required to optimize the oneDNN kernels for torch.bfloat16.

Recommendation

Apply workaround: Use torch.float32 instead of torch.bfloat16 for the model, as it already dispatches to optimized oneDNN kernels. This is a temporary solution until the oneDNN kernels are optimized for torch.bfloat16.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING