pytorch - 💡(How to fix) Fix mps test got Exception: mean(): could not infer output dtype. [2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177225Fetched 2026-04-08 00:21:49
View on GitHub
Comments
2
Participants
2
Timeline
37
Reactions
0
Author
Timeline (top)
mentioned ×15subscribed ×15labeled ×5commented ×2

Error Message

_ TestCommonMPS.test_python_ref_executor__refs_var_mean_executor_aten_mps_int8 _ 2026-03-11T05:06:54.6927770Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 1701, in only_fn 2026-03-11T05:06:54.6927860Z return fn(self, *args, **kwargs) 2026-03-11T05:06:54.6927910Z ^^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6928240Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 1788, in wrapper 2026-03-11T05:06:54.6928290Z fn(*args, **kwargs) 2026-03-11T05:06:54.6928520Z File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_ops.py", line 764, in test_python_ref_executor 2026-03-11T05:06:54.6928630Z self._ref_test_helper(contextlib.nullcontext, device, dtype, op) 2026-03-11T05:06:54.6928840Z File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_ops.py", line 609, in _ref_test_helper 2026-03-11T05:06:54.6928940Z ref_result = op(sample.input, *sample.args, **sample.kwargs) 2026-03-11T05:06:54.6929010Z ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6929340Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py", line 1251, in call 2026-03-11T05:06:54.6929400Z return self.op(*args, **kwargs) 2026-03-11T05:06:54.6929500Z ^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6929790Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_prims/executor.py", line 65, in _traced 2026-03-11T05:06:54.6929850Z gm = make_fx(wrapped)(all_args) 2026-03-11T05:06:54.6929890Z ^^^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6930260Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2848, in wrapped 2026-03-11T05:06:54.6930330Z return make_fx_tracer.trace(f, *args) 2026-03-11T05:06:54.6930380Z ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6930700Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2749, in trace 2026-03-11T05:06:54.6930760Z return self._trace_inner(f, *args) 2026-03-11T05:06:54.6930800Z ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6931150Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2710, in _trace_inner 2026-03-11T05:06:54.6931190Z t = dispatch_trace( 2026-03-11T05:06:54.6931300Z ^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6931570Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner 2026-03-11T05:06:54.6931620Z return disable_fn(*args, **kwargs) 2026-03-11T05:06:54.6931670Z ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6931950Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1256, in _fn 2026-03-11T05:06:54.6932010Z return fn(*args, **kwargs) 2026-03-11T05:06:54.6932050Z ^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6932400Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1550, in dispatch_trace 2026-03-11T05:06:54.6932520Z graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] 2026-03-11T05:06:54.6932570Z ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6932870Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 890, in trace 2026-03-11T05:06:54.6932920Z (self.create_arg(fn(*args)),), 2026-03-11T05:06:54.6932970Z ^^^^^^^^^ 2026-03-11T05:06:54.6933270Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 735, in flatten_fn 2026-03-11T05:06:54.6933320Z tree_out = root_fn(*tree_args) 2026-03-11T05:06:54.6933370Z ^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6933750Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1620, in wrapped 2026-03-11T05:06:54.6933830Z out = f(*tensors) # type:ignore[call-arg] 2026-03-11T05:06:54.6933870Z ^^^^^^^^^^^ 2026-03-11T05:06:54.6934190Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2096, in wrapped 2026-03-11T05:06:54.6934250Z return func(*fn_args, **fn_kwargs) 2026-03-11T05:06:54.6934300Z ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6934590Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 314, in _fn 2026-03-11T05:06:54.6934650Z result = fn(*args, **kwargs) 2026-03-11T05:06:54.6934690Z ^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6934980Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_refs/init.py", line 2695, in var_mean 2026-03-11T05:06:54.6935030Z m = mean(a, dim, keepdim) 2026-03-11T05:06:54.6935080Z ^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6935400Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_refs/init.py", line 2633, in mean 2026-03-11T05:06:54.6935450Z torch._check( 2026-03-11T05:06:54.6935720Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/init.py", line 1760, in _check 2026-03-11T05:06:54.6935910Z _check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type] 2026-03-11T05:06:54.6935970Z ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6936250Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/init.py", line 1742, in _check_with 2026-03-11T05:06:54.6936330Z raise error_type(message_evaluated) 2026-03-11T05:06:54.6936580Z RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: torch.int8 2026-03-11T05:06:54.6936590Z 2026-03-11T05:06:54.6941250Z The above exception was the direct cause of the following exception: 2026-03-11T05:06:54.6941270Z 2026-03-11T05:06:54.6941330Z Traceback (most recent call last): 2026-03-11T05:06:54.6941670Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 3370, in wrapper 2026-03-11T05:06:54.6941880Z method(*args, **kwargs) 2026-03-11T05:06:54.6942260Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 430, in instantiated_test 2026-03-11T05:06:54.6942320Z result = test(self, **param_kwargs) 2026-03-11T05:06:54.6942370Z ^^^^^^^^^^^^^^^^^^^^^^^^^^ 2026-03-11T05:06:54.6942700Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 1766, in wrapper 2026-03-11T05:06:54.6942750Z fn(*args, **kwargs) 2026-03-11T05:06:54.6943110Z File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 1177, in test_wrapper 2026-03-11T05:06:54.6943170Z raise e_tracked from e 2026-03-11T05:06:54.6943420Z Exception: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: torch.int8 2026-03-11T05:06:54.6943430Z 2026-03-11T05:06:54.6943750Z Caused by reference input at index 0: SampleInput(input=Tensor[size=(5, 5, 5), device="mps:0", dtype=torch.int8], args=(), kwargs={}, broadcasts_input=False, name='') 2026-03-11T05:06:54.6943760Z 2026-03-11T05:06:54.6943870Z To execute this test, run the following from the base repo dir: 2026-03-11T05:06:54.6944270Z PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=0 PYTORCH_TEST_WITH_SLOW=1 python test/test_ops.py TestCommonMPS.test_python_ref_executor__refs_var_mean_executor_aten_mps_int8

Root Cause

_ TestCommonMPS.test_python_ref_executor__refs_var_mean_executor_aten_mps_int8 _
2026-03-11T05:06:54.6927770Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 1701, in only_fn
2026-03-11T05:06:54.6927860Z     return fn(self, *args, **kwargs)
2026-03-11T05:06:54.6927910Z            ^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6928240Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 1788, in wrapper
2026-03-11T05:06:54.6928290Z     fn(*args, **kwargs)
2026-03-11T05:06:54.6928520Z   File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_ops.py", line 764, in test_python_ref_executor
2026-03-11T05:06:54.6928630Z     self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
2026-03-11T05:06:54.6928840Z   File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_ops.py", line 609, in _ref_test_helper
2026-03-11T05:06:54.6928940Z     ref_result = op(sample.input, *sample.args, **sample.kwargs)
2026-03-11T05:06:54.6929010Z                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6929340Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py", line 1251, in __call__
2026-03-11T05:06:54.6929400Z     return self.op(*args, **kwargs)
2026-03-11T05:06:54.6929500Z            ^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6929790Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_prims/executor.py", line 65, in _traced
2026-03-11T05:06:54.6929850Z     gm = make_fx(wrapped)(all_args)
2026-03-11T05:06:54.6929890Z          ^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6930260Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2848, in wrapped
2026-03-11T05:06:54.6930330Z     return make_fx_tracer.trace(f, *args)
2026-03-11T05:06:54.6930380Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6930700Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2749, in trace
2026-03-11T05:06:54.6930760Z     return self._trace_inner(f, *args)
2026-03-11T05:06:54.6930800Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931150Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2710, in _trace_inner
2026-03-11T05:06:54.6931190Z     t = dispatch_trace(
2026-03-11T05:06:54.6931300Z         ^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931570Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
2026-03-11T05:06:54.6931620Z     return disable_fn(*args, **kwargs)
2026-03-11T05:06:54.6931670Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931950Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1256, in _fn
2026-03-11T05:06:54.6932010Z     return fn(*args, **kwargs)
2026-03-11T05:06:54.6932050Z            ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6932400Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1550, in dispatch_trace
2026-03-11T05:06:54.6932520Z     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
2026-03-11T05:06:54.6932570Z             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6932870Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 890, in trace
2026-03-11T05:06:54.6932920Z     (self.create_arg(fn(*args)),),
2026-03-11T05:06:54.6932970Z                      ^^^^^^^^^
2026-03-11T05:06:54.6933270Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 735, in flatten_fn
2026-03-11T05:06:54.6933320Z     tree_out = root_fn(*tree_args)
2026-03-11T05:06:54.6933370Z                ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6933750Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1620, in wrapped
2026-03-11T05:06:54.6933830Z     out = f(*tensors)  # type:ignore[call-arg]
2026-03-11T05:06:54.6933870Z           ^^^^^^^^^^^
2026-03-11T05:06:54.6934190Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2096, in wrapped
2026-03-11T05:06:54.6934250Z     return func(*fn_args, **fn_kwargs)
2026-03-11T05:06:54.6934300Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6934590Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 314, in _fn
2026-03-11T05:06:54.6934650Z     result = fn(*args, **kwargs)
2026-03-11T05:06:54.6934690Z              ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6934980Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_refs/__init__.py", line 2695, in var_mean
2026-03-11T05:06:54.6935030Z     m = mean(a, dim, keepdim)
2026-03-11T05:06:54.6935080Z         ^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6935400Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_refs/__init__.py", line 2633, in mean
2026-03-11T05:06:54.6935450Z     torch._check(
2026-03-11T05:06:54.6935720Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/__init__.py", line 1760, in _check
2026-03-11T05:06:54.6935910Z     _check_with(RuntimeError, cond, message)  # pyrefly: ignore [bad-argument-type]
2026-03-11T05:06:54.6935970Z     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6936250Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/__init__.py", line 1742, in _check_with
2026-03-11T05:06:54.6936330Z     raise error_type(message_evaluated)
2026-03-11T05:06:54.6936580Z RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: torch.int8
2026-03-11T05:06:54.6936590Z 
2026-03-11T05:06:54.6941250Z The above exception was the direct cause of the following exception:
2026-03-11T05:06:54.6941270Z 
2026-03-11T05:06:54.6941330Z Traceback (most recent call last):
2026-03-11T05:06:54.6941670Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 3370, in wrapper
2026-03-11T05:06:54.6941880Z     method(*args, **kwargs)
2026-03-11T05:06:54.6942260Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 430, in instantiated_test
2026-03-11T05:06:54.6942320Z     result = test(self, **param_kwargs)
2026-03-11T05:06:54.6942370Z              ^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6942700Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 1766, in wrapper
2026-03-11T05:06:54.6942750Z     fn(*args, **kwargs)
2026-03-11T05:06:54.6943110Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 1177, in test_wrapper
2026-03-11T05:06:54.6943170Z     raise e_tracked from e
2026-03-11T05:06:54.6943420Z Exception: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: torch.int8
2026-03-11T05:06:54.6943430Z 
2026-03-11T05:06:54.6943750Z Caused by reference input at index 0: SampleInput(input=Tensor[size=(5, 5, 5), device="mps:0", dtype=torch.int8], args=(), kwargs={}, broadcasts_input=False, name='')
2026-03-11T05:06:54.6943760Z 
2026-03-11T05:06:54.6943870Z To execute this test, run the following from the base repo dir:
2026-03-11T05:06:54.6944270Z     PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=0 PYTORCH_TEST_WITH_SLOW=1 python test/test_ops.py TestCommonMPS.test_python_ref_executor__refs_var_mean_executor_aten_mps_int8

Fix Action

Fix / Workaround

_ TestCommonMPS.test_python_ref_executor__refs_var_mean_executor_aten_mps_int8 _
2026-03-11T05:06:54.6927770Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 1701, in only_fn
2026-03-11T05:06:54.6927860Z     return fn(self, *args, **kwargs)
2026-03-11T05:06:54.6927910Z            ^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6928240Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 1788, in wrapper
2026-03-11T05:06:54.6928290Z     fn(*args, **kwargs)
2026-03-11T05:06:54.6928520Z   File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_ops.py", line 764, in test_python_ref_executor
2026-03-11T05:06:54.6928630Z     self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
2026-03-11T05:06:54.6928840Z   File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_ops.py", line 609, in _ref_test_helper
2026-03-11T05:06:54.6928940Z     ref_result = op(sample.input, *sample.args, **sample.kwargs)
2026-03-11T05:06:54.6929010Z                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6929340Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py", line 1251, in __call__
2026-03-11T05:06:54.6929400Z     return self.op(*args, **kwargs)
2026-03-11T05:06:54.6929500Z            ^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6929790Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_prims/executor.py", line 65, in _traced
2026-03-11T05:06:54.6929850Z     gm = make_fx(wrapped)(all_args)
2026-03-11T05:06:54.6929890Z          ^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6930260Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2848, in wrapped
2026-03-11T05:06:54.6930330Z     return make_fx_tracer.trace(f, *args)
2026-03-11T05:06:54.6930380Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6930700Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2749, in trace
2026-03-11T05:06:54.6930760Z     return self._trace_inner(f, *args)
2026-03-11T05:06:54.6930800Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931150Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2710, in _trace_inner
2026-03-11T05:06:54.6931190Z     t = dispatch_trace(
2026-03-11T05:06:54.6931300Z         ^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931570Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
2026-03-11T05:06:54.6931620Z     return disable_fn(*args, **kwargs)
2026-03-11T05:06:54.6931670Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931950Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1256, in _fn
2026-03-11T05:06:54.6932010Z     return fn(*args, **kwargs)
2026-03-11T05:06:54.6932050Z            ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6932400Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1550, in dispatch_trace
2026-03-11T05:06:54.6932520Z     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
2026-03-11T05:06:54.6932570Z             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6932870Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 890, in trace
2026-03-11T05:06:54.6932920Z     (self.create_arg(fn(*args)),),
2026-03-11T05:06:54.6932970Z                      ^^^^^^^^^
2026-03-11T05:06:54.6933270Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 735, in flatten_fn
2026-03-11T05:06:54.6933320Z     tree_out = root_fn(*tree_args)
2026-03-11T05:06:54.6933370Z                ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6933750Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1620, in wrapped
2026-03-11T05:06:54.6933830Z     out = f(*tensors)  # type:ignore[call-arg]
2026-03-11T05:06:54.6933870Z           ^^^^^^^^^^^
2026-03-11T05:06:54.6934190Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2096, in wrapped
2026-03-11T05:06:54.6934250Z     return func(*fn_args, **fn_kwargs)
2026-03-11T05:06:54.6934300Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6934590Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 314, in _fn
2026-03-11T05:06:54.6934650Z     result = fn(*args, **kwargs)
2026-03-11T05:06:54.6934690Z              ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6934980Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_refs/__init__.py", line 2695, in var_mean
2026-03-11T05:06:54.6935030Z     m = mean(a, dim, keepdim)
2026-03-11T05:06:54.6935080Z         ^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6935400Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_refs/__init__.py", line 2633, in mean
2026-03-11T05:06:54.6935450Z     torch._check(
2026-03-11T05:06:54.6935720Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/__init__.py", line 1760, in _check
2026-03-11T05:06:54.6935910Z     _check_with(RuntimeError, cond, message)  # pyrefly: ignore [bad-argument-type]
2026-03-11T05:06:54.6935970Z     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6936250Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/__init__.py", line 1742, in _check_with
2026-03-11T05:06:54.6936330Z     raise error_type(message_evaluated)
2026-03-11T05:06:54.6936580Z RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: torch.int8
2026-03-11T05:06:54.6936590Z 
2026-03-11T05:06:54.6941250Z The above exception was the direct cause of the following exception:
2026-03-11T05:06:54.6941270Z 
2026-03-11T05:06:54.6941330Z Traceback (most recent call last):
2026-03-11T05:06:54.6941670Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 3370, in wrapper
2026-03-11T05:06:54.6941880Z     method(*args, **kwargs)
2026-03-11T05:06:54.6942260Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 430, in instantiated_test
2026-03-11T05:06:54.6942320Z     result = test(self, **param_kwargs)
2026-03-11T05:06:54.6942370Z              ^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6942700Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 1766, in wrapper
2026-03-11T05:06:54.6942750Z     fn(*args, **kwargs)
2026-03-11T05:06:54.6943110Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 1177, in test_wrapper
2026-03-11T05:06:54.6943170Z     raise e_tracked from e
2026-03-11T05:06:54.6943420Z Exception: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: torch.int8
2026-03-11T05:06:54.6943430Z 
2026-03-11T05:06:54.6943750Z Caused by reference input at index 0: SampleInput(input=Tensor[size=(5, 5, 5), device="mps:0", dtype=torch.int8], args=(), kwargs={}, broadcasts_input=False, name='')
2026-03-11T05:06:54.6943760Z 
2026-03-11T05:06:54.6943870Z To execute this test, run the following from the base repo dir:
2026-03-11T05:06:54.6944270Z     PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=0 PYTORCH_TEST_WITH_SLOW=1 python test/test_ops.py TestCommonMPS.test_python_ref_executor__refs_var_mean_executor_aten_mps_int8

Code Example

_ TestCommonMPS.test_python_ref_executor__refs_var_mean_executor_aten_mps_int8 _
2026-03-11T05:06:54.6927770Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 1701, in only_fn
2026-03-11T05:06:54.6927860Z     return fn(self, *args, **kwargs)
2026-03-11T05:06:54.6927910Z            ^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6928240Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 1788, in wrapper
2026-03-11T05:06:54.6928290Z     fn(*args, **kwargs)
2026-03-11T05:06:54.6928520Z   File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_ops.py", line 764, in test_python_ref_executor
2026-03-11T05:06:54.6928630Z     self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
2026-03-11T05:06:54.6928840Z   File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_ops.py", line 609, in _ref_test_helper
2026-03-11T05:06:54.6928940Z     ref_result = op(sample.input, *sample.args, **sample.kwargs)
2026-03-11T05:06:54.6929010Z                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6929340Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py", line 1251, in __call__
2026-03-11T05:06:54.6929400Z     return self.op(*args, **kwargs)
2026-03-11T05:06:54.6929500Z            ^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6929790Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_prims/executor.py", line 65, in _traced
2026-03-11T05:06:54.6929850Z     gm = make_fx(wrapped)(all_args)
2026-03-11T05:06:54.6929890Z          ^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6930260Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2848, in wrapped
2026-03-11T05:06:54.6930330Z     return make_fx_tracer.trace(f, *args)
2026-03-11T05:06:54.6930380Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6930700Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2749, in trace
2026-03-11T05:06:54.6930760Z     return self._trace_inner(f, *args)
2026-03-11T05:06:54.6930800Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931150Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2710, in _trace_inner
2026-03-11T05:06:54.6931190Z     t = dispatch_trace(
2026-03-11T05:06:54.6931300Z         ^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931570Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
2026-03-11T05:06:54.6931620Z     return disable_fn(*args, **kwargs)
2026-03-11T05:06:54.6931670Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931950Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1256, in _fn
2026-03-11T05:06:54.6932010Z     return fn(*args, **kwargs)
2026-03-11T05:06:54.6932050Z            ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6932400Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1550, in dispatch_trace
2026-03-11T05:06:54.6932520Z     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
2026-03-11T05:06:54.6932570Z             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6932870Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 890, in trace
2026-03-11T05:06:54.6932920Z     (self.create_arg(fn(*args)),),
2026-03-11T05:06:54.6932970Z                      ^^^^^^^^^
2026-03-11T05:06:54.6933270Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 735, in flatten_fn
2026-03-11T05:06:54.6933320Z     tree_out = root_fn(*tree_args)
2026-03-11T05:06:54.6933370Z                ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6933750Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1620, in wrapped
2026-03-11T05:06:54.6933830Z     out = f(*tensors)  # type:ignore[call-arg]
2026-03-11T05:06:54.6933870Z           ^^^^^^^^^^^
2026-03-11T05:06:54.6934190Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2096, in wrapped
2026-03-11T05:06:54.6934250Z     return func(*fn_args, **fn_kwargs)
2026-03-11T05:06:54.6934300Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6934590Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 314, in _fn
2026-03-11T05:06:54.6934650Z     result = fn(*args, **kwargs)
2026-03-11T05:06:54.6934690Z              ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6934980Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_refs/__init__.py", line 2695, in var_mean
2026-03-11T05:06:54.6935030Z     m = mean(a, dim, keepdim)
2026-03-11T05:06:54.6935080Z         ^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6935400Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_refs/__init__.py", line 2633, in mean
2026-03-11T05:06:54.6935450Z     torch._check(
2026-03-11T05:06:54.6935720Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/__init__.py", line 1760, in _check
2026-03-11T05:06:54.6935910Z     _check_with(RuntimeError, cond, message)  # pyrefly: ignore [bad-argument-type]
2026-03-11T05:06:54.6935970Z     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6936250Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/__init__.py", line 1742, in _check_with
2026-03-11T05:06:54.6936330Z     raise error_type(message_evaluated)
2026-03-11T05:06:54.6936580Z RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: torch.int8
2026-03-11T05:06:54.6936590Z 
2026-03-11T05:06:54.6941250Z The above exception was the direct cause of the following exception:
2026-03-11T05:06:54.6941270Z 
2026-03-11T05:06:54.6941330Z Traceback (most recent call last):
2026-03-11T05:06:54.6941670Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 3370, in wrapper
2026-03-11T05:06:54.6941880Z     method(*args, **kwargs)
2026-03-11T05:06:54.6942260Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 430, in instantiated_test
2026-03-11T05:06:54.6942320Z     result = test(self, **param_kwargs)
2026-03-11T05:06:54.6942370Z              ^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6942700Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 1766, in wrapper
2026-03-11T05:06:54.6942750Z     fn(*args, **kwargs)
2026-03-11T05:06:54.6943110Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 1177, in test_wrapper
2026-03-11T05:06:54.6943170Z     raise e_tracked from e
2026-03-11T05:06:54.6943420Z Exception: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: torch.int8
2026-03-11T05:06:54.6943430Z 
2026-03-11T05:06:54.6943750Z Caused by reference input at index 0: SampleInput(input=Tensor[size=(5, 5, 5), device="mps:0", dtype=torch.int8], args=(), kwargs={}, broadcasts_input=False, name='')
2026-03-11T05:06:54.6943760Z 
2026-03-11T05:06:54.6943870Z To execute this test, run the following from the base repo dir:
2026-03-11T05:06:54.6944270Z     PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=0 PYTORCH_TEST_WITH_SLOW=1 python test/test_ops.py TestCommonMPS.test_python_ref_executor__refs_var_mean_executor_aten_mps_int8
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When I generalize test_ops.py TestCommon test_python_ref_executor() function for accelerator, I got this failure in mps test in CI. This could be an mps limitation and I will skip it in op_db in my PR.

_ TestCommonMPS.test_python_ref_executor__refs_var_mean_executor_aten_mps_int8 _
2026-03-11T05:06:54.6927770Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 1701, in only_fn
2026-03-11T05:06:54.6927860Z     return fn(self, *args, **kwargs)
2026-03-11T05:06:54.6927910Z            ^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6928240Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 1788, in wrapper
2026-03-11T05:06:54.6928290Z     fn(*args, **kwargs)
2026-03-11T05:06:54.6928520Z   File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_ops.py", line 764, in test_python_ref_executor
2026-03-11T05:06:54.6928630Z     self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
2026-03-11T05:06:54.6928840Z   File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_ops.py", line 609, in _ref_test_helper
2026-03-11T05:06:54.6928940Z     ref_result = op(sample.input, *sample.args, **sample.kwargs)
2026-03-11T05:06:54.6929010Z                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6929340Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py", line 1251, in __call__
2026-03-11T05:06:54.6929400Z     return self.op(*args, **kwargs)
2026-03-11T05:06:54.6929500Z            ^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6929790Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_prims/executor.py", line 65, in _traced
2026-03-11T05:06:54.6929850Z     gm = make_fx(wrapped)(all_args)
2026-03-11T05:06:54.6929890Z          ^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6930260Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2848, in wrapped
2026-03-11T05:06:54.6930330Z     return make_fx_tracer.trace(f, *args)
2026-03-11T05:06:54.6930380Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6930700Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2749, in trace
2026-03-11T05:06:54.6930760Z     return self._trace_inner(f, *args)
2026-03-11T05:06:54.6930800Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931150Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2710, in _trace_inner
2026-03-11T05:06:54.6931190Z     t = dispatch_trace(
2026-03-11T05:06:54.6931300Z         ^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931570Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
2026-03-11T05:06:54.6931620Z     return disable_fn(*args, **kwargs)
2026-03-11T05:06:54.6931670Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6931950Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1256, in _fn
2026-03-11T05:06:54.6932010Z     return fn(*args, **kwargs)
2026-03-11T05:06:54.6932050Z            ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6932400Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1550, in dispatch_trace
2026-03-11T05:06:54.6932520Z     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
2026-03-11T05:06:54.6932570Z             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6932870Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 890, in trace
2026-03-11T05:06:54.6932920Z     (self.create_arg(fn(*args)),),
2026-03-11T05:06:54.6932970Z                      ^^^^^^^^^
2026-03-11T05:06:54.6933270Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 735, in flatten_fn
2026-03-11T05:06:54.6933320Z     tree_out = root_fn(*tree_args)
2026-03-11T05:06:54.6933370Z                ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6933750Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1620, in wrapped
2026-03-11T05:06:54.6933830Z     out = f(*tensors)  # type:ignore[call-arg]
2026-03-11T05:06:54.6933870Z           ^^^^^^^^^^^
2026-03-11T05:06:54.6934190Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2096, in wrapped
2026-03-11T05:06:54.6934250Z     return func(*fn_args, **fn_kwargs)
2026-03-11T05:06:54.6934300Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6934590Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 314, in _fn
2026-03-11T05:06:54.6934650Z     result = fn(*args, **kwargs)
2026-03-11T05:06:54.6934690Z              ^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6934980Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_refs/__init__.py", line 2695, in var_mean
2026-03-11T05:06:54.6935030Z     m = mean(a, dim, keepdim)
2026-03-11T05:06:54.6935080Z         ^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6935400Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/_refs/__init__.py", line 2633, in mean
2026-03-11T05:06:54.6935450Z     torch._check(
2026-03-11T05:06:54.6935720Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/__init__.py", line 1760, in _check
2026-03-11T05:06:54.6935910Z     _check_with(RuntimeError, cond, message)  # pyrefly: ignore [bad-argument-type]
2026-03-11T05:06:54.6935970Z     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6936250Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/__init__.py", line 1742, in _check_with
2026-03-11T05:06:54.6936330Z     raise error_type(message_evaluated)
2026-03-11T05:06:54.6936580Z RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: torch.int8
2026-03-11T05:06:54.6936590Z 
2026-03-11T05:06:54.6941250Z The above exception was the direct cause of the following exception:
2026-03-11T05:06:54.6941270Z 
2026-03-11T05:06:54.6941330Z Traceback (most recent call last):
2026-03-11T05:06:54.6941670Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 3370, in wrapper
2026-03-11T05:06:54.6941880Z     method(*args, **kwargs)
2026-03-11T05:06:54.6942260Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 430, in instantiated_test
2026-03-11T05:06:54.6942320Z     result = test(self, **param_kwargs)
2026-03-11T05:06:54.6942370Z              ^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-03-11T05:06:54.6942700Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py", line 1766, in wrapper
2026-03-11T05:06:54.6942750Z     fn(*args, **kwargs)
2026-03-11T05:06:54.6943110Z   File "/Users/ec2-user/runner/_work/_temp/venv-3.12-1773200043/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py", line 1177, in test_wrapper
2026-03-11T05:06:54.6943170Z     raise e_tracked from e
2026-03-11T05:06:54.6943420Z Exception: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: torch.int8
2026-03-11T05:06:54.6943430Z 
2026-03-11T05:06:54.6943750Z Caused by reference input at index 0: SampleInput(input=Tensor[size=(5, 5, 5), device="mps:0", dtype=torch.int8], args=(), kwargs={}, broadcasts_input=False, name='')
2026-03-11T05:06:54.6943760Z 
2026-03-11T05:06:54.6943870Z To execute this test, run the following from the base repo dir:
2026-03-11T05:06:54.6944270Z     PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=0 PYTORCH_TEST_WITH_SLOW=1 python test/test_ops.py TestCommonMPS.test_python_ref_executor__refs_var_mean_executor_aten_mps_int8

Versions

See ghstack #176691

cc @mruberry @ezyang @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

Fix Plan

Problem Summary

The problem is a RuntimeError when calling mean() on a tensor with dtype=torch.int8 on an MPS device.

Root Cause Analysis

The root cause is that mean() cannot infer the output dtype when the input dtype is torch.int8.

Fix Plan

To fix this issue, we need to ensure that the input dtype to mean() is a floating point or complex dtype.

Step 1: Cast the input tensor to a floating point dtype

We can cast the input tensor to torch.float32 before calling mean().

import torch

# Input tensor with dtype=torch.int8
input_tensor = torch.randint(0, 10, (5, 5, 5), dtype=torch.int8)

# Cast the input tensor to torch.float32
cast_tensor = input_tensor.to(torch.float32)

# Call mean() on the cast tensor
mean_value = torch.mean(cast_tensor)

Step 2: Update the test code to cast the input tensor

We need to update the test code to cast the input tensor to torch.float32 before calling mean().

# Update the test code to cast the input tensor
def test_python_ref_executor__refs_var_mean_executor_aten_mps_int8(self):
    # Cast the input tensor to torch.float32
    cast_tensor = self.input_tensor.to(torch.float32)

    # Call mean() on the cast tensor
    mean_value = torch.mean(cast_tensor)
    self.assertEqual(mean_value, torch.mean(self.input_tensor))

Verification

To verify that the fix worked, we can run the test again and check that the RuntimeError is no longer raised.

Extra Tips

  • Make sure to update the test code to cast the input tensor to torch.float32 before calling mean().
  • If you are using a specific version

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix mps test got Exception: mean(): could not infer output dtype. [2 comments, 2 participants]