pytorch - ✅(Solved) Fix [PERFORMANCE] [NN] Change `torch._C._get_tracing_state()` to `torch._C._is_tracing()` for around 15% improved performance of `module.py` `_call_impl()` [1 pull requests, 1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178070Fetched 2026-04-08 01:12:17
View on GitHub
Comments
1
Participants
1
Timeline
21
Reactions
0
Participants
Timeline (top)
labeled ×6mentioned ×6subscribed ×6closed ×1

Root Cause

The full log is available online here (https://justpaste.it/fqjqe) because it's way too long (100 runs each script (for both (PyTorch itself and PyTorch with the patch))) for pasting it on GitHub, but you can visit the link above and view all results.

Fix Action

Fix / Workaround

pythontorch_A = r"C:\Users\User\pythontorch\PCbuild\amd64\python.exe" pythontorch_B = r"C:\Users\User\pythontorch_patch\PCbuild\amd64\python.exe" benchmark_script = r"C:\Users\User\Desktop\testtorchperf.py"

print("\n===== STATS =====") summarize(results_A, "PyTorch") summarize(results_B, "PyTorch PATCH")

The full log is available online here (https://justpaste.it/fqjqe) because it's way too long (100 runs each script (for both (PyTorch itself and PyTorch with the patch))) for pasting it on GitHub, but you can visit the link above and view all results.

PR fix notes

PR #11: [PERFORMANCE] [NN] Change torch._C._get_tracing_state() to torch._C._is_tracing() for around 15% improved performance of module.py _call_impl()

Description (problem / solution / changelog)

Summary

This pull request implements #178070 by replacing torch._C._get_tracing_state() with torch._C._is_tracing() in torch/nn/modules/module.py inside _call_impl(), avoiding unnecessary object overhead and leading to a performance boost by around 15% (see benchmark results down below).

Explanation

In the original implementation of PyTorch we used torch._C._get_tracing_state() which returns a full tracing state object. For the purpose of checking whether tracing is active (and we do not use the object itself anymore later on), this is unnecessary overhead. Replacing it with torch._C._is_tracing() directly returns a boolean, avoiding usage of the object itself and instead only checking the returned boolean via torch._C._is_tracing(). The benchmark results (down below) confirm a ~15% performance improvement for module.py _call_impl().

Benchmarks

I've tested the suggested change with the following scripts:

torchcontroller.py

import subprocess
import statistics
import math
import time

pythontorch_A = r"C:\Users\User\pythontorch\PCbuild\amd64\python.exe"
pythontorch_B = r"C:\Users\User\pythontorch_patch\PCbuild\amd64\python.exe"
benchmark_script = r"C:\Users\User\Desktop\testtorchperf.py"

runs = 100

results_A = []
results_B = []

for i in range(runs):
    out_A = subprocess.check_output([pythontorch_A, benchmark_script])
    time_A = float(out_A.strip())
    results_A.append(time_A)
    print(f"Run {i+1}/{runs} completed for PyTorch A | Time: {time_A} s")
    
    out_B = subprocess.check_output([pythontorch_B, benchmark_script])
    time_B = float(out_B.strip())
    results_B.append(time_B)
    print(f"Run {i+1}/{runs} completed for PyTorch B | Time: {time_B} s")

def summarize(data, label):
    n = len(data)
    mean = statistics.mean(data)
    stdev = statistics.stdev(data)
    median = statistics.median(data)

    z = 1.96
    margin = z * (stdev / math.sqrt(n))
    ci_lower = mean - margin
    ci_upper = mean + margin

    print(f"\n--- {label} ---")
    print(f"Mean: {mean:.6f} s")
    print(f"Median: {median:.6f} s")
    print(f"Std Dev: {stdev:.6f} s")
    print(f"95% CI: [{ci_lower:.6f}, {ci_upper:.6f}] s")

print("\n===== STATS =====")
summarize(results_A, "PyTorch")
summarize(results_B, "PyTorch PATCH")

speedup = statistics.mean(results_A) / statistics.mean(results_B)
print(f"\nSpeedup A/B: {speedup:.4f}x")

testtorchperf.py

import torch
import torch.nn as nn
import time

class TestModule(nn.Module):
    def forward(self, x):
        return x

def main():
    device = torch.device('cpu')
    model = TestModule().to(device)
    x = torch.tensor([[1.0, 2.0]], device=device)
    num_iter = 200_000

    start = time.perf_counter()
    for _ in range(num_iter):
        model(x)
    end = time.perf_counter()

    elapsed = end - start
    print(f"{elapsed:.6f}")
    
if __name__ == "__main__":
    main()

Results

The full log is available online here (https://justpaste.it/fqjqe) because it's way too long (100 runs each script (for both (PyTorch itself and PyTorch with the patch))) for pasting it on GitHub, but you can visit the link above and view all results.

Summary (for full log see https://justpaste.it/fqjqe):

--- PyTorch ---
Mean: 0.255768 s
Median: 0.252001 s
Std Dev: 0.011083 s
95% CI: [0.253596, 0.257940] s

--- PyTorch PATCH ---
Mean: 0.220687 s
Median: 0.217648 s
Std Dev: 0.009732 s
95% CI: [0.218780, 0.222594] s

Speedup A/B: 1.1590x

Contributed by Benedikt Johannes

Changed files

  • torch/nn/modules/module.py (modified, +1/-1)

Code Example

import subprocess
import statistics
import math
import time

pythontorch_A = r"C:\Users\User\pythontorch\PCbuild\amd64\python.exe"
pythontorch_B = r"C:\Users\User\pythontorch_patch\PCbuild\amd64\python.exe"
benchmark_script = r"C:\Users\User\Desktop\testtorchperf.py"

runs = 100

results_A = []
results_B = []

for i in range(runs):
    out_A = subprocess.check_output([pythontorch_A, benchmark_script])
    time_A = float(out_A.strip())
    results_A.append(time_A)
    print(f"Run {i+1}/{runs} completed for PyTorch A | Time: {time_A} s")
    
    out_B = subprocess.check_output([pythontorch_B, benchmark_script])
    time_B = float(out_B.strip())
    results_B.append(time_B)
    print(f"Run {i+1}/{runs} completed for PyTorch B | Time: {time_B} s")

def summarize(data, label):
    n = len(data)
    mean = statistics.mean(data)
    stdev = statistics.stdev(data)
    median = statistics.median(data)

    z = 1.96
    margin = z * (stdev / math.sqrt(n))
    ci_lower = mean - margin
    ci_upper = mean + margin

    print(f"\n--- {label} ---")
    print(f"Mean: {mean:.6f} s")
    print(f"Median: {median:.6f} s")
    print(f"Std Dev: {stdev:.6f} s")
    print(f"95% CI: [{ci_lower:.6f}, {ci_upper:.6f}] s")

print("\n===== STATS =====")
summarize(results_A, "PyTorch")
summarize(results_B, "PyTorch PATCH")

speedup = statistics.mean(results_A) / statistics.mean(results_B)
print(f"\nSpeedup A/B: {speedup:.4f}x")

---

import torch
import torch.nn as nn
import time

class TestModule(nn.Module):
    def forward(self, x):
        return x

def main():
    device = torch.device('cpu')
    model = TestModule().to(device)
    x = torch.tensor([[1.0, 2.0]], device=device)
    num_iter = 200_000

    start = time.perf_counter()
    for _ in range(num_iter):
        model(x)
    end = time.perf_counter()

    elapsed = end - start
    print(f"{elapsed:.6f}")
    
if __name__ == "__main__":
    main()
RAW_BUFFERClick to expand / collapse

Summary

This issue suggests replacing torch._C._get_tracing_state() with torch._C._is_tracing() in torch/nn/modules/module.py inside _call_impl(), avoiding unnecessary object overhead and leading to a performance boost by around 15% (see benchmark results down below).

Explanation

In the original implementation of PyTorch we used torch._C._get_tracing_state() which returns a full tracing state object. For the purpose of checking whether tracing is active (and we do not use the object itself anymore later on), this is unnecessary overhead. Replacing it with torch._C._is_tracing() directly returns a boolean, avoiding usage of the object itself and instead only checking the returned boolean via torch._C._is_tracing(). The benchmark results (down below) confirm a ~15% performance improvement for module.py _call_impl().

Implementation and benchmarks

I've implemented the suggested change (see implementation) and tested it with the following scripts:

torchcontroller.py

import subprocess
import statistics
import math
import time

pythontorch_A = r"C:\Users\User\pythontorch\PCbuild\amd64\python.exe"
pythontorch_B = r"C:\Users\User\pythontorch_patch\PCbuild\amd64\python.exe"
benchmark_script = r"C:\Users\User\Desktop\testtorchperf.py"

runs = 100

results_A = []
results_B = []

for i in range(runs):
    out_A = subprocess.check_output([pythontorch_A, benchmark_script])
    time_A = float(out_A.strip())
    results_A.append(time_A)
    print(f"Run {i+1}/{runs} completed for PyTorch A | Time: {time_A} s")
    
    out_B = subprocess.check_output([pythontorch_B, benchmark_script])
    time_B = float(out_B.strip())
    results_B.append(time_B)
    print(f"Run {i+1}/{runs} completed for PyTorch B | Time: {time_B} s")

def summarize(data, label):
    n = len(data)
    mean = statistics.mean(data)
    stdev = statistics.stdev(data)
    median = statistics.median(data)

    z = 1.96
    margin = z * (stdev / math.sqrt(n))
    ci_lower = mean - margin
    ci_upper = mean + margin

    print(f"\n--- {label} ---")
    print(f"Mean: {mean:.6f} s")
    print(f"Median: {median:.6f} s")
    print(f"Std Dev: {stdev:.6f} s")
    print(f"95% CI: [{ci_lower:.6f}, {ci_upper:.6f}] s")

print("\n===== STATS =====")
summarize(results_A, "PyTorch")
summarize(results_B, "PyTorch PATCH")

speedup = statistics.mean(results_A) / statistics.mean(results_B)
print(f"\nSpeedup A/B: {speedup:.4f}x")

testtorchperf.py

import torch
import torch.nn as nn
import time

class TestModule(nn.Module):
    def forward(self, x):
        return x

def main():
    device = torch.device('cpu')
    model = TestModule().to(device)
    x = torch.tensor([[1.0, 2.0]], device=device)
    num_iter = 200_000

    start = time.perf_counter()
    for _ in range(num_iter):
        model(x)
    end = time.perf_counter()

    elapsed = end - start
    print(f"{elapsed:.6f}")
    
if __name__ == "__main__":
    main()

Results

The full log is available online here (https://justpaste.it/fqjqe) because it's way too long (100 runs each script (for both (PyTorch itself and PyTorch with the patch))) for pasting it on GitHub, but you can visit the link above and view all results.

Summary (for full log see https://justpaste.it/fqjqe):

--- PyTorch ---
Mean: 0.255768 s
Median: 0.252001 s
Std Dev: 0.011083 s
95% CI: [0.253596, 0.257940] s

--- PyTorch PATCH ---
Mean: 0.220687 s
Median: 0.217648 s
Std Dev: 0.009732 s
95% CI: [0.218780, 0.222594] s

Speedup A/B: 1.1590x

Contributed by Benedikt Johannes

cc @jerryzh168 @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

extent analysis

Fix Plan

To fix the performance issue, replace torch._C._get_tracing_state() with torch._C._is_tracing() in torch/nn/modules/module.py inside _call_impl().

Code Changes

# Before
if torch._C._get_tracing_state():
    # ...

# After
if torch._C._is_tracing():
    # ...

This change directly returns a boolean, avoiding unnecessary object overhead and leading to a performance boost.

Verification

Run the provided benchmark scripts (torchcontroller.py and testtorchperf.py) to verify the performance improvement. Compare the mean execution times of the original and patched PyTorch versions.

Extra Tips

  • Ensure that the patched PyTorch version is correctly installed and imported in the benchmark scripts.
  • Run the benchmarks multiple times to account for any variability in execution times.
  • Consider adding additional test cases to verify that the performance improvement does not introduce any regressions.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING