pytorch - 💡(How to fix) Fix [Inductor] Block-wise quantization (MXFP8) fusion rejects FloorDiv broadcast index before reindex runs [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#183542Fetched 2026-05-14 03:28:28
View on GitHub
Comments
1
Participants
1
Timeline
87
Reactions
1
Author
Participants
Timeline (top)
mentioned ×38subscribed ×38labeled ×7commented ×1

Root Cause

fusable_read_and_write() requires read.index == write.index, which hard-rejects this pair because FloorDiv(p1, 32)p1. The fusion log shows "memory deps did not match". Result: RMSNorm+MXFP8 produces 3 kernels instead of 2 (M ≥ 4); standalone MXFP8 produces 2 kernels instead of 1 (M = block_size).

Fix Action

Fix / Workaround

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 20 On-line CPU(s) list: 0-19 Vendor ID: GenuineIntel BIOS Vendor ID: Intel(R) Corporation Model name: Intel(R) Core(TM) Ultra 7 265K BIOS Model name: Intel(R) Core(TM) Ultra 7 265K To Be Filled By O.E.M. CPU @ 3.9GHz BIOS CPU family: 774 CPU family: 6 Model: 198 Thread(s) per core: 1 Core(s) per socket: 20 Socket(s): 1 Stepping: 2 CPU(s) scaling MHz: 44% CPU max MHz: 5500.0000 CPU min MHz: 800.0000 BogoMIPS: 7756.80 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdt_a rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect user_shstk avx_vnni lam wbnoinvd dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid bus_lock_detect movdiri movdir64b fsrm md_clear serialize arch_lbr ibt flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 704 KiB (18 instances) L1i cache: 1.1 MiB (18 instances) L2 cache: 36 MiB (11 instances) L3 cache: 30 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-19 Vulnerability Gather data sampling: Not affected Vulnerability Ghostwrite: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS Not affected; BHI BHI_DIS_S Vulnerability Srbds: Not affected Vulnerability Tsa: Not affected Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

Code Example

#!/usr/bin/env python3
"""
Reproducer: Inductor fails to fuse block-wise quantization with preceding reduction.

Pattern: RMSNorm → block-wise MXFP8 quantization
  - RMSNorm: variance reduction over K (rnumel=K)
  - MXFP8 quant: block amax reduction over block_size=32 (rnumel=32),
    then pointwise fp8 conversion using broadcast scale (FloorDiv index)

Expected: 2 kernels (variance reduction + fused block_amax/scale/quantize)
Actual:   3 kernels (variance reduction, block_amax+scale_encode, fp8_quantize)

Root cause: fusable_read_and_write() in scheduler.py requires exact index
expression match between producer writes and consumer reads. The block amax
reduction writes with index `256*p0 + p1` (ranges p1∈[0,256)), but the fp8
quantize pointwise reads with `256*p0 + (p1//32)` (ranges p1∈[0,8192))the FloorDiv broadcast pattern is not recognized as fusable.

Affects: Any block-wise quantization pattern (MXFP8, MXFP4, etc.) where
a per-block scale is broadcast back to element-level granularity.

Usage:
  python reproducer.py               # auto-detect device
  python reproducer.py --device cuda  # force CUDA
  python reproducer.py --device xpu   # force XPU
  python reproducer.py --device cpu   # CPU (for reference, no GPU needed)
  python reproducer.py --debug        # also dump TORCH_COMPILE_DEBUG
"""

import argparse
import os
import sys

import torch
import torch._inductor.metrics as metrics


# ---------------------------------------------------------------------------
# Pattern: RMSNorm + block-wise MXFP8 quantization
# Verbatim from vLLM MXFP8 quantization path — DO NOT simplify
# ---------------------------------------------------------------------------

MXFP8_BLOCK_SIZE = 32


def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Block-wise MXFP8 E4M3 quantization.

    Decomposed op sequence from real vLLM FX graph:
      to_fp32 → unflatten(K → num_blocks × 32) → abs → amax(dim=-1)
clamp(min=FLT_MIN) → log2 → floor → +127clamp(0,254) → to_uint8
exp2(scale−127)unsqueeze(-1) → div → flatten → to_fp8e4m3
    """
    K = x.shape[-1]
    num_blocks = K // MXFP8_BLOCK_SIZE

    x_fp32 = x.to(torch.float32)
    x_blocked = x_fp32.unflatten(-1, (num_blocks, MXFP8_BLOCK_SIZE))

    # Per-block amax — reduction over block_size=32
    amax = x_blocked.abs().amax(dim=-1)
    amax = amax.clamp(min=torch.finfo(torch.float32).tiny)

    # Scale encoding (e8m0 format)
    scale_biased = torch.floor(torch.log2(amax)) + 127.0
    scale_biased = scale_biased.clamp(0, 254)
    scales_uint8 = scale_biased.to(torch.uint8)

    # Quantize: broadcast scale back to elements via unsqueeze
    descale = torch.exp2(scale_biased - 127.0)
    x_scaled = x_blocked / descale.unsqueeze(-1)

    # Final dtype conversions
    x_fp8 = x_scaled.flatten(-2).to(torch.float8_e4m3fn)
    scales = scales_uint8.view(torch.float8_e8m0fnu)
    return x_fp8, scales


def rms_norm_then_mxfp8_quant(
    x: torch.Tensor, norm_weight: torch.Tensor, eps: float = 1e-6
) -> tuple[torch.Tensor, torch.Tensor]:
    """RMSNorm followed by block-wise MXFP8 quantization."""
    x_fp32 = x.to(torch.float32)
    variance = x_fp32.pow(2).mean(dim=-1, keepdim=True)
    normed = x_fp32 * torch.rsqrt(variance + eps)
    normed = normed.to(torch.bfloat16) * norm_weight
    return mxfp8_quantize(normed)


# ---------------------------------------------------------------------------
# Standalone MXFP8 quantize (no RMSNorm) — isolates the block-broadcast gap
# ---------------------------------------------------------------------------


def mxfp8_quant_only(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Standalone block-wise MXFP8 quantization (no preceding norm)."""
    return mxfp8_quantize(x)


# ---------------------------------------------------------------------------
# Test harness
# ---------------------------------------------------------------------------


def detect_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch, "xpu") and torch.xpu.is_available():
        return "xpu"
    return "cpu"


def sync(device: str) -> None:
    if device == "cuda":
        torch.cuda.synchronize()
    elif device == "xpu":
        torch.xpu.synchronize()


def count_kernels(fn, *args, device: str) -> int:
    torch._dynamo.reset()
    metrics.reset()
    compiled = torch.compile(fn, fullgraph=True)
    compiled(*args)
    sync(device)
    return metrics.generated_kernel_count


def main() -> None:
    parser = argparse.ArgumentParser(description="Inductor block-broadcast fusion gap reproducer")
    parser.add_argument("--device", type=str, default=None, help="Device: cuda, xpu, or cpu")
    parser.add_argument("--debug", action="store_true", help="Enable TORCH_COMPILE_DEBUG")
    args = parser.parse_args()

    device = args.device or detect_device()
    if args.debug:
        os.environ["TORCH_COMPILE_DEBUG"] = "1"

    print(f"PyTorch version: {torch.__version__}")
    print(f"Device: {device}")
    print(f"float8_e4m3fn support: {hasattr(torch, 'float8_e4m3fn')}")
    print(f"float8_e8m0fnu support: {hasattr(torch, 'float8_e8m0fnu')}")
    print()

    K = 8192
    torch.manual_seed(42)

    print("=" * 70)
    print("Test 1: RMSNorm + MXFP8 block-wise quantization")
    print("=" * 70)
    print()
    print(f"  Pattern: variance_red(rnumel={K}) → norm_pw → block_amax_red(rnumel={MXFP8_BLOCK_SIZE})")
    print(f"           → scale_encode_pw → fp8_quantize_pw(reads scale via FloorDiv)")
    print()
    print(f"  Note: M=1 may produce 2 kernels (optimal) because Inductor collapses")
    print(f"  the M dimension, simplifying the iteration domain. The FloorDiv gap")
    print(f"  manifests at M≥8 where the block structure is fully preserved.")
    print()

    for M in [1, 2, 4, 8, 32, 64]:
        x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        w = torch.randn(K, dtype=torch.bfloat16, device=device)
        n = count_kernels(rms_norm_then_mxfp8_quant, x, w, device=device)
        status = "OK" if n <= 2 else "SUBOPTIMAL"
        print(f"  M={M:5d}  K={K}  kernels={n}  (optimal=2)  [{status}]")

    print()
    print("=" * 70)
    print("Test 2: Standalone MXFP8 block-wise quantization (no RMSNorm)")
    print("=" * 70)
    print()
    print(f"  Pattern: block_amax_red(rnumel={MXFP8_BLOCK_SIZE}) → scale_encode_pw")
    print(f"           → fp8_quantize_pw(reads scale via FloorDiv)")
    print()

    for M in [1, 2, 4, 8, 32, 64]:
        x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        n = count_kernels(mxfp8_quant_only, x, device=device)
        status = "OK" if n <= 1 else "SUBOPTIMAL"
        print(f"  M={M:5d}  K={K}  kernels={n}  (optimal=1)  [{status}]")

    print()
    print("=" * 70)
    print("Fusion log (set TORCH_LOGS='+fusion' to see rejection details)")
    print("=" * 70)
    print()
    print("Expected rejection messages:")
    print("  - 'numel/rnumel mismatch (reduce)' between variance and block_amax")
    print("  - 'memory deps did not match' between block_amax and fp8_quantize")
    print("    (due to FloorDiv broadcast index: write=256*p0+p1 vs read=256*p0+(p1//32))")
    print()
    print("To see full fusion trace:")
    print(f"  TORCH_LOGS='+fusion' python {sys.argv[0]}")


if __name__ == "__main__":
    main()

---

Collecting environment information...
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: N/A

OS: Ubuntu 25.04 (x86_64)
GCC version: (Ubuntu 14.2.0-19ubuntu2) 14.2.0
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.41

Python version: 3.13.12 | packaged by conda-forge | (main, Feb  5 2026, 05:53:46) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.14.0-37-generic-x86_64-with-glibc2.41
Is CUDA available: N/A
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: N/A
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: N/A
Caching allocator config: N/A

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           46 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  20
On-line CPU(s) list:                     0-19
Vendor ID:                               GenuineIntel
BIOS Vendor ID:                          Intel(R) Corporation
Model name:                              Intel(R) Core(TM) Ultra 7 265K
BIOS Model name:                         Intel(R) Core(TM) Ultra 7 265K To Be Filled By O.E.M. CPU @ 3.9GHz
BIOS CPU family:                         774
CPU family:                              6
Model:                                   198
Thread(s) per core:                      1
Core(s) per socket:                      20
Socket(s):                               1
Stepping:                                2
CPU(s) scaling MHz:                      44%
CPU max MHz:                             5500.0000
CPU min MHz:                             800.0000
BogoMIPS:                                7756.80
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdt_a rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect user_shstk avx_vnni lam wbnoinvd dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid bus_lock_detect movdiri movdir64b fsrm md_clear serialize arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                          VT-x
L1d cache:                               704 KiB (18 instances)
L1i cache:                               1.1 MiB (18 instances)
L2 cache:                                36 MiB (11 instances)
L3 cache:                                30 MiB (1 instance)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-19
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS Not affected; BHI BHI_DIS_S
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB before exit to userspace

Versions of relevant libraries:
[pip3] No relevant packages
[conda] No relevant packages
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Edited by claude. While profiling Llama4-Maverick-17B-128E inference with TorchAO MXFP8 quantization under torch.compile, we found that block-wise MXFP8 quantization produces extra Triton kernels that should be fused.

Step 1 — Observe the fusion gap. Block-wise MXFP8 computes a per-block scale via reduction (amax over block_size=32), then broadcasts it back to element level via unsqueeze(-1) + division. Inductor decomposes this into:

  • op1 (reduction): writes buf1 with index 256*p0 + p1, ranges p1 ∈ [0, 256)
  • op2 (pointwise): reads buf1 with index 256*p0 + (p1 // 32), ranges p1 ∈ [0, 8192)

fusable_read_and_write() requires read.index == write.index, which hard-rejects this pair because FloorDiv(p1, 32)p1. The fusion log shows "memory deps did not match". Result: RMSNorm+MXFP8 produces 3 kernels instead of 2 (M ≥ 4); standalone MXFP8 produces 2 kernels instead of 1 (M = block_size).

Step 2 — Check existing PR #176927. This PR added _try_reindex_pointwise_for_reduction which can resolve the index mismatch by reindexing the pointwise's iteration domain to match the reduction's. It successfully handles standalone MXFP8 for most M values. But two gaps remain:

  • Gap A (RMSNorm + MXFP8, M ≥ 4): The reindex path is only wired into shared_data_after_reordering_loop (the scoring fallback). When the scheduler enters the vertical fusion branch, it calls can_fuse_vertical() directly — the reindex path never gets a chance to run. The FloorDiv mismatch is rejected at fusable_read_and_write before any reindexing.

  • Gap B (standalone MXFP8, M = block_size): Even when the code reaches shared_data_after_reordering_loop, PR #179090's write→read size subset heuristic in _try_reorder_loops_for_candidates produces a false positive that short-circuits before reaching the reindex fallback.

Step 3 — Trace root cause. Gap A is a wiring issue: _try_reindex_pointwise_for_reduction exists but is not called before can_fuse_vertical. Gap B is a coincidence in the size subset heuristic — when M = block_size (B), the reduction write normalizes to (M×N/B,) = (N,), and the subset check {N} ⊆ {M, N} evaluates True (false positive), causing _try_reorder_loops_for_candidates to return early without reaching the reindex fallback. When M ≠ B, M×N/B ≠ M and ≠ N, so the subset check correctly returns False and falls through to reindex.

Experimental verification:

Mblockw_sizesr_sizessubset?KERNELS
232(512,)(2, 8192)False1
3232(8192,)(32, 8192)True2
6432(16384,)(64, 8192)False1

Reproducer

#!/usr/bin/env python3
"""
Reproducer: Inductor fails to fuse block-wise quantization with preceding reduction.

Pattern: RMSNorm → block-wise MXFP8 quantization
  - RMSNorm: variance reduction over K (rnumel=K)
  - MXFP8 quant: block amax reduction over block_size=32 (rnumel=32),
    then pointwise fp8 conversion using broadcast scale (FloorDiv index)

Expected: 2 kernels (variance reduction + fused block_amax/scale/quantize)
Actual:   3 kernels (variance reduction, block_amax+scale_encode, fp8_quantize)

Root cause: fusable_read_and_write() in scheduler.py requires exact index
expression match between producer writes and consumer reads. The block amax
reduction writes with index `256*p0 + p1` (ranges p1∈[0,256)), but the fp8
quantize pointwise reads with `256*p0 + (p1//32)` (ranges p1∈[0,8192)) —
the FloorDiv broadcast pattern is not recognized as fusable.

Affects: Any block-wise quantization pattern (MXFP8, MXFP4, etc.) where
a per-block scale is broadcast back to element-level granularity.

Usage:
  python reproducer.py               # auto-detect device
  python reproducer.py --device cuda  # force CUDA
  python reproducer.py --device xpu   # force XPU
  python reproducer.py --device cpu   # CPU (for reference, no GPU needed)
  python reproducer.py --debug        # also dump TORCH_COMPILE_DEBUG
"""

import argparse
import os
import sys

import torch
import torch._inductor.metrics as metrics


# ---------------------------------------------------------------------------
# Pattern: RMSNorm + block-wise MXFP8 quantization
# Verbatim from vLLM MXFP8 quantization path — DO NOT simplify
# ---------------------------------------------------------------------------

MXFP8_BLOCK_SIZE = 32


def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Block-wise MXFP8 E4M3 quantization.

    Decomposed op sequence from real vLLM FX graph:
      to_fp32 → unflatten(K → num_blocks × 32) → abs → amax(dim=-1)
      → clamp(min=FLT_MIN) → log2 → floor → +127 → clamp(0,254) → to_uint8
      → exp2(scale−127) → unsqueeze(-1) → div → flatten → to_fp8e4m3
    """
    K = x.shape[-1]
    num_blocks = K // MXFP8_BLOCK_SIZE

    x_fp32 = x.to(torch.float32)
    x_blocked = x_fp32.unflatten(-1, (num_blocks, MXFP8_BLOCK_SIZE))

    # Per-block amax — reduction over block_size=32
    amax = x_blocked.abs().amax(dim=-1)
    amax = amax.clamp(min=torch.finfo(torch.float32).tiny)

    # Scale encoding (e8m0 format)
    scale_biased = torch.floor(torch.log2(amax)) + 127.0
    scale_biased = scale_biased.clamp(0, 254)
    scales_uint8 = scale_biased.to(torch.uint8)

    # Quantize: broadcast scale back to elements via unsqueeze
    descale = torch.exp2(scale_biased - 127.0)
    x_scaled = x_blocked / descale.unsqueeze(-1)

    # Final dtype conversions
    x_fp8 = x_scaled.flatten(-2).to(torch.float8_e4m3fn)
    scales = scales_uint8.view(torch.float8_e8m0fnu)
    return x_fp8, scales


def rms_norm_then_mxfp8_quant(
    x: torch.Tensor, norm_weight: torch.Tensor, eps: float = 1e-6
) -> tuple[torch.Tensor, torch.Tensor]:
    """RMSNorm followed by block-wise MXFP8 quantization."""
    x_fp32 = x.to(torch.float32)
    variance = x_fp32.pow(2).mean(dim=-1, keepdim=True)
    normed = x_fp32 * torch.rsqrt(variance + eps)
    normed = normed.to(torch.bfloat16) * norm_weight
    return mxfp8_quantize(normed)


# ---------------------------------------------------------------------------
# Standalone MXFP8 quantize (no RMSNorm) — isolates the block-broadcast gap
# ---------------------------------------------------------------------------


def mxfp8_quant_only(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Standalone block-wise MXFP8 quantization (no preceding norm)."""
    return mxfp8_quantize(x)


# ---------------------------------------------------------------------------
# Test harness
# ---------------------------------------------------------------------------


def detect_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch, "xpu") and torch.xpu.is_available():
        return "xpu"
    return "cpu"


def sync(device: str) -> None:
    if device == "cuda":
        torch.cuda.synchronize()
    elif device == "xpu":
        torch.xpu.synchronize()


def count_kernels(fn, *args, device: str) -> int:
    torch._dynamo.reset()
    metrics.reset()
    compiled = torch.compile(fn, fullgraph=True)
    compiled(*args)
    sync(device)
    return metrics.generated_kernel_count


def main() -> None:
    parser = argparse.ArgumentParser(description="Inductor block-broadcast fusion gap reproducer")
    parser.add_argument("--device", type=str, default=None, help="Device: cuda, xpu, or cpu")
    parser.add_argument("--debug", action="store_true", help="Enable TORCH_COMPILE_DEBUG")
    args = parser.parse_args()

    device = args.device or detect_device()
    if args.debug:
        os.environ["TORCH_COMPILE_DEBUG"] = "1"

    print(f"PyTorch version: {torch.__version__}")
    print(f"Device: {device}")
    print(f"float8_e4m3fn support: {hasattr(torch, 'float8_e4m3fn')}")
    print(f"float8_e8m0fnu support: {hasattr(torch, 'float8_e8m0fnu')}")
    print()

    K = 8192
    torch.manual_seed(42)

    print("=" * 70)
    print("Test 1: RMSNorm + MXFP8 block-wise quantization")
    print("=" * 70)
    print()
    print(f"  Pattern: variance_red(rnumel={K}) → norm_pw → block_amax_red(rnumel={MXFP8_BLOCK_SIZE})")
    print(f"           → scale_encode_pw → fp8_quantize_pw(reads scale via FloorDiv)")
    print()
    print(f"  Note: M=1 may produce 2 kernels (optimal) because Inductor collapses")
    print(f"  the M dimension, simplifying the iteration domain. The FloorDiv gap")
    print(f"  manifests at M≥8 where the block structure is fully preserved.")
    print()

    for M in [1, 2, 4, 8, 32, 64]:
        x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        w = torch.randn(K, dtype=torch.bfloat16, device=device)
        n = count_kernels(rms_norm_then_mxfp8_quant, x, w, device=device)
        status = "OK" if n <= 2 else "SUBOPTIMAL"
        print(f"  M={M:5d}  K={K}  kernels={n}  (optimal=2)  [{status}]")

    print()
    print("=" * 70)
    print("Test 2: Standalone MXFP8 block-wise quantization (no RMSNorm)")
    print("=" * 70)
    print()
    print(f"  Pattern: block_amax_red(rnumel={MXFP8_BLOCK_SIZE}) → scale_encode_pw")
    print(f"           → fp8_quantize_pw(reads scale via FloorDiv)")
    print()

    for M in [1, 2, 4, 8, 32, 64]:
        x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
        n = count_kernels(mxfp8_quant_only, x, device=device)
        status = "OK" if n <= 1 else "SUBOPTIMAL"
        print(f"  M={M:5d}  K={K}  kernels={n}  (optimal=1)  [{status}]")

    print()
    print("=" * 70)
    print("Fusion log (set TORCH_LOGS='+fusion' to see rejection details)")
    print("=" * 70)
    print()
    print("Expected rejection messages:")
    print("  - 'numel/rnumel mismatch (reduce)' between variance and block_amax")
    print("  - 'memory deps did not match' between block_amax and fp8_quantize")
    print("    (due to FloorDiv broadcast index: write=256*p0+p1 vs read=256*p0+(p1//32))")
    print()
    print("To see full fusion trace:")
    print(f"  TORCH_LOGS='+fusion' python {sys.argv[0]}")


if __name__ == "__main__":
    main()

Versions

Collecting environment information...
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: N/A

OS: Ubuntu 25.04 (x86_64)
GCC version: (Ubuntu 14.2.0-19ubuntu2) 14.2.0
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.41

Python version: 3.13.12 | packaged by conda-forge | (main, Feb  5 2026, 05:53:46) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.14.0-37-generic-x86_64-with-glibc2.41
Is CUDA available: N/A
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: N/A
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: N/A
Caching allocator config: N/A

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           46 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  20
On-line CPU(s) list:                     0-19
Vendor ID:                               GenuineIntel
BIOS Vendor ID:                          Intel(R) Corporation
Model name:                              Intel(R) Core(TM) Ultra 7 265K
BIOS Model name:                         Intel(R) Core(TM) Ultra 7 265K To Be Filled By O.E.M. CPU @ 3.9GHz
BIOS CPU family:                         774
CPU family:                              6
Model:                                   198
Thread(s) per core:                      1
Core(s) per socket:                      20
Socket(s):                               1
Stepping:                                2
CPU(s) scaling MHz:                      44%
CPU max MHz:                             5500.0000
CPU min MHz:                             800.0000
BogoMIPS:                                7756.80
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdt_a rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect user_shstk avx_vnni lam wbnoinvd dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid bus_lock_detect movdiri movdir64b fsrm md_clear serialize arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                          VT-x
L1d cache:                               704 KiB (18 instances)
L1i cache:                               1.1 MiB (18 instances)
L2 cache:                                36 MiB (11 instances)
L3 cache:                                30 MiB (1 instance)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-19
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS Not affected; BHI BHI_DIS_S
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB before exit to userspace

Versions of relevant libraries:
[pip3] No relevant packages
[conda] No relevant packages

cc @jerryzh168 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [Inductor] Block-wise quantization (MXFP8) fusion rejects FloorDiv broadcast index before reindex runs [1 comments, 1 participants]