quantized-custom
立即下载
资源介绍:
quantized-custom
import contextlib
import logging
import os
from distutils.version import LooseVersion
from functools import wraps
import torch.onnx
from torch import _C # noqa: N814
from torch.onnx import OperatorExportTypes
# from horizon_plugin_pytorch.utils import _log_first_n
# from . import _register_onnx_ops # noqa: F401
from . import _register_quantized_onnx_ops # noqa: F401
TrainingMode = _C._onnx.TrainingMode
__all__ = ["export_to_onnx", "export_quantized_onnx"]
def _preprocess_graph(func):
"""Remove dead custom registered ops.
Custom registered quantized functions with a tuple returned usually
followed by a prim::TupleUnpack node in traced graph. However, if this func
results is not used by other nodes, there is no prim::TupleUnpack node
followed to unpack the results. Dead code elimination pass in torch can not
remove this node(May this custom op is traced as a block and marked as
live). So find unused custom ops and delete from graph here.
"""
@wraps(func)
def _preprocess(*args, **kwargs):
graph, *args = args
assert type(graph) == torch._C.Graph
dead_vslz_node = []
for node in graph.nodes():
if node.kind() == "prim::PythonOp" and not node.hasUses():
dead_vslz_node.append(node)
for node in dead_vslz_node:
node.destroy()
return func(graph, *args, **kwargs)
return _preprocess
# torch 1.10.2 add some logic in onnx shape inference and use std::cerr
# print warnings in custom registered ops.
# We redirect stderr to null to avoid warnings in each custom op,
# do torch.onnx.export and then redirect stderr back.
@contextlib.contextmanager
def _redirect_stderr():
# Note: Directly use sys.stderr.fileno() cause 'Tee' error in CI/CD
# stderr_fd = sys.stderr.fileno()
stderr_fd = 2
fd = os.open("/dev/null", os.O_WRONLY)
dup_stderr_fd = os.dup(stderr_fd)
try:
yield os.dup2(fd, stderr_fd)
finally:
os.dup2(dup_stderr_fd, stderr_fd)
os.close(fd)
os.close(dup_stderr_fd)
# replace torch.onnx.utils._optimize_graph in torch 1.13 to avoid
# process of autograd function inner implementation
@contextlib.contextmanager
def _redirect_opt_graph():
_torch_optimize_graph = torch.onnx.utils._optimize_graph
try:
if LooseVersion(torch.__version__) >= LooseVersion("1.13"):
from ._optimize_graph_helper import _optimize_graph
torch.onnx.utils._optimize_graph = _preprocess_graph(
_optimize_graph
)
yield True
else:
torch.onnx.utils._optimize_graph = _preprocess_graph(
_torch_optimize_graph
)
yield False
finally:
torch.onnx.utils._optimize_graph = _torch_optimize_graph
@contextlib.contextmanager
def _set_is_in_onnx_export_false():
origin_f = torch.onnx.utils.is_in_onnx_export
try:
if LooseVersion(torch.__version__) >= LooseVersion("1.13"):
torch.onnx.utils.is_in_onnx_export = False
yield
finally:
torch.onnx.utils.is_in_onnx_export = origin_f
def export_to_onnx(
model,
args,
f,
export_params=True,
verbose=False,
training=TrainingMode.EVAL,
input_names=None,
output_names=None,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
opset_version=11,
do_constant_folding=True,
dynamic_axes=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
):
r"""
Export a (float or qat)model into ONNX format.
Args:
model (torch.nn.Module/torch.jit.ScriptModule/ScriptFunction):
the model to be exported.
args (tuple or torch.Tensor):
args can be structured either as:
1. ONLY A TUPLE OF ARGUMENTS::
args = (x, y, z)
The tuple should contain model inputs such that ``model(*args)``
is a valid invocation of the model. Any non-Tensor arguments will
be hard-coded into the exported model; any Tensor arguments will
become inputs of the exported model, in the order they occur in
the tuple.
2. A TENSOR::
args = torch.Tensor([1])
This is equivalent to a 1-ary tuple of that Tensor.
3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED
ARGUMENTS::
args = (x,
{'y': input_y,
'z': input_z})
All but the last element of the tuple will be passed as non-keyword
arguments, and named arguments will be set from the last element.
If a named argument is not present in the dictionary , it is
assigned the default value, or None if a default value is not
provided.
f: a file-like object or a string containing a file name. A binary
protocol buffer will be written to this file.
export_params (bool, default True): if True, all parameters will
be exported.
verbose (bool, default False): if True, prints a description of the
model being exported to stdout, doc_string will be added to graph.
doc_string may contaion mapping of module scope to node name in
future torch onnx.
training (enum, default TrainingMode.EVAL):
if model.training is False and in training mode if model.training
is True.
* ``TrainingMode.EVAL``: export the model in inference mode.
* ``TrainingMode.PRESERVE``: export the model in inference mode
* ``TrainingMode.TRAINING``: export the model in training mode.
Disables optimizations which might interfere with training.
input_names (list of str, default empty list): names to assign to the
input nodes of the graph, in order.
output_names (list of str, default empty list): names to assign to the
output nodes of the graph, in order.
operator_export_type (enum, default ONNX_FALLTHROUGH):
* ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops
(in the default opset domain).
* ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops
to standard ONNX ops in the default opset domain.
* ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the
TorchScript namespace "aten") are exported as ATen ops.
* ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each
ATen op (in the TorchScript namespace "aten") as a regular ONNX
op. If we are unable to do so,fall back to exporting an ATen op.
opset_version (int, default 11): by default we export the model to the
opset version of the onnx submodule.
do_constant_folding (bool, default False): Apply the constant-folding
optimization. Constant-folding will replace some of the ops that
have all constant inputs with pre-computed constant nodes.
dynamic_axes (dict>, default empty dict):
By default the exported model will have the shapes of all input
and output tensors set to exactly match those given in ``args``
(and ``example_outputs`` when that arg is required). To specify
axes of tensors as dynamic (i.e. known only at run-time), set
``dynamic_axes`` to a dict with schema:
* KEY (str): an input or output name. Each name must also be
provided in ``input_names`` or ``output_names``.
* VALUE (dict or list): If a dict, keys are axis indices and
values are axis names. If a list, each element is an axis index.
keep_initializers_as_inputs (bool, default None): If True, all the
initializers (typically corresponding to parameters) in t