pytorch - ✅(Solved) Fix `torch.compile` produces wrong results when fusing `adaptive_avg_pool2d` with `flatten + sum` [1 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

This is a silent correctness bug — no error is raised, but the fused kernel computes wrong output values. The adaptive_avg_pool2d kernel alone produces correct results; the error is introduced during fusion with the subsequent reduction. 1. Pool alone is correct; fusion introduces the error:

Fix Action

Fixed

PR fix notes

PR #180898: [inductor] Fix bug with contiguous checks and comprehensive_padding

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #180898

Fixes #180848

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Changed files

  • test/inductor/test_torchinductor.py (modified, +16/-0)
  • torch/_inductor/ir.py (modified, +26/-3)

Code Example

import torch

def f(x):
    y = torch.nn.functional.adaptive_avg_pool2d(x, 7)
    return y.flatten(1).sum(dim=-1)

torch.manual_seed(42)
x = torch.randn(2, 33, 8, 8, dtype=torch.float64, device='cuda')

# Eager — correct
print(f"Eager: {f(x).tolist()}")
# [-52.27041819362077, 30.75847231895647]

# Inductor — wrong
torch._dynamo.reset()
compiled = torch.compile(f, backend='inductor')(x)
print(f"Inductor: {compiled.tolist()}")
# [-52.27041819362077, 34.325837206874446]  ← second element is wrong

print(f"Max diff: {(f(x) - compiled).abs().max().item()}")
# 3.567365  ← catastrophic for fp64

---

# These are all correct (diff ≈ 0):
torch.compile(lambda x: F.adaptive_avg_pool2d(x, 7))(x)           # pool only
torch.compile(lambda x: F.adaptive_avg_pool2d(x, 7).sum())(x)     # pool + sum
torch.compile(lambda x: F.adaptive_avg_pool2d(x, 7).flatten(1))(x) # pool + flatten

# This is WRONG (diff ≈ 3.6):
torch.compile(lambda x: F.adaptive_avg_pool2d(x, 7).flatten(1).sum(-1))(x)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Describe the bug

torch.compile(backend='inductor') silently produces incorrect results when adaptive_avg_pool2d is fused with downstream flatten + sum operations. The bug triggers when the channel dimension is not a power of 2 (e.g., C=33, 65, 129).

This is a silent correctness bug — no error is raised, but the fused kernel computes wrong output values. The adaptive_avg_pool2d kernel alone produces correct results; the error is introduced during fusion with the subsequent reduction.

Minimal repro

import torch

def f(x):
    y = torch.nn.functional.adaptive_avg_pool2d(x, 7)
    return y.flatten(1).sum(dim=-1)

torch.manual_seed(42)
x = torch.randn(2, 33, 8, 8, dtype=torch.float64, device='cuda')

# Eager — correct
print(f"Eager: {f(x).tolist()}")
# [-52.27041819362077, 30.75847231895647]

# Inductor — wrong
torch._dynamo.reset()
compiled = torch.compile(f, backend='inductor')(x)
print(f"Inductor: {compiled.tolist()}")
# [-52.27041819362077, 34.325837206874446]  ← second element is wrong

print(f"Max diff: {(f(x) - compiled).abs().max().item()}")
# 3.567365  ← catastrophic for fp64

Key observations

1. Pool alone is correct; fusion introduces the error:

# These are all correct (diff ≈ 0):
torch.compile(lambda x: F.adaptive_avg_pool2d(x, 7))(x)           # pool only
torch.compile(lambda x: F.adaptive_avg_pool2d(x, 7).sum())(x)     # pool + sum
torch.compile(lambda x: F.adaptive_avg_pool2d(x, 7).flatten(1))(x) # pool + flatten

# This is WRONG (diff ≈ 3.6):
torch.compile(lambda x: F.adaptive_avg_pool2d(x, 7).flatten(1).sum(-1))(x)

2. Channel dimension determines whether the bug triggers:

C (channels)Power of 2?Diff
32yes0
33no7.37
48no (but 16×3)0
64yes0
65no4.71
128yes0
129no6.45

3. Spatial dimensions also matter:

With C=33, output_size=7:

  • H=7 (no resize needed): diff=0 ✅
  • H=8 (8→7, non-integer stride): diff=7.37 ❌
  • H=14 (14→7, integer stride): diff=0 ✅ on pt211, diff≠0 on nightly
  • H=15 (15→7, non-integer stride): diff=8.86 ❌

4. CUDA only — CPU produces correct results.

5. Deterministic across random seeds — not a race condition.

Versions

  • PyTorch 2.11.0+cu126: bug present
  • PyTorch 2.12.0.dev20260410+cu126 (nightly): bug present
  • CUDA 12.6, Tesla T4
  • OS: Ubuntu 20.04, Python 3.11

cc @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @mikaylagawarecki @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The bug in torch.compile(backend='inductor') when fusing adaptive_avg_pool2d with downstream flatten and sum operations can be worked around by avoiding the fusion or using a channel dimension that is a power of 2.

Guidance

  • Verify that the issue is indeed caused by the fusion of adaptive_avg_pool2d with flatten and sum operations by checking the output of each individual operation.
  • Check if using a channel dimension that is a power of 2 (e.g., 32, 64, 128) resolves the issue.
  • Consider using the torch.compile with a different backend or disabling the fusion of operations to avoid the bug.
  • Test the code on the CPU to see if the issue is specific to the CUDA backend.

Example

# Using a channel dimension that is a power of 2
x = torch.randn(2, 32, 8, 8, dtype=torch.float64, device='cuda')
compiled = torch.compile(f, backend='inductor')(x)
print(f"Inductor: {compiled.tolist()}")

Notes

The bug seems to be specific to the CUDA backend and does not occur on the CPU. The issue is also deterministic and not affected by random seeds.

Recommendation

Apply a workaround by avoiding the fusion of operations or using a channel dimension that is a power of 2, as the root cause of the bug is not immediately clear and may require further investigation.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix `torch.compile` produces wrong results when fusing `adaptive_avg_pool2d` with `flatten + sum` [1 pull requests]