pytorch - 💡(How to fix) Fix [MPS] For accumulating ops allow the correctness tolerance to scale within the envelope of the reference implementations drift [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#183621Fetched 2026-05-14 03:28:03
View on GitHub
Comments
0
Participants
1
Timeline
16
Reactions
0
Participants
Timeline (top)
labeled ×6mentioned ×5subscribed ×5

Related to the discussion at issue #182817 this is an orthogonal issue with the current testing approach and a suggestion on how we could address it:

When testing op correctness for fp16/bf16 implementations in MPS backend we compare against the CPU computations done with the same precision. For ops that aggregate multiple computations to produce the elements of the final result like convolutions or reduction ops the ordering how the computation happens will cause a drift in the final product.

One can estimate the worst case per-element error with

err = sqrt(K) * eps * |x|

where K is the reduction depth, eps the machine epsilon for the dtype and |x| the partial product magnitude. So this tells us that the expected uncertainty grows with the number of partial results aggregated as well as machine epsilon that is orders of magnitude larger for fp16 and bf16 compared to fp32.

A small script to demonstrate the issue (note this requires the user to register ConvTranspose3d for fp16 locally first, its left unimplemented in the current nightly due to this testing issue)

import torch
from torch.testing._internal.common_methods_invocations import op_db

for op in op_db:
    if op.name != 'nn.functional.conv_transpose3d':
        continue
    print('Looking at all fp16 samples for conv_transpose3d...')
    samples = list(op.sample_inputs('mps', torch.float16, requires_grad=False, set_seed=True))
    for i, sample in enumerate(samples):
        try:
            inp = sample.input
            args = list(sample.args)
            mps_out = op(inp, *args, **sample.kwargs)
            cpu_fp32_in = inp.to('cpu', dtype=torch.float32)
            cpu_fp32_args = [a.to('cpu', dtype=torch.float32) if hasattr(a, 'to') else a for a in args]
            cpu_fp32_out = op(cpu_fp32_in, *cpu_fp32_args, **sample.kwargs)
            cpu_fp16_in = inp.cpu()
            cpu_fp16_args = [a.cpu() if hasattr(a, 'cpu') else a for a in args]
            cpu_fp16_out = op(cpu_fp16_in, *cpu_fp16_args, **sample.kwargs)
            torch.mps.synchronize()
            mps_vs_fp32 = (mps_out.cpu().float() - cpu_fp32_out).abs().max().item()
            cpu_vs_fp32 = (cpu_fp16_out.float() - cpu_fp32_out).abs().max().item()
            mps_vs_cpu  = (mps_out.cpu().float() - cpu_fp16_out.float()).abs().max().item()
            flag = ""
            if mps_vs_cpu > 5e-2:
                flag = "  OPINFO FAILS"
            if mps_vs_fp32 > cpu_vs_fp32 * 2 + 1e-3:
                flag += "  DRIFT TEST FAILS: LIKELY AN ACTUAL BUG"
            print(f"  sample {i}: MPS/fp32={mps_vs_fp32:.4f}  CPU/fp32={cpu_vs_fp32:.4f}  MPS/CPU={mps_vs_cpu:.4f}{flag}")
        except Exception as e:
            print(f"  sample {i}: ERROR {e}")
    break

On my machine this will print something like

 sample   |MPS - fp32|   |CPU - fp32|   |MPS - CPU|    assertEqual
  ------   ------------   ------------   -----------    -----------
    0          0.167          0.230         0.250         FAILS
    1          0.167          0.230         0.250         FAILS
    2          0.083          0.116         0.125         FAILS
    3          0.083          0.084         0.125         FAILS
    4          0.043          0.061         0.063         FAILS
    5          0.043          0.061         0.063         FAILS
    6          0.057          0.140         0.125         FAILS
    7          0.057          0.140         0.125         FAILS
    8          0.244          1.138         1.000         FAILS
    9          0.244          1.138         1.000         FAILS

So for all 10 samples used by the op_db for this test the canonical correctness test fails with the generous 5^-2 atol. Even though the column |MPS - fp32| shows that for all cases the MPS implementation is closer to the CPU fp32 reference value than the CPU computed fp16 value in |CPU - fp32|!

Error Message

import torch from torch.testing._internal.common_methods_invocations import op_db

for op in op_db: if op.name != 'nn.functional.conv_transpose3d': continue print('Looking at all fp16 samples for conv_transpose3d...') samples = list(op.sample_inputs('mps', torch.float16, requires_grad=False, set_seed=True)) for i, sample in enumerate(samples): try: inp = sample.input args = list(sample.args) mps_out = op(inp, *args, **sample.kwargs) cpu_fp32_in = inp.to('cpu', dtype=torch.float32) cpu_fp32_args = [a.to('cpu', dtype=torch.float32) if hasattr(a, 'to') else a for a in args] cpu_fp32_out = op(cpu_fp32_in, *cpu_fp32_args, **sample.kwargs) cpu_fp16_in = inp.cpu() cpu_fp16_args = [a.cpu() if hasattr(a, 'cpu') else a for a in args] cpu_fp16_out = op(cpu_fp16_in, *cpu_fp16_args, **sample.kwargs) torch.mps.synchronize() mps_vs_fp32 = (mps_out.cpu().float() - cpu_fp32_out).abs().max().item() cpu_vs_fp32 = (cpu_fp16_out.float() - cpu_fp32_out).abs().max().item() mps_vs_cpu = (mps_out.cpu().float() - cpu_fp16_out.float()).abs().max().item() flag = "" if mps_vs_cpu > 5e-2: flag = " OPINFO FAILS" if mps_vs_fp32 > cpu_vs_fp32 * 2 + 1e-3: flag += " DRIFT TEST FAILS: LIKELY AN ACTUAL BUG" print(f" sample {i}: MPS/fp32={mps_vs_fp32:.4f} CPU/fp32={cpu_vs_fp32:.4f} MPS/CPU={mps_vs_cpu:.4f}{flag}") except Exception as e: print(f" sample {i}: ERROR {e}") break

Root Cause

Related to the discussion at issue #182817 this is an orthogonal issue with the current testing approach and a suggestion on how we could address it:

When testing op correctness for fp16/bf16 implementations in MPS backend we compare against the CPU computations done with the same precision. For ops that aggregate multiple computations to produce the elements of the final result like convolutions or reduction ops the ordering how the computation happens will cause a drift in the final product.

One can estimate the worst case per-element error with

err = sqrt(K) * eps * |x|

where K is the reduction depth, eps the machine epsilon for the dtype and |x| the partial product magnitude. So this tells us that the expected uncertainty grows with the number of partial results aggregated as well as machine epsilon that is orders of magnitude larger for fp16 and bf16 compared to fp32.

A small script to demonstrate the issue (note this requires the user to register ConvTranspose3d for fp16 locally first, its left unimplemented in the current nightly due to this testing issue)

import torch
from torch.testing._internal.common_methods_invocations import op_db

for op in op_db:
    if op.name != 'nn.functional.conv_transpose3d':
        continue
    print('Looking at all fp16 samples for conv_transpose3d...')
    samples = list(op.sample_inputs('mps', torch.float16, requires_grad=False, set_seed=True))
    for i, sample in enumerate(samples):
        try:
            inp = sample.input
            args = list(sample.args)
            mps_out = op(inp, *args, **sample.kwargs)
            cpu_fp32_in = inp.to('cpu', dtype=torch.float32)
            cpu_fp32_args = [a.to('cpu', dtype=torch.float32) if hasattr(a, 'to') else a for a in args]
            cpu_fp32_out = op(cpu_fp32_in, *cpu_fp32_args, **sample.kwargs)
            cpu_fp16_in = inp.cpu()
            cpu_fp16_args = [a.cpu() if hasattr(a, 'cpu') else a for a in args]
            cpu_fp16_out = op(cpu_fp16_in, *cpu_fp16_args, **sample.kwargs)
            torch.mps.synchronize()
            mps_vs_fp32 = (mps_out.cpu().float() - cpu_fp32_out).abs().max().item()
            cpu_vs_fp32 = (cpu_fp16_out.float() - cpu_fp32_out).abs().max().item()
            mps_vs_cpu  = (mps_out.cpu().float() - cpu_fp16_out.float()).abs().max().item()
            flag = ""
            if mps_vs_cpu > 5e-2:
                flag = "  OPINFO FAILS"
            if mps_vs_fp32 > cpu_vs_fp32 * 2 + 1e-3:
                flag += "  DRIFT TEST FAILS: LIKELY AN ACTUAL BUG"
            print(f"  sample {i}: MPS/fp32={mps_vs_fp32:.4f}  CPU/fp32={cpu_vs_fp32:.4f}  MPS/CPU={mps_vs_cpu:.4f}{flag}")
        except Exception as e:
            print(f"  sample {i}: ERROR {e}")
    break

On my machine this will print something like

 sample   |MPS - fp32|   |CPU - fp32|   |MPS - CPU|    assertEqual
  ------   ------------   ------------   -----------    -----------
    0          0.167          0.230         0.250         FAILS
    1          0.167          0.230         0.250         FAILS
    2          0.083          0.116         0.125         FAILS
    3          0.083          0.084         0.125         FAILS
    4          0.043          0.061         0.063         FAILS
    5          0.043          0.061         0.063         FAILS
    6          0.057          0.140         0.125         FAILS
    7          0.057          0.140         0.125         FAILS
    8          0.244          1.138         1.000         FAILS
    9          0.244          1.138         1.000         FAILS

So for all 10 samples used by the op_db for this test the canonical correctness test fails with the generous 5^-2 atol. Even though the column |MPS - fp32| shows that for all cases the MPS implementation is closer to the CPU fp32 reference value than the CPU computed fp16 value in |CPU - fp32|!

Code Example

import torch
from torch.testing._internal.common_methods_invocations import op_db

for op in op_db:
    if op.name != 'nn.functional.conv_transpose3d':
        continue
    print('Looking at all fp16 samples for conv_transpose3d...')
    samples = list(op.sample_inputs('mps', torch.float16, requires_grad=False, set_seed=True))
    for i, sample in enumerate(samples):
        try:
            inp = sample.input
            args = list(sample.args)
            mps_out = op(inp, *args, **sample.kwargs)
            cpu_fp32_in = inp.to('cpu', dtype=torch.float32)
            cpu_fp32_args = [a.to('cpu', dtype=torch.float32) if hasattr(a, 'to') else a for a in args]
            cpu_fp32_out = op(cpu_fp32_in, *cpu_fp32_args, **sample.kwargs)
            cpu_fp16_in = inp.cpu()
            cpu_fp16_args = [a.cpu() if hasattr(a, 'cpu') else a for a in args]
            cpu_fp16_out = op(cpu_fp16_in, *cpu_fp16_args, **sample.kwargs)
            torch.mps.synchronize()
            mps_vs_fp32 = (mps_out.cpu().float() - cpu_fp32_out).abs().max().item()
            cpu_vs_fp32 = (cpu_fp16_out.float() - cpu_fp32_out).abs().max().item()
            mps_vs_cpu  = (mps_out.cpu().float() - cpu_fp16_out.float()).abs().max().item()
            flag = ""
            if mps_vs_cpu > 5e-2:
                flag = "  OPINFO FAILS"
            if mps_vs_fp32 > cpu_vs_fp32 * 2 + 1e-3:
                flag += "  DRIFT TEST FAILS: LIKELY AN ACTUAL BUG"
            print(f"  sample {i}: MPS/fp32={mps_vs_fp32:.4f}  CPU/fp32={cpu_vs_fp32:.4f}  MPS/CPU={mps_vs_cpu:.4f}{flag}")
        except Exception as e:
            print(f"  sample {i}: ERROR {e}")
    break

---

sample   |MPS - fp32|   |CPU - fp32|   |MPS - CPU|    assertEqual
  ------   ------------   ------------   -----------    -----------
    0          0.167          0.230         0.250         FAILS
    1          0.167          0.230         0.250         FAILS
    2          0.083          0.116         0.125         FAILS
    3          0.083          0.084         0.125         FAILS
    4          0.043          0.061         0.063         FAILS
    5          0.043          0.061         0.063         FAILS
    6          0.057          0.140         0.125         FAILS
    7          0.057          0.140         0.125         FAILS
    8          0.244          1.138         1.000         FAILS
    9          0.244          1.138         1.000         FAILS

---

# pseudocode
  def assert_mps_match_or_drift(cpu, mps, fp32_ref, atol, rtol, dtype):
      try:
          assertEqual(cpu, mps, atol=atol, rtol=rtol)
          return
      except AssertionError:
          if dtype not in (float16, bfloat16):
              raise                                      # Not a half type: Regression probably real

      cpu_err = (cpu - fp32_ref).abs()
      mps_err = (mps - fp32_ref).abs()
      budget  = atol + rtol * fp32_ref.abs()

      # Can CPU pass the expected precision w.r.t fp32?
      if (cpu_err <= budget).all():
          raise AssertionError("MPS exceeds budget where CPU meets it — regression")

      # Worst case scenario: CPU and MPS drift to exactly opposite directions: SLACK=2.0
      SLACK = 2.0

      # Neither implementation can meet the fixed tolerance at this precision.
      # Assert MPS is within the CPU implementations drift envelope
      assert mps_err.max() <= SLACK * cpu_err.max() + atol
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Context

Related to the discussion at issue #182817 this is an orthogonal issue with the current testing approach and a suggestion on how we could address it:

When testing op correctness for fp16/bf16 implementations in MPS backend we compare against the CPU computations done with the same precision. For ops that aggregate multiple computations to produce the elements of the final result like convolutions or reduction ops the ordering how the computation happens will cause a drift in the final product.

One can estimate the worst case per-element error with

err = sqrt(K) * eps * |x|

where K is the reduction depth, eps the machine epsilon for the dtype and |x| the partial product magnitude. So this tells us that the expected uncertainty grows with the number of partial results aggregated as well as machine epsilon that is orders of magnitude larger for fp16 and bf16 compared to fp32.

A small script to demonstrate the issue (note this requires the user to register ConvTranspose3d for fp16 locally first, its left unimplemented in the current nightly due to this testing issue)

import torch
from torch.testing._internal.common_methods_invocations import op_db

for op in op_db:
    if op.name != 'nn.functional.conv_transpose3d':
        continue
    print('Looking at all fp16 samples for conv_transpose3d...')
    samples = list(op.sample_inputs('mps', torch.float16, requires_grad=False, set_seed=True))
    for i, sample in enumerate(samples):
        try:
            inp = sample.input
            args = list(sample.args)
            mps_out = op(inp, *args, **sample.kwargs)
            cpu_fp32_in = inp.to('cpu', dtype=torch.float32)
            cpu_fp32_args = [a.to('cpu', dtype=torch.float32) if hasattr(a, 'to') else a for a in args]
            cpu_fp32_out = op(cpu_fp32_in, *cpu_fp32_args, **sample.kwargs)
            cpu_fp16_in = inp.cpu()
            cpu_fp16_args = [a.cpu() if hasattr(a, 'cpu') else a for a in args]
            cpu_fp16_out = op(cpu_fp16_in, *cpu_fp16_args, **sample.kwargs)
            torch.mps.synchronize()
            mps_vs_fp32 = (mps_out.cpu().float() - cpu_fp32_out).abs().max().item()
            cpu_vs_fp32 = (cpu_fp16_out.float() - cpu_fp32_out).abs().max().item()
            mps_vs_cpu  = (mps_out.cpu().float() - cpu_fp16_out.float()).abs().max().item()
            flag = ""
            if mps_vs_cpu > 5e-2:
                flag = "  OPINFO FAILS"
            if mps_vs_fp32 > cpu_vs_fp32 * 2 + 1e-3:
                flag += "  DRIFT TEST FAILS: LIKELY AN ACTUAL BUG"
            print(f"  sample {i}: MPS/fp32={mps_vs_fp32:.4f}  CPU/fp32={cpu_vs_fp32:.4f}  MPS/CPU={mps_vs_cpu:.4f}{flag}")
        except Exception as e:
            print(f"  sample {i}: ERROR {e}")
    break

On my machine this will print something like

 sample   |MPS - fp32|   |CPU - fp32|   |MPS - CPU|    assertEqual
  ------   ------------   ------------   -----------    -----------
    0          0.167          0.230         0.250         FAILS
    1          0.167          0.230         0.250         FAILS
    2          0.083          0.116         0.125         FAILS
    3          0.083          0.084         0.125         FAILS
    4          0.043          0.061         0.063         FAILS
    5          0.043          0.061         0.063         FAILS
    6          0.057          0.140         0.125         FAILS
    7          0.057          0.140         0.125         FAILS
    8          0.244          1.138         1.000         FAILS
    9          0.244          1.138         1.000         FAILS

So for all 10 samples used by the op_db for this test the canonical correctness test fails with the generous 5^-2 atol. Even though the column |MPS - fp32| shows that for all cases the MPS implementation is closer to the CPU fp32 reference value than the CPU computed fp16 value in |CPU - fp32|!

Proposal

The accuracy of the half type implementations should be allowed to drift within the same envelope that the CPU reference between the dtypes does. This can be implemented as a fall through test for the cases that fail the traditional assertion and we can verify that this is indeed a case where the CPU also drifts significantly from the FP32 reference. Since the worst case of the drift is in exactly opposite direction from the CPU reference the allowed envelope should be 2x the |CPU - fp32| + atol. The same with pseudo-code:

  # pseudocode
  def assert_mps_match_or_drift(cpu, mps, fp32_ref, atol, rtol, dtype):
      try:
          assertEqual(cpu, mps, atol=atol, rtol=rtol)
          return
      except AssertionError:
          if dtype not in (float16, bfloat16):
              raise                                      # Not a half type: Regression probably real

      cpu_err = (cpu - fp32_ref).abs()
      mps_err = (mps - fp32_ref).abs()
      budget  = atol + rtol * fp32_ref.abs()

      # Can CPU pass the expected precision w.r.t fp32?
      if (cpu_err <= budget).all():
          raise AssertionError("MPS exceeds budget where CPU meets it — regression")

      # Worst case scenario: CPU and MPS drift to exactly opposite directions: SLACK=2.0
      SLACK = 2.0

      # Neither implementation can meet the fixed tolerance at this precision.
      # Assert MPS is within the CPU implementations drift envelope
      assert mps_err.max() <= SLACK * cpu_err.max() + atol

Replacing the current assertEqual call with this allows also the currently skipped ConvTranspose1D and ConvTranspose2D tests for fp16 to pass correctly as the same issue has prevented us from enabling those tests. There are probably other similar cases that need to be re-evaluated on the skiplist for MPS. The upside is that with this approach we would not need to manually identify which cases are such where the error accumulates to a measurable drift: We determine that from the CPU case. If there is little to no drift for the CPU we will still raise the correctness error as we currently are.

I can add a PR for this in case the reasoning makes sense.

Versions

Current nightly. MacOS 26.4.

cc @mruberry @kulinseth @malfet @DenisVieriu97 @aditvenk

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [MPS] For accumulating ops allow the correctness tolerance to scale within the envelope of the reference implementations drift [1 participants]