pytorch - ✅(Solved) Fix propagate_single_input_strategy copies input tensor_meta to output, producing wrong dtype for ops like convert_element_type [1 pull requests, 3 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178091Fetched 2026-04-08 01:16:46
View on GitHub
Comments
3
Participants
3
Timeline
53
Reactions
0
Author
Assignees
Timeline (top)
mentioned ×22subscribed ×22labeled ×4commented ×3

Fix Action

Fixed

PR fix notes

PR #386: Unify sharding rule registration into a single mechanism

Description (problem / solution / changelog)

propagation_rules.py had two separate rule registration systems with separate code paths: _op_rules (via @register_rule) for self-contained rules taking (mesh, specs), and _op_partial_rules (via @register_opschema_rule) for rules taking (mesh, op_schema) that went through a shared post-processing pipeline. This consolidates them into a single _op_rules dict with a single @register_rule decorator. All rules now take (mesh, op_schema) and go through one code path in get_placement_options.

The main subtlety is propagate_tensor_meta: the old _op_rules path skipped it entirely, while the old _op_partial_rules path always ran it. With the unified path, propagate_tensor_meta now always runs — which is the correct behavior, since some rules (e.g. convert_element_type via PyTorch's propagate_single_input_strategy) set output tensor_meta with the wrong dtype (copying the input's dtype) and rely on propagate_tensor_meta to correct it. The only exception is operator.getitem, which is skipped because its input structure (a tuple indexed by position) doesn't match what propagate_tensor_meta expects.

Start reading from propagation_rules.py (registration changes), then placement_options.py (unified code path).

Also adds unit tests for the key utility functions: remove_invalid_configs, keep_unique_configs, fill_missing_redistribute_cost, and propagate_tensor_meta.

Authored with Claude.

Changed files

  • autoparallel/shardings/placement_options.py (modified, +11/-20)
  • autoparallel/shardings/propagation_rules.py (modified, +43/-63)
  • examples/native_ds3/moe_ops.py (modified, +5/-5)
  • tests/test_placement_options_utils.py (added, +217/-0)

Code Example

OpSpec(
      output_specs=DTensorSpec(
          mesh=first_input_strategy.mesh,
          placements=strategy.output_spec.placements,
          tensor_meta=strategy.output_spec.tensor_meta,  # <-- copies input's tensor_meta
      ),
      ...
  )
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

propagate_single_input_strategy in torch/distributed/tensor/_ops/_tensor_ops.py copies the input's tensor_meta directly to the output spec:

  OpSpec(
      output_specs=DTensorSpec(
          mesh=first_input_strategy.mesh,
          placements=strategy.output_spec.placements,
          tensor_meta=strategy.output_spec.tensor_meta,  # <-- copies input's tensor_meta
      ),
      ...
  )

For ops that change dtype or shape (like prims.convert_element_type), this produces incorrect output tensor_meta. For example, convert_element_type(bfloat16 → float32) returns an output strategy with tensor_meta.dtype = bfloat16 instead of float32.

In autoparallel, we work around this by running propagate_tensor_meta after every rule, which re-executes the op on fake tensors and overwrites the incorrect tensor_meta. But this is fragile — if the caller checks for the presence of tensor_meta and assumes it's correct (which is reasonable), they'll see the wrong dtype.

Concretely, this caused a downstream failure: view_as_complex received a strategy with bfloat16 tensor_meta (propagated from the incorrect convert_element_type output), created a bfloat16 fake tensor during strategy propagation, and hit RuntimeError: view_as_complex is only supported for half, float and double tensors.

A possible fix would be to leave tensor_meta=None on the output spec and let the caller fill it in, or to actually compute the correct output tensor_meta by running the op on a meta tensor.

Versions

PyTorch nightly

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan

extent analysis

Fix Plan

To fix the issue, we need to modify the propagate_single_input_strategy function to correctly compute the output tensor_meta. Here are the steps:

  • Modify the OpSpec creation to leave tensor_meta=None and let the caller fill it in:
OpSpec(
    output_specs=DTensorSpec(
        mesh=first_input_strategy.mesh,
        placements=strategy.output_spec.placements,
        tensor_meta=None,  # <-- leave tensor_meta as None
    ),
    ...
)
  • Alternatively, compute the correct output tensor_meta by running the op on a meta tensor:
meta_tensor = torch.meta_tensor(first_input_strategy.tensor_meta)
output_meta = op(meta_tensor)
OpSpec(
    output_specs=DTensorSpec(
        mesh=first_input_strategy.mesh,
        placements=strategy.output_spec.placements,
        tensor_meta=output_meta.tensor_meta,  # <-- compute correct tensor_meta
    ),
    ...
)

Verification

To verify the fix, test the propagate_single_input_strategy function with ops that change dtype or shape, such as prims.convert_element_type. Check that the output tensor_meta is correctly computed.

Extra Tips

  • Make sure to update the documentation and tests to reflect the changes.
  • Consider adding a test case to verify that the fix works correctly for the view_as_complex op.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING