pytorch - ✅(Solved) Fix [ONNX] Optimize should not fold DequantizeLinear [1 pull requests, 5 comments, 4 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177611Fetched 2026-04-08 00:47:14
View on GitHub
Comments
5
Participants
4
Timeline
36
Reactions
1
Timeline (top)
mentioned ×8subscribed ×8commented ×5labeled ×4

Fix Action

Fixed

PR fix notes

PR #2865: optimizer: Prevent constant folding of DynamicQuantizeLinear

Description (problem / solution / changelog)

The constant folding pass was eliminating DequantizeLinear nodes that operated on constant weight tensors during optimize(), collapsing the quantization structure into a plain Conv and losing quantization semantics in QAT-exported models.

Changes

  • optimizer/_constant_folding.py: Add DynamicQuantizeLinear to DEFAULT_CONSTANT_FOLD_BLACKLIST alongside the existing QuantizeLinear and DequantizeLinear entries; reorder alphabetically for consistency
  • optimizer/_constant_folding_test.py: Add tests verifying QuantizeLinear and DequantizeLinear are not folded when all inputs are constant initializers
<!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary>

This section details on the original issue you should resolve

<issue_title>[ONNX] Optimize should not fold DequantizeLinear</issue_title> <issue_description>### 🐛 Describe the bug

After the QAT model undergoes the onnx_program.optimize() process, there is a loss of quantization nodes. As shown in the figure on the left is the normal export, and on the right is the abnormal export graph.

<img width="898" height="884" alt="Image" src="https://github.com/user-attachments/assets/481bc3c0-38fe-45f6-9fde-bc1a287617a3" />

This bug occurred in torch/onnx/_internal/exporter/_onnx_program.py:

def optimize(self) -> None:
    self.model = onnxscript_apis.optimize(self.model)

and it internally called the optimize_ir function in onnxscript/optimizer/_optimizer.py. The default value of input_size_limit is 512. Nodes with an input size less than this value will be collapsed.

def optimize_ir(
    model: ir.Model,
    num_iterations: int = 2,
    *,
    onnx_shape_inference: bool = True,
    stop_if_no_change: bool = True,
    input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
    output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
    inline: bool = True,
) -> None:
    passes = [
        ir.passes.PassManager(
            [
                _constant_folding.FoldConstantsPass(
                    shape_inference=onnx_shape_inference,
                    input_size_limit=input_size_limit,
                    output_size_limit=output_size_limit,
                ),
                rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
                common_passes.RemoveUnusedNodesPass(),
                common_passes.RemoveUnusedFunctionsPass(),
                common_passes.RemoveUnusedOpsetsPass(),
            ],
            steps=num_iterations,
            early_stop=stop_if_no_change,
        ),
    ......

⭐ Please enable the parameter optimization function in torch/onnx/_internal/exporter/_onnx_program.py. Otherwise, I will be able to install onnxscript only by referring to the source code.

The smallest reproducible example:

import copy
import torch
import torch.nn as nn
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from onnxslim import slim
import onnx


class ConvBnReluModel(nn.Module):
    def __init__(self, eps=1e-3, momentum=0.03):
        super().__init__()
        self.conv = nn.Conv2d(4, 4, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(4, eps=eps, momentum=momentum)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


def get_batch_norm_node_args(gm):
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.batch_norm.default:
            return tuple(node.args)
    raise RuntimeError("No aten.batch_norm.default node found")


torch.manual_seed(0)
device = 'cuda' 

model = ConvBnReluModel().train().to(device)
inputs = (torch.randn(2, 4, 8, 8).to(device),)
exported = torch.export.export_for_training(copy.deepcopy(model), inputs).module()
print("before prepare:", get_batch_norm_node_args(exported))

quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))
prepared = prepare_qat_pt2e(exported, quantizer)
prepared.to(device)
torch.ao.quantization.move_exported_model_to_eval(prepared)
torch.ao.quantization.allow_exported_model_train_eval(prepared)
prepared.eval()

#---export quantized model to onnx---
qat_onnx_sp = './quant.onnx'
quantized_model = convert_pt2e(prepared)
print('convert_pt2e done!')
onnx_program = torch.onnx.export(quantized_model, inputs, dynamo=True, opset_version=21)
""" bug """
onnx_program.optimize()
onnx_program.save(qat_onnx_sp)
print(f'export qat model to [{qat_onnx_sp}] done!')

model_simp = slim(onnx_program.model_proto)
sim_path = qat_onnx_sp.replace('.onnx', '_slim.onnx')
onnx.save(model_simp, sim_path)
print(f"save onnx model to [{sim_path}] Successfully!")

Versions

Versions of relevant libraries: [pip3] executorch==0.5.0 [pip3] numpy==1.23.5 [pip3] nvidia-cublas-cu11==11.11.3.6 [pip3] nvidia-cuda-cupti-cu11==11.8.87 [pip3] nvidia-cuda-nvrtc-cu11==11.8.89 [pip3] nvidia-cuda-runtime-cu11==11.8.89 [pip3] nvidia-cudnn-cu11==9.1.0.70 [pip3] nvidia-cufft-cu11==10.9.0.58 [pip3] nvidia-curand-cu11==10.3.0.86 [pip3] nvidia-cusolver-cu11==11.4.1.48 [pip3] nvidia-cusparse-cu11==11.7.5.86 [pip3] nvidia-nccl-cu11==2.21.5 [pip3] nvidia-nvtx-cu11==11.8.86 [pip3] onnx==1.17.0 [pip3] onnx_graphsurgeon==0.5.8 [pip3] onnx-ir==0.1.12 [pip3] onnx-simplifier==0.4.36 [pip3] onnxruntime==1.21.0 [pip3] onnxruntime-gpu==1.21.0 [pip3] onnxscript==0.4.0 [pip3] onnxslim==0.1.48 [pip3] torch==2.6.0+cu118 [pip3] torchao==0.14.1 [pip3] torchaudio==2.6.0+cu118 [pip3] torchvision==0.21.0+cu118 [pip3] ...

</details> <!-- START COPILOT CODING AGENT SUFFIX -->
  • Fixes pytorch/pytorch#177611
<!-- START COPILOT CODING AGENT TIPS -->

📍 Connect Copilot coding agent with Jira, Azure Boards or Linear to delegate work to Copilot in one click without leaving your project management tool.

Changed files

  • onnxscript/optimizer/_constant_folding.py (modified, +2/-1)
  • onnxscript/optimizer/_constant_folding_test.py (modified, +30/-0)

Code Example

def optimize(self) -> None:
    self.model = onnxscript_apis.optimize(self.model)

---

def optimize_ir(
    model: ir.Model,
    num_iterations: int = 2,
    *,
    onnx_shape_inference: bool = True,
    stop_if_no_change: bool = True,
    input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
    output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
    inline: bool = True,
) -> None:
    passes = [
        ir.passes.PassManager(
            [
                _constant_folding.FoldConstantsPass(
                    shape_inference=onnx_shape_inference,
                    input_size_limit=input_size_limit,
                    output_size_limit=output_size_limit,
                ),
                rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
                common_passes.RemoveUnusedNodesPass(),
                common_passes.RemoveUnusedFunctionsPass(),
                common_passes.RemoveUnusedOpsetsPass(),
            ],
            steps=num_iterations,
            early_stop=stop_if_no_change,
        ),
    ......

---

import copy
import torch
import torch.nn as nn
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from onnxslim import slim
import onnx


class ConvBnReluModel(nn.Module):
    def __init__(self, eps=1e-3, momentum=0.03):
        super().__init__()
        self.conv = nn.Conv2d(4, 4, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(4, eps=eps, momentum=momentum)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


def get_batch_norm_node_args(gm):
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.batch_norm.default:
            return tuple(node.args)
    raise RuntimeError("No aten.batch_norm.default node found")


torch.manual_seed(0)
device = 'cuda' 

model = ConvBnReluModel().train().to(device)
inputs = (torch.randn(2, 4, 8, 8).to(device),)
exported = torch.export.export_for_training(copy.deepcopy(model), inputs).module()
print("before prepare:", get_batch_norm_node_args(exported))

quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))
prepared = prepare_qat_pt2e(exported, quantizer)
prepared.to(device)
torch.ao.quantization.move_exported_model_to_eval(prepared)
torch.ao.quantization.allow_exported_model_train_eval(prepared)
prepared.eval()

#---export quantized model to onnx---
qat_onnx_sp = './quant.onnx'
quantized_model = convert_pt2e(prepared)
print('convert_pt2e done!')
onnx_program = torch.onnx.export(quantized_model, inputs, dynamo=True, opset_version=21)
""" bug """
onnx_program.optimize()
onnx_program.save(qat_onnx_sp)
print(f'export qat model to [{qat_onnx_sp}] done!')

model_simp = slim(onnx_program.model_proto)
sim_path = qat_onnx_sp.replace('.onnx', '_slim.onnx')
onnx.save(model_simp, sim_path)
print(f"save onnx model to [{sim_path}] Successfully!")
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

After the QAT model undergoes the onnx_program.optimize() process, there is a loss of quantization nodes. As shown in the figure on the left is the normal export, and on the right is the abnormal export graph.

<img width="898" height="884" alt="Image" src="https://github.com/user-attachments/assets/481bc3c0-38fe-45f6-9fde-bc1a287617a3" />

This bug occurred in torch/onnx/_internal/exporter/_onnx_program.py:

def optimize(self) -> None:
    self.model = onnxscript_apis.optimize(self.model)

and it internally called the optimize_ir function in onnxscript/optimizer/_optimizer.py. The default value of input_size_limit is 512. Nodes with an input size less than this value will be collapsed.

def optimize_ir(
    model: ir.Model,
    num_iterations: int = 2,
    *,
    onnx_shape_inference: bool = True,
    stop_if_no_change: bool = True,
    input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
    output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
    inline: bool = True,
) -> None:
    passes = [
        ir.passes.PassManager(
            [
                _constant_folding.FoldConstantsPass(
                    shape_inference=onnx_shape_inference,
                    input_size_limit=input_size_limit,
                    output_size_limit=output_size_limit,
                ),
                rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
                common_passes.RemoveUnusedNodesPass(),
                common_passes.RemoveUnusedFunctionsPass(),
                common_passes.RemoveUnusedOpsetsPass(),
            ],
            steps=num_iterations,
            early_stop=stop_if_no_change,
        ),
    ......

⭐ Please enable the parameter optimization function in torch/onnx/_internal/exporter/_onnx_program.py. Otherwise, I will be able to install onnxscript only by referring to the source code.

The smallest reproducible example:

import copy
import torch
import torch.nn as nn
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from onnxslim import slim
import onnx


class ConvBnReluModel(nn.Module):
    def __init__(self, eps=1e-3, momentum=0.03):
        super().__init__()
        self.conv = nn.Conv2d(4, 4, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(4, eps=eps, momentum=momentum)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


def get_batch_norm_node_args(gm):
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.batch_norm.default:
            return tuple(node.args)
    raise RuntimeError("No aten.batch_norm.default node found")


torch.manual_seed(0)
device = 'cuda' 

model = ConvBnReluModel().train().to(device)
inputs = (torch.randn(2, 4, 8, 8).to(device),)
exported = torch.export.export_for_training(copy.deepcopy(model), inputs).module()
print("before prepare:", get_batch_norm_node_args(exported))

quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))
prepared = prepare_qat_pt2e(exported, quantizer)
prepared.to(device)
torch.ao.quantization.move_exported_model_to_eval(prepared)
torch.ao.quantization.allow_exported_model_train_eval(prepared)
prepared.eval()

#---export quantized model to onnx---
qat_onnx_sp = './quant.onnx'
quantized_model = convert_pt2e(prepared)
print('convert_pt2e done!')
onnx_program = torch.onnx.export(quantized_model, inputs, dynamo=True, opset_version=21)
""" bug """
onnx_program.optimize()
onnx_program.save(qat_onnx_sp)
print(f'export qat model to [{qat_onnx_sp}] done!')

model_simp = slim(onnx_program.model_proto)
sim_path = qat_onnx_sp.replace('.onnx', '_slim.onnx')
onnx.save(model_simp, sim_path)
print(f"save onnx model to [{sim_path}] Successfully!")

Versions

Versions of relevant libraries: [pip3] executorch==0.5.0 [pip3] numpy==1.23.5 [pip3] nvidia-cublas-cu11==11.11.3.6 [pip3] nvidia-cuda-cupti-cu11==11.8.87 [pip3] nvidia-cuda-nvrtc-cu11==11.8.89 [pip3] nvidia-cuda-runtime-cu11==11.8.89 [pip3] nvidia-cudnn-cu11==9.1.0.70 [pip3] nvidia-cufft-cu11==10.9.0.58 [pip3] nvidia-curand-cu11==10.3.0.86 [pip3] nvidia-cusolver-cu11==11.4.1.48 [pip3] nvidia-cusparse-cu11==11.7.5.86 [pip3] nvidia-nccl-cu11==2.21.5 [pip3] nvidia-nvtx-cu11==11.8.86 [pip3] onnx==1.17.0 [pip3] onnx_graphsurgeon==0.5.8 [pip3] onnx-ir==0.1.12 [pip3] onnx-simplifier==0.4.36 [pip3] onnxruntime==1.21.0 [pip3] onnxruntime-gpu==1.21.0 [pip3] onnxscript==0.4.0 [pip3] onnxslim==0.1.48 [pip3] torch==2.6.0+cu118 [pip3] torchao==0.14.1 [pip3] torchaudio==2.6.0+cu118 [pip3] torchvision==0.21.0+cu118 [pip3] triton==3.2.0 [conda] executorch 0.5.0 pypi_0 pypi [conda] numpy 1.23.5 pypi_0 pypi [conda] nvidia-cublas-cu11 11.11.3.6 pypi_0 pypi [conda] nvidia-cuda-cupti-cu11 11.8.87 pypi_0 pypi [conda] nvidia-cuda-nvrtc-cu11 11.8.89 pypi_0 pypi [conda] nvidia-cuda-runtime-cu11 11.8.89 pypi_0 pypi [conda] nvidia-cudnn-cu11 9.1.0.70 pypi_0 pypi [conda] nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi [conda] nvidia-curand-cu11 10.3.0.86 pypi_0 pypi [conda] nvidia-cusolver-cu11 11.4.1.48 pypi_0 pypi [conda] nvidia-cusparse-cu11 11.7.5.86 pypi_0 pypi [conda] nvidia-nccl-cu11 2.21.5 pypi_0 pypi [conda] nvidia-nvtx-cu11 11.8.86 pypi_0 pypi [conda] torch 2.6.0+cu118 pypi_0 pypi [conda] torchao 0.14.1 pypi_0 pypi [conda] torchaudio 2.6.0+cu118 pypi_0 pypi [conda] torchvision 0.21.0+cu118 pypi_0 pypi [conda] triton 3.2.0 pypi_0 pypi onnxscript == 0.4.0

extent analysis

Fix Plan

To fix the loss of quantization nodes after the onnx_program.optimize() process, we need to adjust the input_size_limit parameter in the optimize_ir function. The default value of 512 is too low, causing nodes with smaller input sizes to be collapsed.

Here are the steps to fix the issue:

  • Modify the optimize method in torch/onnx/_internal/exporter/_onnx_program.py to accept an input_size_limit parameter:

def optimize(self, input_size_limit: int = 1024) -> None: self.model = onnxscript_apis.optimize(self.model, input_size_limit=input_size_limit)

*   Update the `onnxscript_apis.optimize` function to pass the `input_size_limit` parameter to the `optimize_ir` function:
    ```python
def optimize(model: ir.Model, input_size_limit: int = 1024) -> ir.Model:
    return optimize_ir(model, input_size_limit=input_size_limit)
  • In the optimize_ir function, update the FoldConstantsPass to use the provided input_size_limit:

def optimize_ir( model: ir.Model, num_iterations: int = 2, *, onnx_shape_inference: bool = True, stop_if_no_change: bool = True, input_size_limit: int = 1024, output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, inline: bool = True, ) -> None: passes = [ ir.passes.PassManager( [ _constant_folding.FoldConstantsPass( shape_inference=onnx_shape_inference, input_size_limit=input_size_limit, output_size_limit=output_size_limit, ), # ... other passes ... ], steps=num_iterations, early_stop=stop_if_no_change, ), ]

*   When calling `onnx_program.optimize()`, provide a suitable `input_size_limit` value:
    ```python
onnx_program.optimize(input_size_limit=1024)

Verification

To verify that the fix worked, you can compare the optimized ONNX model with the original one. The optimized model should retain the quantization nodes.

You can use tools like onnx-graphsurgeon or onnx-simplifier to visualize and compare the models.

Extra Tips

When working with ONNX models and optimization, it's essential to be mindful of the trade-offs between model size, performance, and accuracy. Adjusting the input_size_limit parameter can help balance these factors.

Additionally, consider using other optimization techniques, such as quantization-aware training or knowledge distillation, to further improve the model's performance and efficiency.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING