pytorch - ✅(Solved) Fix Way to have DTensor ops go through torch_function rather than go directly through torch_dispatch? [1 pull requests, 7 comments, 4 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177059Fetched 2026-04-08 00:22:23
View on GitHub
Comments
7
Participants
4
Timeline
66
Reactions
0
Timeline (top)
mentioned ×22subscribed ×22unsubscribed ×12commented ×7

I am working on TP support for MXFP8WeightWrapperTensor subclass used for MXFP8 training in torchao: https://github.com/pytorch/ao/pull/3985

The basic design of this tensor subclass is the following:

  • In __torch_function__ intercept linear ops or mm.default to execute an autograd function in its place, which does differentiable mxfp8 quantization + scaled_mm.
  • For all other ops, continue to __torch_dispatch__ and behave like a regular tensor.

For TP, we wrap weights in the tensor subclass first, then wrap in DTensor when TP is applied, so it looks like:

DTensor(MXFP8WeightWrapperTensor(...))

When DTensor is not applied, the op override for linear in __torch_function__ works as expected. Logs:

(Pdb) toy_model_fp8(x_bf16)
[TORCH_FUNCTION] linear
mx_mm forward
[TORCH_FUNCTION] linear
mx_mm forward
[TORCH_FUNCTION] linear
mx_mm forward

However, when DTensor wraps the subclass, logging shows that linear is decomposed into __get__, t.default, mm.default, and the tranpose + mm ops are skipping __torch_function__ and going directly to __torch_dispatch__, which prevents us from intercepting our call and calling an autograd function in its place. Logs:

[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default
[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default
[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default

...
tp out comparison not close, shapes: tp=torch.Size([2, 128, 64]), global=torch.Size([2, 128, 64])
tp_out stats: min=-0.134765625, max=0.150390625, mean=0.000507354736328125
global_out stats: min=-0.1318359375, max=0.150390625, mean=0.000659942626953125
diff stats: min=0.0, max=0.010009765625, mean=0.0016021728515625
tp_out SQNR: 23.625

I don't know what __get__ is doing, but the fact that some linear ops are skipping __torch_function__ prevents us from intercepting and thus causes regular bf16 mm to be executed, which is not the desired behavior.

We can't intercept in torch dispatch because autograd runs at the torch_function level, it will not capture the backward pass we are defining in the custom autograd func.

Is there a way we can have DTensor send these mm ops through torch_function so we can intercept?

Root Cause

We can't intercept in torch dispatch because autograd runs at the torch_function level, it will not capture the backward pass we are defining in the custom autograd func.

Fix Action

Fix / Workaround

The basic design of this tensor subclass is the following:

  • In __torch_function__ intercept linear ops or mm.default to execute an autograd function in its place, which does differentiable mxfp8 quantization + scaled_mm.
  • For all other ops, continue to __torch_dispatch__ and behave like a regular tensor.

However, when DTensor wraps the subclass, logging shows that linear is decomposed into __get__, t.default, mm.default, and the tranpose + mm ops are skipping __torch_function__ and going directly to __torch_dispatch__, which prevents us from intercepting our call and calling an autograd function in its place. Logs:

[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default
[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default
[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default

PR fix notes

PR #3985: [not for land] [mxfp8 training] fix TP bug

Description (problem / solution / changelog)

(No description)

Changed files

  • test/prototype/mx_formats/test_mx_dtensor.py (modified, +32/-17)
  • torchao/prototype/moe_training/tensor.py (modified, +44/-13)
  • torchao/prototype/mx_formats/mx_linear.py (modified, +1/-1)
  • torchao/prototype/mx_formats/mx_tensor.py (modified, +2/-1)
  • torchao/testing/training/dtensor_utils.py (modified, +66/-23)

Code Example

(Pdb) toy_model_fp8(x_bf16)
[TORCH_FUNCTION] linear
mx_mm forward
[TORCH_FUNCTION] linear
mx_mm forward
[TORCH_FUNCTION] linear
mx_mm forward

---

[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default
[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default
[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default

...
tp out comparison not close, shapes: tp=torch.Size([2, 128, 64]), global=torch.Size([2, 128, 64])
tp_out stats: min=-0.134765625, max=0.150390625, mean=0.000507354736328125
global_out stats: min=-0.1318359375, max=0.150390625, mean=0.000659942626953125
diff stats: min=0.0, max=0.010009765625, mean=0.0016021728515625
tp_out SQNR: 23.625
RAW_BUFFERClick to expand / collapse

Context

Summary

I am working on TP support for MXFP8WeightWrapperTensor subclass used for MXFP8 training in torchao: https://github.com/pytorch/ao/pull/3985

The basic design of this tensor subclass is the following:

  • In __torch_function__ intercept linear ops or mm.default to execute an autograd function in its place, which does differentiable mxfp8 quantization + scaled_mm.
  • For all other ops, continue to __torch_dispatch__ and behave like a regular tensor.

For TP, we wrap weights in the tensor subclass first, then wrap in DTensor when TP is applied, so it looks like:

DTensor(MXFP8WeightWrapperTensor(...))

When DTensor is not applied, the op override for linear in __torch_function__ works as expected. Logs:

(Pdb) toy_model_fp8(x_bf16)
[TORCH_FUNCTION] linear
mx_mm forward
[TORCH_FUNCTION] linear
mx_mm forward
[TORCH_FUNCTION] linear
mx_mm forward

However, when DTensor wraps the subclass, logging shows that linear is decomposed into __get__, t.default, mm.default, and the tranpose + mm ops are skipping __torch_function__ and going directly to __torch_dispatch__, which prevents us from intercepting our call and calling an autograd function in its place. Logs:

[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default
[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default
[TORCH_FUNCTION] __get__
[TORCH_DISPATCH]:  mm.default
[TORCH_DISPATCH]:  t.default

...
tp out comparison not close, shapes: tp=torch.Size([2, 128, 64]), global=torch.Size([2, 128, 64])
tp_out stats: min=-0.134765625, max=0.150390625, mean=0.000507354736328125
global_out stats: min=-0.1318359375, max=0.150390625, mean=0.000659942626953125
diff stats: min=0.0, max=0.010009765625, mean=0.0016021728515625
tp_out SQNR: 23.625

I don't know what __get__ is doing, but the fact that some linear ops are skipping __torch_function__ prevents us from intercepting and thus causes regular bf16 mm to be executed, which is not the desired behavior.

We can't intercept in torch dispatch because autograd runs at the torch_function level, it will not capture the backward pass we are defining in the custom autograd func.

Is there a way we can have DTensor send these mm ops through torch_function so we can intercept?

Repro

  • On b200 devgpu, check out this torchao PR: https://github.com/pytorch/ao/pull/3985
  • Run torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mx_dtensor.py
  • torch version: 2.12.0.dev20260305+cu130

extent analysis

Fix Plan

Modify DTensor to use __torch_function__ for mm.default and t.default

We need to modify the DTensor class to use __torch_function__ for mm.default and t.default instead of __torch_dispatch__. This will allow us to intercept the linear ops and execute the custom autograd function.

class DTensor:
    # ...

    def __torch_function__(self, func, types, args, kwargs):
        if func in (torch.mm.default, torch.t.default):
            # Intercept mm.default and t.default
            return self.mm_default_autograd(*args, **kwargs)
        elif func == torch.linear:
            # Intercept linear
            return self.linear_autograd(*args, **kwargs)
        else:
            # Continue to __torch_dispatch__ for other ops
            return super().__torch_function__(func, types, args, kwargs)

    def mm_default_autograd(self, *args, **kwargs):
        # Custom autograd function for mm.default
        # ...
        return torch.mm(*args, **kwargs)

    def linear_autograd(self, *args, **kwargs):
        # Custom autograd function for linear
        # ...
        return torch.linear(*args, **kwargs)

Update MXFP8WeightWrapperTensor to use DTensor's new __torch_function__ behavior

We need to update the MXFP8WeightWrapperTensor class to use DTensor's new __torch_function__ behavior.

class MXFP8WeightWrapperTensor:
    # ...

    def __torch_function__(self, func, types, args, kwargs):
        # Use DTensor's new __torch_function__ behavior
        return DTensor.__torch_function__(self, func, types, args, kwargs)

Verification

  1. Run the test test/prototype/mx_formats/test_mx_dtensor.py with the

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix Way to have DTensor ops go through torch_function rather than go directly through torch_dispatch? [1 pull requests, 7 comments, 4 participants]