onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,652 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
2
|
+
import numpy as np
|
|
3
|
+
from onnx import (
|
|
4
|
+
AttributeProto,
|
|
5
|
+
GraphProto,
|
|
6
|
+
FunctionProto,
|
|
7
|
+
ModelProto,
|
|
8
|
+
NodeProto,
|
|
9
|
+
TypeProto,
|
|
10
|
+
ValueInfoProto,
|
|
11
|
+
helper as oh,
|
|
12
|
+
load,
|
|
13
|
+
save as onnx_save,
|
|
14
|
+
shape_inference as shi,
|
|
15
|
+
)
|
|
16
|
+
from onnx.defs import onnx_opset_version
|
|
17
|
+
import onnxruntime
|
|
18
|
+
from ..helpers import string_type
|
|
19
|
+
from ..helpers.onnx_helper import pretty_onnx, dtype_to_tensor_dtype, to_array_extended
|
|
20
|
+
from ..helpers.ort_session import (
|
|
21
|
+
InferenceSessionForTorch,
|
|
22
|
+
InferenceSessionForNumpy,
|
|
23
|
+
_InferenceSession,
|
|
24
|
+
)
|
|
25
|
+
from ..helpers.torch_helper import to_tensor
|
|
26
|
+
from .report_results_comparison import ReportResultComparison
|
|
27
|
+
from .evaluator import ExtendedReferenceEvaluator
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
|
|
31
|
+
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class OnnxruntimeEvaluator:
|
|
35
|
+
"""
|
|
36
|
+
This class loads an onnx model and the executes one by one the nodes
|
|
37
|
+
with onnxruntime. This class is mostly meant for debugging.
|
|
38
|
+
|
|
39
|
+
:param proto: proto or filename
|
|
40
|
+
:param session_options: options
|
|
41
|
+
:param providers: providers
|
|
42
|
+
:param nvtx: enable nvidia events
|
|
43
|
+
:param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers
|
|
44
|
+
:param graph_optimization_level: see :class:`onnxruntime.SessionOptions`
|
|
45
|
+
:param log_severity_level: see :class:`onnxruntime.SessionOptions`
|
|
46
|
+
:param log_verbosity_level: see :class:`onnxruntime.SessionOptions`
|
|
47
|
+
:param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
|
|
48
|
+
:param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
|
|
49
|
+
:param use_training_api: use onnxruntime-traning API
|
|
50
|
+
:param verbose: verbosity
|
|
51
|
+
:param local_functions: additional local function
|
|
52
|
+
:param ir_version: ir version to use when unknown
|
|
53
|
+
:param opsets: opsets to use when unknown
|
|
54
|
+
:param whole: if True, do not split node by node
|
|
55
|
+
:param torch_or_numpy: force the use of one of them, True for torch,
|
|
56
|
+
False for numpy, None to let the class choose
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
proto: Union[str, Proto, "OnnxruntimeEvaluator"],
|
|
62
|
+
session_options: Optional[onnxruntime.SessionOptions] = None,
|
|
63
|
+
providers: Optional[Union[str, List[str]]] = None,
|
|
64
|
+
nvtx: bool = False,
|
|
65
|
+
enable_profiling: bool = False,
|
|
66
|
+
graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
|
|
67
|
+
log_severity_level: Optional[int] = None,
|
|
68
|
+
log_verbosity_level: Optional[int] = None,
|
|
69
|
+
optimized_model_filepath: Optional[str] = None,
|
|
70
|
+
disable_aot_function_inlining: Optional[bool] = None,
|
|
71
|
+
use_training_api: bool = False,
|
|
72
|
+
verbose: int = 0,
|
|
73
|
+
local_functions: Optional[
|
|
74
|
+
Dict[Tuple[str, str], Union[Proto, "OnnxruntimeEvaluator"]]
|
|
75
|
+
] = None,
|
|
76
|
+
ir_version: int = 10,
|
|
77
|
+
opsets: Optional[Union[int, Dict[str, int]]] = None,
|
|
78
|
+
whole: bool = False,
|
|
79
|
+
torch_or_numpy: Optional[bool] = None,
|
|
80
|
+
):
|
|
81
|
+
if isinstance(proto, str):
|
|
82
|
+
self.proto: Proto = load(proto)
|
|
83
|
+
elif isinstance(proto, OnnxruntimeEvaluator):
|
|
84
|
+
assert isinstance(
|
|
85
|
+
proto.proto, PROTO
|
|
86
|
+
), f"Unexpected type for proto.proto {type(proto.proto)}"
|
|
87
|
+
self.proto = proto.proto
|
|
88
|
+
else:
|
|
89
|
+
self.proto = proto
|
|
90
|
+
assert isinstance(
|
|
91
|
+
self.proto, PROTO
|
|
92
|
+
), f"Unexpected type for self.proto {type(self.proto)}"
|
|
93
|
+
|
|
94
|
+
self._cache: Dict[
|
|
95
|
+
Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037
|
|
96
|
+
] = {}
|
|
97
|
+
self.ir_version = ir_version
|
|
98
|
+
self.opsets = opsets
|
|
99
|
+
self.session_kwargs: Dict[str, Any] = dict(
|
|
100
|
+
session_options=session_options,
|
|
101
|
+
providers=providers,
|
|
102
|
+
nvtx=nvtx,
|
|
103
|
+
enable_profiling=enable_profiling,
|
|
104
|
+
graph_optimization_level=graph_optimization_level,
|
|
105
|
+
log_severity_level=log_severity_level,
|
|
106
|
+
log_verbosity_level=log_verbosity_level,
|
|
107
|
+
optimized_model_filepath=optimized_model_filepath,
|
|
108
|
+
disable_aot_function_inlining=disable_aot_function_inlining,
|
|
109
|
+
use_training_api=use_training_api,
|
|
110
|
+
)
|
|
111
|
+
self.to_tensor_or_array = to_array_extended if not torch_or_numpy else to_tensor
|
|
112
|
+
|
|
113
|
+
self.verbose = verbose
|
|
114
|
+
self.torch_or_numpy = torch_or_numpy
|
|
115
|
+
self.sess_: Optional[_InferenceSession] = None
|
|
116
|
+
if whole:
|
|
117
|
+
self.nodes: Optional[List[NodeProto]] = None
|
|
118
|
+
self.rt_inits_: Optional[Dict[str, Any]] = None
|
|
119
|
+
self.rt_nodes_: Optional[List[NodeProto]] = None
|
|
120
|
+
else:
|
|
121
|
+
self.nodes = (
|
|
122
|
+
[self.proto]
|
|
123
|
+
if isinstance(self.proto, NodeProto)
|
|
124
|
+
else (
|
|
125
|
+
list(
|
|
126
|
+
self.proto.graph.node
|
|
127
|
+
if hasattr(self.proto, "graph")
|
|
128
|
+
else self.proto.node
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
self.rt_inits_ = (
|
|
133
|
+
{
|
|
134
|
+
init.name: self.to_tensor_or_array(init)
|
|
135
|
+
for init in self.proto.graph.initializer
|
|
136
|
+
}
|
|
137
|
+
if hasattr(self.proto, "graph")
|
|
138
|
+
else {}
|
|
139
|
+
)
|
|
140
|
+
self.rt_nodes_ = self.nodes.copy()
|
|
141
|
+
|
|
142
|
+
self.local_functions: Dict[Tuple[str, str], "OnnxruntimeEvaluator"] = ( # noqa: UP037
|
|
143
|
+
{(f.domain, f.name): self.__class__(f) for f in self.proto.functions}
|
|
144
|
+
if hasattr(self.proto, "functions")
|
|
145
|
+
else {}
|
|
146
|
+
)
|
|
147
|
+
if local_functions:
|
|
148
|
+
self.local_functions.update(local_functions)
|
|
149
|
+
self.garbage_collector = self._build_garbage_collector() if self.rt_nodes_ else {}
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def input_names(self) -> List[str]:
|
|
153
|
+
"Returns input names."
|
|
154
|
+
assert self.proto, "self.proto is empty"
|
|
155
|
+
if isinstance(self.proto, NodeProto):
|
|
156
|
+
assert isinstance(
|
|
157
|
+
self.nodes, list
|
|
158
|
+
), f"Unexpected type {type(self.nodes)} for self.nodes"
|
|
159
|
+
return self.nodes[0].input
|
|
160
|
+
return [
|
|
161
|
+
getattr(o, "name", o)
|
|
162
|
+
for o in (
|
|
163
|
+
self.proto.graph.input if hasattr(self.proto, "graph") else self.proto.input
|
|
164
|
+
)
|
|
165
|
+
]
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def output_names(self) -> List[str]:
|
|
169
|
+
"Returns output names."
|
|
170
|
+
assert self.proto, "self.proto is empty"
|
|
171
|
+
if isinstance(self.proto, NodeProto):
|
|
172
|
+
assert isinstance(
|
|
173
|
+
self.nodes, list
|
|
174
|
+
), f"Unexpected type {type(self.nodes)} for self.nodes"
|
|
175
|
+
return self.nodes[0].output
|
|
176
|
+
return [
|
|
177
|
+
getattr(o, "name", o)
|
|
178
|
+
for o in (
|
|
179
|
+
self.proto.graph.output if hasattr(self.proto, "graph") else self.proto.output
|
|
180
|
+
)
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def input_types(self) -> List[TypeProto]:
|
|
185
|
+
"Returns input types."
|
|
186
|
+
if not isinstance(self.proto, (ModelProto, GraphProto)):
|
|
187
|
+
raise ValueError(f"Cannot guess input types for type {type(self.proto)}")
|
|
188
|
+
g = self.proto.graph if hasattr(self.proto, "graph") else self.proto
|
|
189
|
+
return [i.type for i in g.input]
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def output_types(self) -> List[TypeProto]:
|
|
193
|
+
"Returns output types."
|
|
194
|
+
if not isinstance(self.proto, (ModelProto, GraphProto)):
|
|
195
|
+
raise ValueError(f"Cannot guess output types for type {type(self.proto)}")
|
|
196
|
+
g = self.proto.graph if hasattr(self.proto, "graph") else self.proto
|
|
197
|
+
return [i.type for i in g.output]
|
|
198
|
+
|
|
199
|
+
def _log_arg(self, a: Any) -> Any:
|
|
200
|
+
if isinstance(a, (str, int, float)):
|
|
201
|
+
return a
|
|
202
|
+
device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
|
|
203
|
+
if hasattr(a, "shape"):
|
|
204
|
+
prefix = "A:" if hasattr(a, "astype") else "T:"
|
|
205
|
+
if self.verbose < 4: # noqa: PLR2004
|
|
206
|
+
return f"{prefix}{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]"
|
|
207
|
+
elements = a.ravel().tolist()
|
|
208
|
+
if len(elements) > 10: # noqa: PLR2004
|
|
209
|
+
elements = elements[:10]
|
|
210
|
+
return f"{prefix}{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..."
|
|
211
|
+
return f"{prefix}{device}{a.dtype}:{a.shape}:{elements}"
|
|
212
|
+
if hasattr(a, "append"):
|
|
213
|
+
return ", ".join(map(self._log_arg, a))
|
|
214
|
+
return a
|
|
215
|
+
|
|
216
|
+
def _log(self, level: int, pattern: str, *args: Any) -> None:
|
|
217
|
+
if level < self.verbose:
|
|
218
|
+
new_args = [self._log_arg(a) for a in args]
|
|
219
|
+
print(pattern % tuple(new_args))
|
|
220
|
+
|
|
221
|
+
def _is_local_function(self, node: NodeProto) -> bool:
|
|
222
|
+
return (node.domain, node.op_type) in self.local_functions
|
|
223
|
+
|
|
224
|
+
def run(
|
|
225
|
+
self,
|
|
226
|
+
outputs: Optional[List[str]],
|
|
227
|
+
feed_inputs: Dict[str, Any],
|
|
228
|
+
intermediate: bool = False,
|
|
229
|
+
report_cmp: Optional[ReportResultComparison] = None,
|
|
230
|
+
) -> Union[Dict[str, Any], List[Any]]:
|
|
231
|
+
"""
|
|
232
|
+
Runs the model.
|
|
233
|
+
It only works with numpy arrays.
|
|
234
|
+
|
|
235
|
+
:param outputs: required outputs or None for all
|
|
236
|
+
:param feed_inputs: inputs
|
|
237
|
+
:param intermediate: returns all output instead of the last ones
|
|
238
|
+
:param report_cmp: used as a reference,
|
|
239
|
+
every intermediate results is compare to every existing one,
|
|
240
|
+
if not empty, it is an instance of
|
|
241
|
+
:class:`onnx_diagnostic.reference.ReportResultComparison`
|
|
242
|
+
:return: outputs, as a list if return_all is False,
|
|
243
|
+
as a dictionary if return_all is True
|
|
244
|
+
"""
|
|
245
|
+
if self.rt_nodes_ is None:
|
|
246
|
+
# runs a whole
|
|
247
|
+
if self.sess_ is None:
|
|
248
|
+
assert self.proto, "self.proto is empty"
|
|
249
|
+
_, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
|
|
250
|
+
assert self.sess_, "mypy not happy"
|
|
251
|
+
return self.sess_.run(outputs, feed_inputs)
|
|
252
|
+
if outputs is None:
|
|
253
|
+
outputs = self.output_names
|
|
254
|
+
results: Dict[str, Any] = (self.rt_inits_ or {}).copy()
|
|
255
|
+
|
|
256
|
+
for k, v in results.items():
|
|
257
|
+
self._log(2, " +C %s: %s", k, v)
|
|
258
|
+
for k, v in feed_inputs.items():
|
|
259
|
+
assert not isinstance(v, str), f"Unexpected type str for {k!r}"
|
|
260
|
+
self._log(2, " +I %s: %s", k, v)
|
|
261
|
+
results[k] = v
|
|
262
|
+
|
|
263
|
+
for i_node, node in enumerate(self.rt_nodes_ or []):
|
|
264
|
+
self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output)
|
|
265
|
+
for i in node.input:
|
|
266
|
+
if i != "" and i not in results:
|
|
267
|
+
raise RuntimeError(
|
|
268
|
+
f"Unable to find input {i!r} in known results {sorted(results)}, "
|
|
269
|
+
f"self.rt_inits_ has {sorted((self.rt_inits_ or {}))}, "
|
|
270
|
+
f"feed_inputs has {sorted(feed_inputs)}."
|
|
271
|
+
)
|
|
272
|
+
inputs = [(results[i] if i != "" else None) for i in node.input]
|
|
273
|
+
if node.op_type == "If" and node.domain == "":
|
|
274
|
+
outputs = self._run_if(node, inputs, results)
|
|
275
|
+
elif node.op_type in {"Scan", "Loop"} and node.domain == "":
|
|
276
|
+
outputs = self._run_scan(node, inputs, results)
|
|
277
|
+
elif self._is_local_function(node):
|
|
278
|
+
outputs = self._run_local(node, inputs, results)
|
|
279
|
+
else:
|
|
280
|
+
outputs = self._run(node, inputs, results)
|
|
281
|
+
for name, value in zip(node.output, outputs):
|
|
282
|
+
if name == "":
|
|
283
|
+
continue
|
|
284
|
+
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
|
|
285
|
+
assert isinstance(name, str), f"unexpected type for name {type(name)}"
|
|
286
|
+
results[name] = value
|
|
287
|
+
if report_cmp:
|
|
288
|
+
reported = report_cmp.report(dict(zip(node.output, outputs)))
|
|
289
|
+
if self.verbose > 1:
|
|
290
|
+
print(f" -- report {len(reported)} comparisons")
|
|
291
|
+
if not intermediate:
|
|
292
|
+
self._clean_unused_inplace(i_node, node, results)
|
|
293
|
+
|
|
294
|
+
if intermediate:
|
|
295
|
+
return results
|
|
296
|
+
output_names = self.output_names
|
|
297
|
+
for name in output_names:
|
|
298
|
+
if name == "":
|
|
299
|
+
continue
|
|
300
|
+
if name not in results:
|
|
301
|
+
raise RuntimeError(
|
|
302
|
+
f"Unable to find output name {name!r} "
|
|
303
|
+
f"in {sorted(results)}, proto is\n{pretty_onnx(self.proto)}"
|
|
304
|
+
)
|
|
305
|
+
return [results[name] for name in output_names if name != ""]
|
|
306
|
+
|
|
307
|
+
def _build_garbage_collector(self) -> Dict[str, int]:
|
|
308
|
+
"""
|
|
309
|
+
Memorizes the results not needed anymore for every node.
|
|
310
|
+
Returns a dictionary with the last node using the results.
|
|
311
|
+
"""
|
|
312
|
+
needed = {}
|
|
313
|
+
for i, node in enumerate(self.rt_nodes_ or []):
|
|
314
|
+
for name in node.input:
|
|
315
|
+
needed[name] = i
|
|
316
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
317
|
+
hidden = self._get_hidden_node_inputs(node)
|
|
318
|
+
for name in hidden:
|
|
319
|
+
needed[name] = i
|
|
320
|
+
if isinstance(self.proto, ModelProto):
|
|
321
|
+
for o in self.proto.graph.output:
|
|
322
|
+
needed[o.name] = len(self.rt_nodes_ or [])
|
|
323
|
+
elif isinstance(self.proto, GraphProto):
|
|
324
|
+
for o in self.proto.output:
|
|
325
|
+
needed[o.name] = len(self.rt_nodes_ or [])
|
|
326
|
+
elif isinstance(self.proto, FunctionProto):
|
|
327
|
+
for o in self.proto.output:
|
|
328
|
+
needed[o] = len(self.rt_nodes_ or [])
|
|
329
|
+
return needed
|
|
330
|
+
|
|
331
|
+
def _clean_unused_inplace(self, i_node: int, node: NodeProto, results: Dict[str, Any]):
|
|
332
|
+
"""
|
|
333
|
+
Cleans all results not needed anymore. Some models requires to clean the memory
|
|
334
|
+
to be able to run.
|
|
335
|
+
"""
|
|
336
|
+
if not self.garbage_collector:
|
|
337
|
+
return
|
|
338
|
+
for name in node.input:
|
|
339
|
+
if self.garbage_collector[name] == i_node and name in results:
|
|
340
|
+
if self.verbose:
|
|
341
|
+
t = results[name]
|
|
342
|
+
print(f" - deletes: {name} - {t.dtype}:{t.shape}")
|
|
343
|
+
del results[name]
|
|
344
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
345
|
+
hidden = self._get_hidden_node_inputs(node)
|
|
346
|
+
for name in hidden:
|
|
347
|
+
if self.garbage_collector[name] == i_node and name in results:
|
|
348
|
+
if self.verbose:
|
|
349
|
+
t = results[name]
|
|
350
|
+
print(f" - deletes: {name} - {t.dtype}:{t.shape}")
|
|
351
|
+
del results[name]
|
|
352
|
+
|
|
353
|
+
def _make_model_proto(
|
|
354
|
+
self,
|
|
355
|
+
nodes: Sequence[NodeProto],
|
|
356
|
+
vinputs: Sequence[ValueInfoProto],
|
|
357
|
+
voutputs: Sequence[ValueInfoProto],
|
|
358
|
+
) -> ModelProto:
|
|
359
|
+
onx = oh.make_model(
|
|
360
|
+
oh.make_graph(nodes, "-", vinputs, voutputs),
|
|
361
|
+
ir_version=getattr(self.proto, "ir_version", self.ir_version),
|
|
362
|
+
functions=getattr(self.proto, "functions", None),
|
|
363
|
+
)
|
|
364
|
+
del onx.opset_import[:]
|
|
365
|
+
if hasattr(self.proto, "opset_import"):
|
|
366
|
+
onx.opset_import.extend(self.proto.opset_import)
|
|
367
|
+
elif self.opsets:
|
|
368
|
+
if isinstance(self.opsets, int):
|
|
369
|
+
onx.opset_import.append(oh.make_opsetid("", self.opsets))
|
|
370
|
+
else:
|
|
371
|
+
onx.opset_import.extend(
|
|
372
|
+
[oh.make_opsetid(k, v) for k, v in self.opsets.items()]
|
|
373
|
+
)
|
|
374
|
+
else:
|
|
375
|
+
onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
|
|
376
|
+
|
|
377
|
+
# That helps fixing bugs.
|
|
378
|
+
onx = shi.infer_shapes(onx)
|
|
379
|
+
return onx
|
|
380
|
+
|
|
381
|
+
@classmethod
|
|
382
|
+
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
|
|
383
|
+
"""
|
|
384
|
+
Returns the hidden inputs (inputs coming from an upper context)
|
|
385
|
+
used by a subgraph.
|
|
386
|
+
"""
|
|
387
|
+
hidden = set()
|
|
388
|
+
memo = set(i.name for i in graph.initializer)
|
|
389
|
+
memo |= set(i.name for i in graph.sparse_initializer)
|
|
390
|
+
for node in graph.node:
|
|
391
|
+
for i in node.input:
|
|
392
|
+
if i not in memo:
|
|
393
|
+
hidden.add(i)
|
|
394
|
+
for att in node.attribute:
|
|
395
|
+
if att.type == AttributeProto.GRAPH and att.g:
|
|
396
|
+
hid = self._get_hidden_inputs(att.g)
|
|
397
|
+
less = set(h for h in hid if h not in memo)
|
|
398
|
+
hidden |= less
|
|
399
|
+
memo |= set(node.output)
|
|
400
|
+
return hidden
|
|
401
|
+
|
|
402
|
+
@classmethod
|
|
403
|
+
def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]:
|
|
404
|
+
"""Calls multiple _get_hidden_inputs on every attribute."""
|
|
405
|
+
if node.op_type not in {"Loop", "Scan", "If"}:
|
|
406
|
+
return set()
|
|
407
|
+
hidden = set()
|
|
408
|
+
for att in node.attribute:
|
|
409
|
+
if att.type == AttributeProto.GRAPH:
|
|
410
|
+
hidden |= self._get_hidden_inputs(att.g)
|
|
411
|
+
return hidden - (hidden & set(node.input))
|
|
412
|
+
|
|
413
|
+
def _get_sess(
|
|
414
|
+
self, node: Union[ModelProto, NodeProto], inputs: List[Any]
|
|
415
|
+
) -> Tuple[ModelProto, _InferenceSession]:
|
|
416
|
+
if isinstance(node, ModelProto):
|
|
417
|
+
onx = node
|
|
418
|
+
else:
|
|
419
|
+
assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
|
|
420
|
+
if node.op_type == "Constant":
|
|
421
|
+
# We force the type to be a boolean.
|
|
422
|
+
ref = ExtendedReferenceEvaluator(node)
|
|
423
|
+
cst = ref.run(None, {})[0]
|
|
424
|
+
vinputs: List[ValueInfoProto] = []
|
|
425
|
+
voutputs = [
|
|
426
|
+
oh.make_tensor_value_info(
|
|
427
|
+
node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape
|
|
428
|
+
)
|
|
429
|
+
]
|
|
430
|
+
else:
|
|
431
|
+
unique_names = set()
|
|
432
|
+
vinputs = []
|
|
433
|
+
for i, it in zip(node.input, inputs):
|
|
434
|
+
if i == "" or i in unique_names:
|
|
435
|
+
continue
|
|
436
|
+
unique_names.add(i)
|
|
437
|
+
value = oh.make_tensor_value_info(
|
|
438
|
+
i, dtype_to_tensor_dtype(it.dtype), it.shape
|
|
439
|
+
)
|
|
440
|
+
vinputs.append(value)
|
|
441
|
+
|
|
442
|
+
# no need to run shape inference
|
|
443
|
+
voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
|
|
444
|
+
|
|
445
|
+
onx = self._make_model_proto([node], vinputs, voutputs)
|
|
446
|
+
|
|
447
|
+
cls = (
|
|
448
|
+
InferenceSessionForNumpy
|
|
449
|
+
if any(isinstance(i, np.ndarray) for i in inputs)
|
|
450
|
+
and (not isinstance(self.torch_or_numpy, bool) or not self.torch_or_numpy)
|
|
451
|
+
else InferenceSessionForTorch
|
|
452
|
+
)
|
|
453
|
+
try:
|
|
454
|
+
sess = cls(onx, **self.session_kwargs)
|
|
455
|
+
except (
|
|
456
|
+
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
|
|
457
|
+
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
|
|
458
|
+
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument,
|
|
459
|
+
) as e:
|
|
460
|
+
onnx_save(onx, "_debug_OnnxruntimeEvaluator_last_failure.onnx")
|
|
461
|
+
raise RuntimeError(
|
|
462
|
+
f"Unable to infer a session with inputs\n{string_type(inputs)}"
|
|
463
|
+
f"\ndue to {e}\n{pretty_onnx(onx)}"
|
|
464
|
+
) from e
|
|
465
|
+
return onx, sess
|
|
466
|
+
|
|
467
|
+
def _get_sess_init_subgraph(
|
|
468
|
+
self, node: NodeProto, inputs: List[Any], context: Dict[str, Any], g: GraphProto
|
|
469
|
+
) -> List[Any]:
|
|
470
|
+
unique_names = set()
|
|
471
|
+
vinputs = []
|
|
472
|
+
for i, it in zip(node.input, inputs):
|
|
473
|
+
if i == "" or i in unique_names:
|
|
474
|
+
continue
|
|
475
|
+
unique_names.add(i)
|
|
476
|
+
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
|
|
477
|
+
vinputs.append(value)
|
|
478
|
+
|
|
479
|
+
reduced_set = self._get_hidden_inputs(g)
|
|
480
|
+
for i, v in context.items():
|
|
481
|
+
if i in reduced_set and i not in unique_names:
|
|
482
|
+
unique_names.add(i)
|
|
483
|
+
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape)
|
|
484
|
+
vinputs.append(value)
|
|
485
|
+
return vinputs
|
|
486
|
+
|
|
487
|
+
def _get_sess_if(
|
|
488
|
+
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
|
|
489
|
+
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
|
|
490
|
+
g = None
|
|
491
|
+
for att in node.attribute:
|
|
492
|
+
if att.name == branch:
|
|
493
|
+
g = att.g
|
|
494
|
+
assert g, f"Missing attribute {branch!r}"
|
|
495
|
+
vinputs = self._get_sess_init_subgraph(node, inputs, context, g)
|
|
496
|
+
|
|
497
|
+
voutputs = g.output
|
|
498
|
+
|
|
499
|
+
identities = [
|
|
500
|
+
oh.make_node("Identity", [iname], [ginput.name])
|
|
501
|
+
for iname, ginput in zip(node.input, g.input)
|
|
502
|
+
]
|
|
503
|
+
|
|
504
|
+
onx = self._make_model_proto([*identities, *g.node], vinputs, voutputs)
|
|
505
|
+
sess = OnnxruntimeEvaluator(
|
|
506
|
+
onx,
|
|
507
|
+
local_functions=self.local_functions,
|
|
508
|
+
verbose=self.verbose,
|
|
509
|
+
ir_version=self.ir_version,
|
|
510
|
+
opsets=self.opsets,
|
|
511
|
+
torch_or_numpy=self.torch_or_numpy,
|
|
512
|
+
**self.session_kwargs,
|
|
513
|
+
)
|
|
514
|
+
return onx, sess
|
|
515
|
+
|
|
516
|
+
def _get_sess_local(
|
|
517
|
+
self, node: NodeProto, inputs: List[Any]
|
|
518
|
+
) -> Tuple[FunctionProto, "OnnxruntimeEvaluator"]:
|
|
519
|
+
ev = self.local_functions[node.domain, node.op_type]
|
|
520
|
+
sess = OnnxruntimeEvaluator(
|
|
521
|
+
ev,
|
|
522
|
+
local_functions=self.local_functions,
|
|
523
|
+
verbose=self.verbose,
|
|
524
|
+
ir_version=self.ir_version,
|
|
525
|
+
opsets=self.opsets,
|
|
526
|
+
torch_or_numpy=self.torch_or_numpy,
|
|
527
|
+
**self.session_kwargs,
|
|
528
|
+
)
|
|
529
|
+
return ev.proto, sess
|
|
530
|
+
|
|
531
|
+
def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]:
|
|
532
|
+
"""Runs a node."""
|
|
533
|
+
types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
|
|
534
|
+
key = (id(node), *types)
|
|
535
|
+
if key in self._cache:
|
|
536
|
+
sess = self._cache[key][1]
|
|
537
|
+
else:
|
|
538
|
+
onx, sess = self._get_sess(node, inputs)
|
|
539
|
+
self._cache[key] = onx, sess
|
|
540
|
+
|
|
541
|
+
feeds = dict(zip(node.input, inputs))
|
|
542
|
+
if "" in feeds:
|
|
543
|
+
feeds[""] = np.array([0], dtype=np.float32)
|
|
544
|
+
|
|
545
|
+
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
546
|
+
outputs = list(sess.run(None, feeds))
|
|
547
|
+
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
|
|
548
|
+
return outputs
|
|
549
|
+
|
|
550
|
+
def _run_if(
|
|
551
|
+
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
|
|
552
|
+
) -> List[Any]:
|
|
553
|
+
"""Runs a node If."""
|
|
554
|
+
feeds = dict(zip(node.input, inputs))
|
|
555
|
+
feeds.update(results)
|
|
556
|
+
if feeds[node.input[0]]:
|
|
557
|
+
name = "then_branch"
|
|
558
|
+
else:
|
|
559
|
+
name = "else_branch"
|
|
560
|
+
|
|
561
|
+
key = (id(node), name)
|
|
562
|
+
if key in self._cache:
|
|
563
|
+
sess = self._cache[key][1]
|
|
564
|
+
else:
|
|
565
|
+
self._cache[key] = _onx, sess = self._get_sess_if(node, name, inputs, results)
|
|
566
|
+
|
|
567
|
+
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
568
|
+
feeds = {name: results[name] for name in sess.input_names}
|
|
569
|
+
outputs = sess.run(None, feeds)
|
|
570
|
+
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
|
|
571
|
+
return outputs
|
|
572
|
+
|
|
573
|
+
def _get_sess_scan(
|
|
574
|
+
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
|
|
575
|
+
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
|
|
576
|
+
g = None
|
|
577
|
+
for att in node.attribute:
|
|
578
|
+
if att.name == branch:
|
|
579
|
+
g = att.g
|
|
580
|
+
assert g, f"Missing attribute {branch!r}"
|
|
581
|
+
vinputs = self._get_sess_init_subgraph(node, inputs, context, g)
|
|
582
|
+
|
|
583
|
+
begin = 0 if node.op_type == "Scan" else 1
|
|
584
|
+
voutputs = []
|
|
585
|
+
for name, _goutput in zip(node.output, g.output[begin:]):
|
|
586
|
+
v = ValueInfoProto()
|
|
587
|
+
# v.ParseFromString(goutput.SerializeToString())
|
|
588
|
+
v.name = name
|
|
589
|
+
voutputs.append(v)
|
|
590
|
+
|
|
591
|
+
# identities = []
|
|
592
|
+
# for iname, ginput in zip(node.input, g.input):
|
|
593
|
+
# identities.append(oh.make_node("Identity", [iname], [ginput.name]))
|
|
594
|
+
|
|
595
|
+
onx = self._make_model_proto([node], vinputs, voutputs)
|
|
596
|
+
sess = OnnxruntimeEvaluator(
|
|
597
|
+
onx,
|
|
598
|
+
local_functions=self.local_functions,
|
|
599
|
+
verbose=self.verbose,
|
|
600
|
+
ir_version=self.ir_version,
|
|
601
|
+
opsets=self.opsets,
|
|
602
|
+
torch_or_numpy=self.torch_or_numpy,
|
|
603
|
+
whole=True,
|
|
604
|
+
**self.session_kwargs,
|
|
605
|
+
)
|
|
606
|
+
return onx, sess
|
|
607
|
+
|
|
608
|
+
def _run_scan(
|
|
609
|
+
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
|
|
610
|
+
) -> List[Any]:
|
|
611
|
+
"""Runs a node Scan."""
|
|
612
|
+
feeds = dict(zip(node.input, inputs))
|
|
613
|
+
feeds.update(results)
|
|
614
|
+
name = "body"
|
|
615
|
+
key = (id(node), name)
|
|
616
|
+
if key in self._cache:
|
|
617
|
+
sess = self._cache[key][1]
|
|
618
|
+
else:
|
|
619
|
+
self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results)
|
|
620
|
+
|
|
621
|
+
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
622
|
+
feeds = {name: results[name] for name in sess.input_names}
|
|
623
|
+
outputs = sess.run(None, feeds)
|
|
624
|
+
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
|
|
625
|
+
return outputs
|
|
626
|
+
|
|
627
|
+
def _run_local(
|
|
628
|
+
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
|
|
629
|
+
) -> List[Any]:
|
|
630
|
+
"""Runs a node."""
|
|
631
|
+
types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
|
|
632
|
+
key = (id(node), *types)
|
|
633
|
+
if key in self._cache:
|
|
634
|
+
sess = self._cache[key][1]
|
|
635
|
+
else:
|
|
636
|
+
onx, sess = self._get_sess_local(node, inputs)
|
|
637
|
+
self._cache[key] = onx, sess
|
|
638
|
+
|
|
639
|
+
replace = dict(zip(node.input, sess.input_names))
|
|
640
|
+
assert len(node.input) == len(sess.input_names), (
|
|
641
|
+
f"Input mismatch: input_names={sess.input_names}, "
|
|
642
|
+
f"replace={replace}, "
|
|
643
|
+
f"type(self.proto)={type(self.proto)}, and node=\n{node}"
|
|
644
|
+
)
|
|
645
|
+
feeds = {replace[i]: v for i, v in zip(node.input, inputs)}
|
|
646
|
+
if "" in feeds:
|
|
647
|
+
feeds[""] = np.array([0], dtype=np.float32)
|
|
648
|
+
|
|
649
|
+
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
650
|
+
outputs = sess.run(None, feeds)
|
|
651
|
+
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
|
|
652
|
+
return outputs
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class QuantizedTensor:
|
|
5
|
+
"""
|
|
6
|
+
Quantizes a vector in range [0, 255].
|
|
7
|
+
|
|
8
|
+
:param tensor: original tensor
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, tensor):
|
|
12
|
+
_min = tensor.min()
|
|
13
|
+
_max = tensor.max()
|
|
14
|
+
_min = min(_min, 0)
|
|
15
|
+
_max = max(_max, 0)
|
|
16
|
+
qmin = 0
|
|
17
|
+
qmax = 255
|
|
18
|
+
|
|
19
|
+
self.scale_ = np.array((_max - _min) / (qmax - qmin), dtype=tensor.dtype)
|
|
20
|
+
initial_zero_point = qmin - _min / self.scale_
|
|
21
|
+
self.zero_point_ = np.array(
|
|
22
|
+
int(max(qmin, min(qmax, initial_zero_point))), dtype=np.uint8
|
|
23
|
+
)
|
|
24
|
+
self.quantized_ = np.maximum(
|
|
25
|
+
0, np.minimum(qmax, (tensor / self.scale_).astype(int) + self.zero_point_)
|
|
26
|
+
).astype(self.zero_point_.dtype)
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def shape(self):
|
|
30
|
+
"accessor"
|
|
31
|
+
return self.quantized_.shape
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def scale(self):
|
|
35
|
+
"accessor"
|
|
36
|
+
return self.scale_
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def zero_point(self):
|
|
40
|
+
"accessor"
|
|
41
|
+
return self.zero_point_
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def qtensor(self):
|
|
45
|
+
"accessor"
|
|
46
|
+
return self.quantized_
|