pytorch - ✅(Solved) Fix [MPS] In-place `self.add_(other, alpha)` `RuntimeError`s with type promotion [1 pull requests, 2 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178709Fetched 2026-04-08 01:45:03
View on GitHub
Comments
2
Participants
3
Timeline
33
Reactions
0
Author
Timeline (top)
mentioned ×10subscribed ×10referenced ×5labeled ×4

Error Message

import torch

device="mps"

a = torch.arange(16, dtype=torch.float16, device=device) b = torch.arange(16, dtype=torch.float32, device=device) alpha = torch.tensor(0.33, dtype=torch.float32)

a.add_(b, alpha=alpha)

Output:

RuntimeError: Failed to create function state object for: add_alpha_dense_cast_half_float

Fix Action

Fixed

PR fix notes

PR #178724: [MPS] fix in-place self.add_(other, alpha) RuntimeErrors with type promotion

Description (problem / solution / changelog)

Fixes #178709

Changed files

  • aten/src/ATen/native/mps/OperationUtils.mm (modified, +2/-1)
  • test/test_mps.py (modified, +12/-0)

Code Example

import torch

device="mps"

a = torch.arange(16, dtype=torch.float16, device=device)
b = torch.arange(16, dtype=torch.float32, device=device)
alpha = torch.tensor(0.33, dtype=torch.float32)

a.add_(b, alpha=alpha)
# Output:
# RuntimeError: Failed to create function state object for: add_alpha_dense_cast_half_float
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

In-place add_ with alpha results in a RuntimeError when the type of self is promoted. The op tries to use a kernel that isn't defined for those dtypes. The issue does not occur when alpha is omitted, if the other tensor is promoted, or when using the non-inplace add with alpha.

On CPU and CUDA, add_ with the alpha parameter promotes types, then stores the result back in the original tensor's dtype.

This is a regression. Last known good PyTorch version is 2.8.0.

Other ops using the same code path might be affected.

Discovered while working on cifar10 speedrun (context: cifar10-airbench).

MRE

import torch

device="mps"

a = torch.arange(16, dtype=torch.float16, device=device)
b = torch.arange(16, dtype=torch.float32, device=device)
alpha = torch.tensor(0.33, dtype=torch.float32)

a.add_(b, alpha=alpha)
# Output:
# RuntimeError: Failed to create function state object for: add_alpha_dense_cast_half_float

Versions

PyTorch version: 2.12.0a0+gitab45b0e Is debug build: True CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 26.4 (arm64) GCC version: Could not collect Clang version: 21.0.0 (clang-2100.0.123.102) CMake version: version 4.1.2 Libc version: N/A

Python version: 3.12.8 | packaged by conda-forge | (main, Dec 5 2024, 14:19:53) [Clang 18.1.8 ] (64-bit runtime) Python platform: macOS-26.4-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Apple M3 Max

Versions of relevant libraries: [pip3] flake8==7.2.0 [pip3] mypy==1.13.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==2.4.3 [pip3] onnx==1.17.0 [pip3] onnx2torch==1.5.15 [pip3] onnxruntime==1.21.1 [pip3] onnxscript==0.1.0.dev20240817 [pip3] optree==0.13.0 [pip3] pytorch_sphinx_theme==0.0.24 [pip3] torch==2.12.0a0+gitab45b0e [pip3] torchaudio==2.6.0a0+1a8f621 [pip3] torchbench==0.1 [pip3] torchvision==0.26.0a0+6285457 [conda] numpy 2.4.3 pypi_0 pypi [conda] onnx2torch 1.5.15 pypi_0 pypi [conda] optree 0.13.0 pypi_0 pypi [conda] pytorch-sphinx-theme 0.0.24 pypi_0 pypi [conda] torch 2.12.0a0+gitab45b0e pypi_0 pypi [conda] torchaudio 2.6.0a0+1a8f621 dev_0 <develop> [conda] torchbench 0.1 dev_0 <develop> [conda] torchvision 0.26.0a0+6285457 dev_0 <develop>

cc @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

Fix Plan

To fix the issue, we need to ensure that the add_ operation is performed without promoting the type of self. We can achieve this by casting the alpha tensor to the same dtype as self before performing the operation.

Here are the steps:

  • Cast the alpha tensor to the same dtype as self.
  • Perform the add_ operation with the cast alpha tensor.

Example code:

import torch

device = "mps"

a = torch.arange(16, dtype=torch.float16, device=device)
b = torch.arange(16, dtype=torch.float32, device=device)
alpha = torch.tensor(0.33, dtype=torch.float32)

# Cast alpha to the same dtype as a
alpha_cast = alpha.to(a.dtype)

# Perform add_ operation with cast alpha
a.add_(b.to(a.dtype), alpha=alpha_cast)

Note that we also cast b to the same dtype as a to avoid type promotion.

Verification

To verify that the fix worked, you can check that the add_ operation completes without raising a RuntimeError. You can also verify that the result is correct by comparing it with the expected result.

Example code:

import torch

device = "mps"

a = torch.arange(16, dtype=torch.float16, device=device)
b = torch.arange(16, dtype=torch.float32, device=device)
alpha = torch.tensor(0.33, dtype=torch.float32)

# Cast alpha to the same dtype as a
alpha_cast = alpha.to(a.dtype)

# Perform add_ operation with cast alpha
a.add_(b.to(a.dtype), alpha=alpha_cast)

# Verify that the result is correct
expected_result = torch.arange(16, dtype=torch.float16, device=device) + alpha_cast * b.to(a.dtype)
assert torch.allclose(a, expected_result)

Extra Tips

  • When performing operations with tensors of different dtypes, it's essential to ensure that the dtypes are compatible to avoid type promotion.
  • Casting tensors to the same dtype can help avoid type promotion and ensure that operations are performed correctly.
  • Always verify the result of an operation to ensure that it is correct, especially when working with different dtypes.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING