pytorch - 💡(How to fix) Fix torch.func.jacfwd gives incorrect Jacobian for torch.pow(torch.abs(complex), complex exponent)

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Root Cause

torch.func.jacfwd and torch.func.jacrev should agree for this real-wrapped function. The imaginary output should not have zero derivative here because: abs(z) ** (1 + 1j) = exp((1 + 1j) * log(abs(z))), so the imaginary part depends on log(abs(z)), which depends on both Re(z) and Im(z).

Fix Action

Fix / Workaround

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 52 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 384 On-line CPU(s) list: 0-383 Vendor ID: AuthenticAMD Model name: AMD EPYC 9654 96-Core Processor CPU family: 25 Model: 17 Thread(s) per core: 2 Core(s) per socket: 96 Socket(s): 2 Stepping: 1 Frequency boost: enabled CPU(s) scaling MHz: 48% CPU max MHz: 3709.3569 CPU min MHz: 1500.0000 BogoMIPS: 4799.86 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap ibpb_exit_to_user Virtualization: AMD-V L1d cache: 6 MiB (192 instances) L1i cache: 6 MiB (192 instances) L2 cache: 192 MiB (192 instances) L3 cache: 768 MiB (24 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-95,192-287 NUMA node1 CPU(s): 96-191,288-383 Vulnerability Gather data sampling: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; Safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsa: Vulnerable: Clear CPU buffers attempted, no microcode Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

Code Example

import torch

def complex_to_real_vec(z):
    real_part = torch.real(z)
    imag_part = torch.imag(z)
    return torch.cat([real_part, imag_part])


def make_real_wrapper(fn):
    def wrapper(coords):
        n = coords.shape[0] // 2
        real_coords = coords[:n]
        imag_coords = coords[n:]
        complex_coords = real_coords + 1j * imag_coords

        complex_res = fn(complex_coords)
        return complex_to_real_vec(complex_res)

    return wrapper


def jacobian_jacfwd(fn, x, input_size, output_size):
    x = x.detach().clone().requires_grad_(True)
    jacobian = torch.func.jacfwd(fn)(x).detach()
    return jacobian


def jacobian_jacrev(fn, x, input_size, output_size):
    x = x.detach().clone().requires_grad_(True)
    jacobian = torch.func.jacrev(fn)(x).detach()
    return jacobian


def fn(x):
    a = torch.tensor(1.0 + 1.0j, dtype=x.dtype, device=x.device)
    return torch.pow(torch.abs(x), a)


x = torch.tensor([1.0 + 1.0j], dtype=torch.complex128)

input_size = 2
output_size = 2

x_real = complex_to_real_vec(x)
real_fn = make_real_wrapper(fn)

jac_fwd = jacobian_jacfwd(real_fn, x_real, input_size, output_size)
jac_rev = jacobian_jacrev(real_fn, x_real, input_size, output_size)

print("jacfwd:")
print(jac_fwd)

print("jacrev:")
print(jac_rev)

print("difference:")
print(jac_fwd - jac_rev)

---

jacfwd:
tensor([[0.2781, 0.2781],
        [0.0000, 0.0000]], dtype=torch.float64)

jacrev:
tensor([[0.2781, 0.2781],
        [0.5471, 0.5471]], dtype=torch.float64)

difference:
tensor([[ 0.0000,  0.0000],
        [-0.5471, -0.5471]], dtype=torch.float64)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

For expression torch.pow(torch.abs(z), complex_exponent), torch.func.jacfwd returns zero for a mathematically non-zero imaginary gradients, while torch.func.jacrev returns the correct non-zero gradients.

Minimal producible example

import torch

def complex_to_real_vec(z):
    real_part = torch.real(z)
    imag_part = torch.imag(z)
    return torch.cat([real_part, imag_part])


def make_real_wrapper(fn):
    def wrapper(coords):
        n = coords.shape[0] // 2
        real_coords = coords[:n]
        imag_coords = coords[n:]
        complex_coords = real_coords + 1j * imag_coords

        complex_res = fn(complex_coords)
        return complex_to_real_vec(complex_res)

    return wrapper


def jacobian_jacfwd(fn, x, input_size, output_size):
    x = x.detach().clone().requires_grad_(True)
    jacobian = torch.func.jacfwd(fn)(x).detach()
    return jacobian


def jacobian_jacrev(fn, x, input_size, output_size):
    x = x.detach().clone().requires_grad_(True)
    jacobian = torch.func.jacrev(fn)(x).detach()
    return jacobian


def fn(x):
    a = torch.tensor(1.0 + 1.0j, dtype=x.dtype, device=x.device)
    return torch.pow(torch.abs(x), a)


x = torch.tensor([1.0 + 1.0j], dtype=torch.complex128)

input_size = 2
output_size = 2

x_real = complex_to_real_vec(x)
real_fn = make_real_wrapper(fn)

jac_fwd = jacobian_jacfwd(real_fn, x_real, input_size, output_size)
jac_rev = jacobian_jacrev(real_fn, x_real, input_size, output_size)

print("jacfwd:")
print(jac_fwd)

print("jacrev:")
print(jac_rev)

print("difference:")
print(jac_fwd - jac_rev)

Actual output

jacfwd:
tensor([[0.2781, 0.2781],
        [0.0000, 0.0000]], dtype=torch.float64)

jacrev:
tensor([[0.2781, 0.2781],
        [0.5471, 0.5471]], dtype=torch.float64)

difference:
tensor([[ 0.0000,  0.0000],
        [-0.5471, -0.5471]], dtype=torch.float64)

Expected behavior

torch.func.jacfwd and torch.func.jacrev should agree for this real-wrapped function. The imaginary output should not have zero derivative here because: abs(z) ** (1 + 1j) = exp((1 + 1j) * log(abs(z))), so the imaginary part depends on log(abs(z)), which depends on both Re(z) and Im(z).

Versions

Collecting environment information... PyTorch version: 2.11.0+cu130 Is debug build: False CUDA used to build PyTorch: 13.0 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64) GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0 Clang version: 16.0.6 (https://github.com/llvm/llvm-project.git 7cbf1a2591520c2491aa35339f227775f4d3adf6) CMake version: version 4.1.0 Libc version: glibc-2.39

Python version: 3.10.20 | packaged by conda-forge | (main, Mar 5 2026, 16:42:22) [GCC 14.3.0] (64-bit runtime) Python platform: Linux-6.8.0-100-generic-x86_64-with-glibc2.39 Is CUDA available: False CUDA runtime version: 12.8.93 CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No devices found. Nvidia driver version: Could not collect cuDNN version: Probably one of the following: /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn.so.9.8.0 /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_adv.so.9.8.0 /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_cnn.so.9.8.0 /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_engines_precompiled.so.9.8.0 /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_engines_runtime_compiled.so.9.8.0 /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_graph.so.9.8.0 /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_heuristic.so.9.8.0 /usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_ops.so.9.8.0 /usr/local/cuda-12.8/targets/x86_64-linux/lib/libcudnn.so.9.8.0 /usr/local/cuda-12.8/targets/x86_64-linux/lib/libcudnn_adv.so.9.8.0 /usr/local/cuda-12.8/targets/x86_64-linux/lib/libcudnn_cnn.so.9.8.0 /usr/local/cuda-12.8/targets/x86_64-linux/lib/libcudnn_engines_precompiled.so.9.8.0 /usr/local/cuda-12.8/targets/x86_64-linux/lib/libcudnn_engines_runtime_compiled.so.9.8.0 /usr/local/cuda-12.8/targets/x86_64-linux/lib/libcudnn_graph.so.9.8.0 /usr/local/cuda-12.8/targets/x86_64-linux/lib/libcudnn_heuristic.so.9.8.0 /usr/local/cuda-12.8/targets/x86_64-linux/lib/libcudnn_ops.so.9.8.0 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 52 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 384 On-line CPU(s) list: 0-383 Vendor ID: AuthenticAMD Model name: AMD EPYC 9654 96-Core Processor CPU family: 25 Model: 17 Thread(s) per core: 2 Core(s) per socket: 96 Socket(s): 2 Stepping: 1 Frequency boost: enabled CPU(s) scaling MHz: 48% CPU max MHz: 3709.3569 CPU min MHz: 1500.0000 BogoMIPS: 4799.86 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap ibpb_exit_to_user Virtualization: AMD-V L1d cache: 6 MiB (192 instances) L1i cache: 6 MiB (192 instances) L2 cache: 192 MiB (192 instances) L3 cache: 768 MiB (24 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-95,192-287 NUMA node1 CPU(s): 96-191,288-383 Vulnerability Gather data sampling: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; Safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsa: Vulnerable: Clear CPU buffers attempted, no microcode Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

Versions of relevant libraries: [pip3] numpy==2.2.6 [pip3] nvidia-cublas==13.1.0.3 [pip3] nvidia-cuda-cupti==13.0.85 [pip3] nvidia-cuda-nvrtc==13.0.88 [pip3] nvidia-cuda-runtime==13.0.96 [pip3] nvidia-cudnn-cu13==9.19.0.56 [pip3] nvidia-cufft==12.0.0.61 [pip3] nvidia-curand==10.4.0.35 [pip3] nvidia-cusolver==12.0.4.66 [pip3] nvidia-cusparse==12.6.3.3 [pip3] nvidia-cusparselt-cu13==0.8.0 [pip3] nvidia-nccl-cu13==2.28.9 [pip3] nvidia-nvjitlink==13.0.88 [pip3] nvidia-nvtx==13.0.85 [pip3] torch==2.11.0 [pip3] torchvision==0.26.0 [pip3] triton==3.6.0 [conda] numpy 2.2.6 pypi_0 pypi [conda] nvidia-cublas 13.1.0.3 pypi_0 pypi [conda] nvidia-cuda-cupti 13.0.85 pypi_0 pypi [conda] nvidia-cuda-nvrtc 13.0.88 pypi_0 pypi [conda] nvidia-cuda-runtime 13.0.96 pypi_0 pypi [conda] nvidia-cudnn-cu13 9.19.0.56 pypi_0 pypi [conda] nvidia-cufft 12.0.0.61 pypi_0 pypi [conda] nvidia-curand 10.4.0.35 pypi_0 pypi [conda] nvidia-cusolver 12.0.4.66 pypi_0 pypi [conda] nvidia-cusparse 12.6.3.3 pypi_0 pypi [conda] nvidia-cusparselt-cu13 0.8.0 pypi_0 pypi [conda] nvidia-nccl-cu13 2.28.9 pypi_0 pypi [conda] nvidia-nvjitlink 13.0.88 pypi_0 pypi [conda] nvidia-nvtx 13.0.85 pypi_0 pypi [conda] torch 2.11.0 pypi_0 pypi [conda] torchvision 0.26.0 pypi_0 pypi [conda] triton 3.6.0 pypi_0 pypi

cc @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @Chillee @samdow @kshitij12345

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

torch.func.jacfwd and torch.func.jacrev should agree for this real-wrapped function. The imaginary output should not have zero derivative here because: abs(z) ** (1 + 1j) = exp((1 + 1j) * log(abs(z))), so the imaginary part depends on log(abs(z)), which depends on both Re(z) and Im(z).

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix torch.func.jacfwd gives incorrect Jacobian for torch.pow(torch.abs(complex), complex exponent)