pytorch - 💡(How to fix) Fix [MPS] Adding a new MPS loader API for optimized reading of safetensors files [4 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179190Fetched 2026-04-08 02:32:57
View on GitHub
Comments
4
Participants
3
Timeline
27
Reactions
0
Timeline (top)
subscribed ×9mentioned ×8labeled ×6commented ×4

Fix Action

Fix / Workaround

Leveraging the MacOS Grand Central Dispatch for thread coordination and utilizing pread for thread-safe offset based reads one can avoid the overhead from locking the file descriptor as is done with the lseek() + read() pattern in the default implementation. Bulk initializing the tensors on the MPS device in one-go before starting the concurrent reads reduce the time any thread spends on doing non-I/O related activities. These features allow us to take advantage of ~full read bandwidth of the SSD on the device when reading in model weights.

Code Example

def load_safetensors(filename: str) -> dict[str, Tensor]:
    r"""Loads tensors from a safetensors file with optimized parallel I/O directly to MPS device.

    Args:
        filename (str): Path to the safetensors file.

    Returns:
        dict[str, Tensor]: Dictionary mapping tensor names to MPS tensors.

    Example::

        >>> # xdoctest: +SKIP("requires safetensors file")
        >>> state_dict = torch.mps.load_safetensors("model.safetensors")
    """
    if not hasattr(torch._C, "_mps_load_safetensors"):
        raise RuntimeError(
            "MPS safetensors loading is not available. "
            "Ensure PyTorch was built with MPS support."
        )
    return torch._C._mps_load_safetensors(filename)
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

Model loading performance to the MPS backend is not very good using the currently popular code paths and formats, namely focusing on safetensors file format often used in models from HuggingFace repository. The default load approach used in downstream libraries like transformers and diffusers achieves poor utilization of the SSD bandwidth on MacOS.

I propose adding an loader API to the PyTorch MPS backend exposed through torch.mps.load_safetensors() that could be called as a replacement for the backend agnostic safetensors.torch.load_file() in the downstream libraries if the destination device for the tensors is MPS.

The proposed API:

def load_safetensors(filename: str) -> dict[str, Tensor]:
    r"""Loads tensors from a safetensors file with optimized parallel I/O directly to MPS device.

    Args:
        filename (str): Path to the safetensors file.

    Returns:
        dict[str, Tensor]: Dictionary mapping tensor names to MPS tensors.

    Example::

        >>> # xdoctest: +SKIP("requires safetensors file")
        >>> state_dict = torch.mps.load_safetensors("model.safetensors")
    """
    if not hasattr(torch._C, "_mps_load_safetensors"):
        raise RuntimeError(
            "MPS safetensors loading is not available. "
            "Ensure PyTorch was built with MPS support."
        )
    return torch._C._mps_load_safetensors(filename)

Leveraging the MacOS Grand Central Dispatch for thread coordination and utilizing pread for thread-safe offset based reads one can avoid the overhead from locking the file descriptor as is done with the lseek() + read() pattern in the default implementation. Bulk initializing the tensors on the MPS device in one-go before starting the concurrent reads reduce the time any thread spends on doing non-I/O related activities. These features allow us to take advantage of ~full read bandwidth of the SSD on the device when reading in model weights.

Using Qwen3-8B in bf16 from HuggingFace as a test case using the transformers library where we use the torch.mps.load_safetensors() instead of the safetensors.torch.load_file() in the PreTrainedModel class, the improvement for cold loading the model weights from disk to MPS device shows up as

Default load: 9.54s ± 0.31s (1.60 GB/s) New loader: 2.91s ± 0.01s (5.24 GB/s)

on an M2 Max device, with the reported bandwidths are computed naively as model_size / load_time. Based on some online searchable external benchmarks the SSD read bandwidth of this device should be between 5.3-6.0 GB/s so this would take us to the level of performance expected from the hardware. The most notable benefit from this can be seen in improving the time to first model output when cold starting a program and the model weights have to be read from disk.

I'm happy to contribute my PR with the proposed changes for review if this is considered actionable by the maintainers.

Alternatives

Relying on the existing backend agnostic load methods currently used downstream. The speedup would mostly benefit by giving a snappy feeling to the developer when prototyping short runs testing different models or swapping models in and out of memory during program execution. For longer program runtimes the model load time is usually insignificant compared to the time spent in training for example.

Additional context

No response

cc @jerryzh168 @kulinseth @malfet @DenisVieriu97 @aditvenk

extent analysis

TL;DR

Implementing a custom loader API torch.mps.load_safetensors() can significantly improve model loading performance to the MPS backend.

Guidance

  • Consider using the proposed torch.mps.load_safetensors() API as a replacement for the backend-agnostic safetensors.torch.load_file() in downstream libraries to improve loading performance.
  • Verify the performance improvement by comparing the load times using the default and new loader APIs, as demonstrated in the provided benchmark.
  • To achieve optimal performance, ensure that PyTorch is built with MPS support, as required by the proposed API.
  • Evaluate the benefits of the custom loader API in the context of specific use cases, such as prototyping or swapping models, where the speedup can provide a noticeable improvement.

Example

state_dict = torch.mps.load_safetensors("model.safetensors")

This example demonstrates how to use the proposed torch.mps.load_safetensors() API to load a model from a safetensors file.

Notes

The proposed API is specific to the MPS backend and may not be applicable to other backends. The performance improvement is expected to be most notable on MacOS devices with SSD storage.

Recommendation

Apply the workaround by implementing the custom loader API torch.mps.load_safetensors() to improve model loading performance to the MPS backend, as it provides a significant speedup compared to the default loader API.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING