JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +5 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,1181 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import logging
|
|
5
|
+
from dataclasses import asdict, dataclass
|
|
6
|
+
from importlib.metadata import version as get_version
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import onnx
|
|
11
|
+
import torch
|
|
12
|
+
from onnx import NodeProto, TensorProto, helper, numpy_helper, shape_inference
|
|
13
|
+
|
|
14
|
+
# Keep the ununused import below as it
|
|
15
|
+
# must remain due to 'SessionOptions' dependency.
|
|
16
|
+
from onnxruntime import InferenceSession, SessionOptions
|
|
17
|
+
from onnxruntime_extensions import get_library_path
|
|
18
|
+
|
|
19
|
+
import python.core.model_processing.onnx_custom_ops # noqa: F401
|
|
20
|
+
from python.core import PACKAGE_NAME
|
|
21
|
+
from python.core.model_processing.converters.base import ModelConverter, ModelType
|
|
22
|
+
from python.core.model_processing.errors import (
|
|
23
|
+
InferenceError,
|
|
24
|
+
InvalidModelError,
|
|
25
|
+
IOInfoExtractionError,
|
|
26
|
+
LayerAnalysisError,
|
|
27
|
+
ModelConversionError,
|
|
28
|
+
ModelLoadError,
|
|
29
|
+
ModelSaveError,
|
|
30
|
+
SerializationError,
|
|
31
|
+
)
|
|
32
|
+
from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
|
|
33
|
+
extract_shape_dict,
|
|
34
|
+
get_input_shapes,
|
|
35
|
+
parse_attributes,
|
|
36
|
+
)
|
|
37
|
+
from python.core.model_processing.onnx_quantizer.exceptions import QuantizationError
|
|
38
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
39
|
+
BaseOpQuantizer,
|
|
40
|
+
ScaleConfig,
|
|
41
|
+
)
|
|
42
|
+
from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
|
|
43
|
+
ONNXOpQuantizer,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
import tomllib # Python 3.11+
|
|
48
|
+
except ModuleNotFoundError:
|
|
49
|
+
import tomli as tomllib # noqa: F401
|
|
50
|
+
|
|
51
|
+
ONNXLayerDict = dict[
|
|
52
|
+
str,
|
|
53
|
+
int | str | list[str] | dict[str, list[int]] | list | None | dict,
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
ONNXIODict = dict[str, str | int | list[int]]
|
|
57
|
+
|
|
58
|
+
CircuitParamsDict = dict[str, int | dict[str, bool]]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class ONNXLayer:
|
|
63
|
+
"""
|
|
64
|
+
A dataclass representing an ONNX layer in the form
|
|
65
|
+
to be sent to the circuit building process.
|
|
66
|
+
|
|
67
|
+
This class encapsulates the essential information
|
|
68
|
+
about a layer in an ONNX model. It is designed to facilitate the
|
|
69
|
+
conversion and processing of ONNX models for circuit building purposes.
|
|
70
|
+
|
|
71
|
+
Attributes:
|
|
72
|
+
id (int): A unique identifier for the layer.
|
|
73
|
+
name (str): The name of the layer.
|
|
74
|
+
op_type (str): The operation type of the layer,
|
|
75
|
+
such as "Conv" for convolution layers.
|
|
76
|
+
inputs (list[str]): A list of input names that this layer depends on.
|
|
77
|
+
outputs (list[str]): A list of output names produced by this layer.
|
|
78
|
+
shape (dict[str, list[int]]): A dictionary mapping output names
|
|
79
|
+
to their corresponding shapes.
|
|
80
|
+
tensor (Optional[list]): For constant nodes, this contains the
|
|
81
|
+
tensor data (weights or biases) as a list. For other layers, empty.
|
|
82
|
+
params (Optional[dict]): A dictionary of parameters specific to the
|
|
83
|
+
layer's operation. For example, convolution layers may include parameters
|
|
84
|
+
like dilation, kernel_shape, pad, strides, and group.
|
|
85
|
+
opset_version_number (int): The version number of the ONNX opset
|
|
86
|
+
used for this operation. This is included for infrastructure
|
|
87
|
+
purposes and may not be actively used in all processing steps.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
id: int
|
|
91
|
+
name: str
|
|
92
|
+
op_type: str
|
|
93
|
+
inputs: list[str]
|
|
94
|
+
outputs: list[str]
|
|
95
|
+
shape: dict[
|
|
96
|
+
str,
|
|
97
|
+
list[int],
|
|
98
|
+
]
|
|
99
|
+
tensor: list | None
|
|
100
|
+
params: dict | None
|
|
101
|
+
opset_version_number: int
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass
|
|
105
|
+
class ONNXIO:
|
|
106
|
+
"""
|
|
107
|
+
A dataclass representing an ONNX input or output,
|
|
108
|
+
in the form to be sent to the circuit building process
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
name: str
|
|
112
|
+
elem_type: int
|
|
113
|
+
shape: list[int]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ONNXConverter(ModelConverter):
|
|
117
|
+
"""Concrete implementation of `ModelConverter` for ONNX models."""
|
|
118
|
+
|
|
119
|
+
def __init__(self: ONNXConverter) -> None:
|
|
120
|
+
"""Initialize the converter and its operator quantizer.
|
|
121
|
+
|
|
122
|
+
Initializes:
|
|
123
|
+
self.op_quantizer (ONNXOpQuantizer): Dispatcher that quantizes
|
|
124
|
+
individual ONNX ops and accumulates newly created initializers.
|
|
125
|
+
"""
|
|
126
|
+
self.op_quantizer = ONNXOpQuantizer()
|
|
127
|
+
self.model_type = ModelType.ONNX
|
|
128
|
+
self.logger = logging.getLogger(__name__)
|
|
129
|
+
|
|
130
|
+
def save_model(self: ONNXConverter, file_path: str) -> None:
|
|
131
|
+
"""Serialize the ONNX model to file.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
file_path (str):
|
|
135
|
+
Destination path (e.g., ``"models/my_model.onnx"``).
|
|
136
|
+
|
|
137
|
+
Note
|
|
138
|
+
----
|
|
139
|
+
- For saving and loading:
|
|
140
|
+
https://onnx.ai/onnx/intro/python.html,
|
|
141
|
+
larger models may require a different structure
|
|
142
|
+
"""
|
|
143
|
+
try:
|
|
144
|
+
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
|
145
|
+
onnx.save(self.model, file_path)
|
|
146
|
+
except Exception as e:
|
|
147
|
+
raise ModelSaveError(
|
|
148
|
+
file_path,
|
|
149
|
+
model_type=self.model_type,
|
|
150
|
+
reason=str(e),
|
|
151
|
+
) from e
|
|
152
|
+
|
|
153
|
+
def load_model(self: ONNXConverter, file_path: str) -> onnx.ModelProto:
|
|
154
|
+
"""Load an ONNX model from file and extract basic I/O metadata.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
file_path (str): Path to the `.onnx` file.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
onnx.ModelProto: The loaded onnx model.
|
|
161
|
+
|
|
162
|
+
Raises:
|
|
163
|
+
ModelLoadError: If the model cannot be loaded or validated.
|
|
164
|
+
"""
|
|
165
|
+
try:
|
|
166
|
+
onnx_model = onnx.load(file_path)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
raise ModelLoadError(
|
|
169
|
+
file_path,
|
|
170
|
+
model_type=self.model_type,
|
|
171
|
+
reason=str(e),
|
|
172
|
+
) from e
|
|
173
|
+
|
|
174
|
+
self.model = onnx_model
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
self._extract_model_io_info(onnx_model)
|
|
178
|
+
except Exception as e:
|
|
179
|
+
raise IOInfoExtractionError(
|
|
180
|
+
model_path=file_path,
|
|
181
|
+
model_type=self.model_type,
|
|
182
|
+
reason=str(e),
|
|
183
|
+
) from e
|
|
184
|
+
return self.model
|
|
185
|
+
|
|
186
|
+
def save_quantized_model(self: ONNXConverter, file_path: str) -> None:
|
|
187
|
+
"""Serialize the quantized ONNX model to file.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
file_path (str): Destination path for the quantized model.
|
|
191
|
+
"""
|
|
192
|
+
try:
|
|
193
|
+
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
|
194
|
+
onnx.save(self.quantized_model, file_path)
|
|
195
|
+
except Exception as e:
|
|
196
|
+
raise ModelSaveError(
|
|
197
|
+
file_path,
|
|
198
|
+
model_type=self.model_type,
|
|
199
|
+
reason=str(e),
|
|
200
|
+
) from e
|
|
201
|
+
|
|
202
|
+
# Not sure this is ideal
|
|
203
|
+
def load_quantized_model(self: ONNXConverter, file_path: str) -> None:
|
|
204
|
+
"""Load a quantized ONNX model and create an inference session.
|
|
205
|
+
|
|
206
|
+
Note
|
|
207
|
+
----
|
|
208
|
+
- Uses the custom opset for the quantized layers
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
file_path (str): Path to the quantized ``.onnx`` file.
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
FileNotFoundError: If the file does not exist.
|
|
215
|
+
ModelLoadError: If loading or validation fails.
|
|
216
|
+
"""
|
|
217
|
+
if not Path(file_path).exists():
|
|
218
|
+
msg = f"Quantized model file not found: {file_path}"
|
|
219
|
+
raise FileNotFoundError(msg)
|
|
220
|
+
self.logger.info("Loading quantized model from", extra={"file_path": file_path})
|
|
221
|
+
onnx_model = onnx.load(file_path)
|
|
222
|
+
custom_domain = onnx.helper.make_operatorsetid(
|
|
223
|
+
domain="ai.onnx.contrib",
|
|
224
|
+
version=1,
|
|
225
|
+
)
|
|
226
|
+
onnx_model.opset_import.append(custom_domain)
|
|
227
|
+
# Fix, can remove this next line
|
|
228
|
+
self.quantized_model = onnx_model
|
|
229
|
+
self.ort_sess = self._create_inference_session(file_path)
|
|
230
|
+
self._extract_model_io_info(onnx_model)
|
|
231
|
+
|
|
232
|
+
self.quantized_model_path = file_path
|
|
233
|
+
|
|
234
|
+
def _onnx_check_model_safely(self: ONNXConverter, model: onnx.ModelProto) -> None:
|
|
235
|
+
try:
|
|
236
|
+
onnx.checker.check_model(model)
|
|
237
|
+
except Exception as e:
|
|
238
|
+
raise InvalidModelError(
|
|
239
|
+
model_path=getattr(self, "model_file_name", None),
|
|
240
|
+
reason=f"Model validation failed: {e!s}",
|
|
241
|
+
) from e
|
|
242
|
+
|
|
243
|
+
def analyze_layers(
|
|
244
|
+
self: ONNXConverter,
|
|
245
|
+
output_name_to_shape: dict[str, list[int]] | None = None,
|
|
246
|
+
) -> tuple[list[ONNXLayer], list[ONNXLayer]]:
|
|
247
|
+
"""Analyze the onnx model graph into
|
|
248
|
+
logical layers and parameter tensors.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
output_name_to_shape (dict[str, list[int]], optional):
|
|
252
|
+
mapping of value name -> shape. If omitted,
|
|
253
|
+
shapes are inferred via `onnx.shape_inference`. Defaults to None.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Tuple[list[ONNXLayer], list[ONNXLayer]]: ``(architecture, w_and_b)`` where:
|
|
257
|
+
- ``architecture`` is a list of `ONNXLayer` describing
|
|
258
|
+
the computational graph.
|
|
259
|
+
- ``w_and_b`` is a list of `ONNXLayer` representing
|
|
260
|
+
constant tensors (initializers).
|
|
261
|
+
"""
|
|
262
|
+
try:
|
|
263
|
+
id_count = 0
|
|
264
|
+
# Apply shape inference on the model
|
|
265
|
+
if not output_name_to_shape:
|
|
266
|
+
inferred_model = shape_inference.infer_shapes(self.model)
|
|
267
|
+
self._onnx_check_model_safely(inferred_model)
|
|
268
|
+
|
|
269
|
+
output_name_to_shape = extract_shape_dict(inferred_model)
|
|
270
|
+
domain_to_version = {
|
|
271
|
+
opset.domain: opset.version for opset in self.model.opset_import
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
id_count = 0
|
|
275
|
+
architecture = self.get_model_architecture(
|
|
276
|
+
self.model,
|
|
277
|
+
output_name_to_shape,
|
|
278
|
+
domain_to_version,
|
|
279
|
+
)
|
|
280
|
+
w_and_b = self.get_model_w_and_b(
|
|
281
|
+
self.model,
|
|
282
|
+
output_name_to_shape,
|
|
283
|
+
id_count,
|
|
284
|
+
domain_to_version,
|
|
285
|
+
)
|
|
286
|
+
except InvalidModelError:
|
|
287
|
+
raise
|
|
288
|
+
except (ValueError, TypeError, RuntimeError, OSError, onnx.ONNXException) as e:
|
|
289
|
+
raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
|
|
290
|
+
except Exception as e:
|
|
291
|
+
raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
|
|
292
|
+
else:
|
|
293
|
+
return (architecture, w_and_b)
|
|
294
|
+
|
|
295
|
+
def run_model_onnx_runtime(
|
|
296
|
+
self: ONNXConverter,
|
|
297
|
+
path: str,
|
|
298
|
+
inputs: torch.Tensor,
|
|
299
|
+
) -> list[np.ndarray]:
|
|
300
|
+
"""Execute a model on CPU via ONNX Runtime and return its outputs.
|
|
301
|
+
|
|
302
|
+
Creates a fresh inference session for the model at ``path``, feeds
|
|
303
|
+
the provided tensor under the first input name, and returns the
|
|
304
|
+
first output.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
path (str): Path to the ONNX model to execute.
|
|
308
|
+
input (torch.Tensor): Input tensor to feed into the model's first input.
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
Any: The output(s) as returned by `InferenceSession.run`.
|
|
312
|
+
"""
|
|
313
|
+
# Fix, can remove this next line
|
|
314
|
+
try:
|
|
315
|
+
ort_sess = self._create_inference_session(path)
|
|
316
|
+
input_name = ort_sess.get_inputs()[0].name
|
|
317
|
+
output_name = ort_sess.get_outputs()[0].name
|
|
318
|
+
if ort_sess.get_inputs()[0].type == "tensor(double)":
|
|
319
|
+
outputs = ort_sess.run(
|
|
320
|
+
[output_name],
|
|
321
|
+
{input_name: np.asarray(inputs).astype(np.float64)},
|
|
322
|
+
)
|
|
323
|
+
else:
|
|
324
|
+
outputs = ort_sess.run([output_name], {input_name: np.asarray(inputs)})
|
|
325
|
+
|
|
326
|
+
except (
|
|
327
|
+
ModelConversionError,
|
|
328
|
+
RuntimeError,
|
|
329
|
+
ValueError,
|
|
330
|
+
TypeError,
|
|
331
|
+
OSError,
|
|
332
|
+
Exception,
|
|
333
|
+
) as e:
|
|
334
|
+
raise InferenceError(
|
|
335
|
+
model_path=path,
|
|
336
|
+
model_type=self.model_type,
|
|
337
|
+
reason=str(e),
|
|
338
|
+
) from e
|
|
339
|
+
else:
|
|
340
|
+
return outputs
|
|
341
|
+
|
|
342
|
+
def _collect_constant_values(
|
|
343
|
+
self: ONNXConverter,
|
|
344
|
+
model: onnx.ModelProto,
|
|
345
|
+
) -> dict[str, np.ndarray]:
|
|
346
|
+
"""Collect constant values from Constant nodes in the model.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
model (onnx.ModelProto): The ONNX model to analyze.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
dict[str, np.ndarray]: Mapping of output name to constant value.
|
|
353
|
+
"""
|
|
354
|
+
constant_values = {}
|
|
355
|
+
for node in model.graph.node:
|
|
356
|
+
if node.op_type == "Constant":
|
|
357
|
+
self.logger.debug("Constant node", extra={"node": node})
|
|
358
|
+
for attr in node.attribute:
|
|
359
|
+
if attr.name == "value":
|
|
360
|
+
tensor = attr.t
|
|
361
|
+
const_value = numpy_helper.to_array(tensor)
|
|
362
|
+
constant_values[node.output[0]] = const_value
|
|
363
|
+
return constant_values
|
|
364
|
+
|
|
365
|
+
def _attach_constant_parameters(
|
|
366
|
+
self: ONNXConverter,
|
|
367
|
+
layer: ONNXLayer,
|
|
368
|
+
node: NodeProto,
|
|
369
|
+
constant_values: dict[str, np.ndarray],
|
|
370
|
+
) -> None:
|
|
371
|
+
"""Attach constant inputs as parameters to a layer.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
layer (ONNXLayer): The layer to modify.
|
|
375
|
+
node (NodeProto): The ONNX node being processed.
|
|
376
|
+
constant_values (dict[str, np.ndarray]): Constant values mapping.
|
|
377
|
+
"""
|
|
378
|
+
for input_name in node.input:
|
|
379
|
+
if input_name in constant_values:
|
|
380
|
+
self.logger.debug(
|
|
381
|
+
"Layer params before:",
|
|
382
|
+
extra={"layer_params": layer.params},
|
|
383
|
+
)
|
|
384
|
+
if not hasattr(layer, "params") or layer.params is None:
|
|
385
|
+
layer.params = {}
|
|
386
|
+
result = constant_values[input_name]
|
|
387
|
+
if isinstance(result, (np.ndarray, torch.Tensor)):
|
|
388
|
+
layer.params[input_name] = result.tolist()
|
|
389
|
+
else:
|
|
390
|
+
layer.params[input_name] = constant_values[input_name]
|
|
391
|
+
self.logger.debug(
|
|
392
|
+
"Updated layer params",
|
|
393
|
+
extra={"layer_params": layer.params},
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
def get_model_architecture(
|
|
397
|
+
self: ONNXConverter,
|
|
398
|
+
model: onnx.ModelProto,
|
|
399
|
+
output_name_to_shape: dict[str, list[int]],
|
|
400
|
+
domain_to_version: dict[str, int] | None = None,
|
|
401
|
+
) -> list[ONNXLayer]:
|
|
402
|
+
"""Construct ONNXLayer objects for architecture graph nodes
|
|
403
|
+
(not weights or biases).
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
model (onnx.ModelProto): The ONNX model to analyze.
|
|
407
|
+
output_name_to_shape (dict[str, list[int]]):
|
|
408
|
+
Map of value name -> inferred shape.
|
|
409
|
+
id_count (int, optional):
|
|
410
|
+
Starting numeric ID for layers (incremented per node).
|
|
411
|
+
Defaults to 0.
|
|
412
|
+
domain_to_version (dict[str, int], optional):
|
|
413
|
+
Map of opset domain -> version used. Defaults to None.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
list[ONNXLayer]:
|
|
417
|
+
Models computational layers
|
|
418
|
+
(excluding initializers) in the form of ONNXLayers.
|
|
419
|
+
"""
|
|
420
|
+
_ = domain_to_version
|
|
421
|
+
constant_values = self._collect_constant_values(model)
|
|
422
|
+
layers = []
|
|
423
|
+
current_id = 0
|
|
424
|
+
|
|
425
|
+
for node in model.graph.node:
|
|
426
|
+
if node.op_type == "Constant":
|
|
427
|
+
continue # Skip constant nodes
|
|
428
|
+
|
|
429
|
+
layer = self.analyze_layer(
|
|
430
|
+
node,
|
|
431
|
+
output_name_to_shape,
|
|
432
|
+
current_id,
|
|
433
|
+
domain_to_version,
|
|
434
|
+
)
|
|
435
|
+
self.logger.debug(
|
|
436
|
+
"Layer",
|
|
437
|
+
extra={
|
|
438
|
+
"layer_name": layer.name,
|
|
439
|
+
"layer_op": layer.op_type,
|
|
440
|
+
"layer_shape": layer.shape,
|
|
441
|
+
},
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
self._attach_constant_parameters(layer, node, constant_values)
|
|
445
|
+
layers.append(layer)
|
|
446
|
+
current_id += 1
|
|
447
|
+
|
|
448
|
+
return layers
|
|
449
|
+
|
|
450
|
+
def get_model_w_and_b(
|
|
451
|
+
self: ONNXConverter,
|
|
452
|
+
model: onnx.ModelProto,
|
|
453
|
+
output_name_to_shape: dict[str, list[int]],
|
|
454
|
+
id_count: int = 0,
|
|
455
|
+
domain_to_version: dict[str, int] | None = None,
|
|
456
|
+
) -> list[ONNXLayer]:
|
|
457
|
+
"""Extract constant initializers (weights/biases) as layers.
|
|
458
|
+
|
|
459
|
+
Iterates through graph initializers and wraps each tensor
|
|
460
|
+
into an ONNXLayers.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
model (onnx.ModelProto): The ONNX model to analyze.
|
|
464
|
+
output_name_to_shape (dict[str, list[int]]):
|
|
465
|
+
Map of value name -> inferred shape
|
|
466
|
+
id_count (int, optional):
|
|
467
|
+
Starting numeric ID for layers (incremented per tensor).
|
|
468
|
+
Defaults to 0.
|
|
469
|
+
domain_to_version (dict[str, int], optional):
|
|
470
|
+
Map of opset domain -> version used (unused). Defaults to None.
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
list[ONNXLayer]: ONNXLayers representing weights/biases found in the graph
|
|
474
|
+
"""
|
|
475
|
+
layers = []
|
|
476
|
+
# Check the model and print Y"s shape information
|
|
477
|
+
for _, node in enumerate(model.graph.initializer):
|
|
478
|
+
layer = self.analyze_constant(
|
|
479
|
+
node,
|
|
480
|
+
output_name_to_shape,
|
|
481
|
+
id_count,
|
|
482
|
+
domain_to_version,
|
|
483
|
+
)
|
|
484
|
+
layers.append(layer)
|
|
485
|
+
id_count += 1
|
|
486
|
+
|
|
487
|
+
return layers
|
|
488
|
+
|
|
489
|
+
def _create_inference_session(
|
|
490
|
+
self: ONNXConverter,
|
|
491
|
+
model_path: str,
|
|
492
|
+
) -> InferenceSession:
|
|
493
|
+
"""Internal helper to create and configure an ONNX Runtime InferenceSession.
|
|
494
|
+
Registers a custom ops shared library for use with the
|
|
495
|
+
custom quantized operations.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
model_path (str): Path to the ONNX model to load.
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
InferenceSession: A configured InferenceSession.
|
|
502
|
+
"""
|
|
503
|
+
try:
|
|
504
|
+
opts = SessionOptions()
|
|
505
|
+
opts.register_custom_ops_library(get_library_path())
|
|
506
|
+
return InferenceSession(
|
|
507
|
+
model_path,
|
|
508
|
+
opts,
|
|
509
|
+
providers=["CPUExecutionProvider"],
|
|
510
|
+
)
|
|
511
|
+
except (OSError, onnx.ONNXException, RuntimeError, Exception) as e:
|
|
512
|
+
raise InferenceError(
|
|
513
|
+
model_path,
|
|
514
|
+
model_type=self.model_type,
|
|
515
|
+
reason=str(e),
|
|
516
|
+
) from e
|
|
517
|
+
|
|
518
|
+
def analyze_layer(
|
|
519
|
+
self: ONNXConverter,
|
|
520
|
+
node: NodeProto,
|
|
521
|
+
output_name_to_shape: dict[str, list[int]],
|
|
522
|
+
id_count: int = -1,
|
|
523
|
+
domain_to_version: dict[str, int] | None = None,
|
|
524
|
+
) -> ONNXLayer:
|
|
525
|
+
"""Convert a non-constant ONNX node into a structured ONNXLayer.
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
node (NodeProto): The ONNX node to analyze.
|
|
529
|
+
output_name_to_shape (dict[str, list[int]]):
|
|
530
|
+
Map of value name -> inferred shape.
|
|
531
|
+
id_count (int, optional):
|
|
532
|
+
Numeric ID to assign to this layer (increment handled by caller).
|
|
533
|
+
Defaults to -1.
|
|
534
|
+
domain_to_version (dict[str, int], optional):
|
|
535
|
+
Map of opset domain -> version number. Defaults to None.
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
ONNXLayer: ONNXLayer describing the node
|
|
539
|
+
"""
|
|
540
|
+
name = node.name
|
|
541
|
+
layer_id = id_count
|
|
542
|
+
id_count += 1
|
|
543
|
+
op_type = node.op_type
|
|
544
|
+
inputs = node.input
|
|
545
|
+
outputs = node.output
|
|
546
|
+
opset_version = (
|
|
547
|
+
domain_to_version.get(node.domain, "unknown") if domain_to_version else -1
|
|
548
|
+
)
|
|
549
|
+
params = parse_attributes(node.attribute)
|
|
550
|
+
|
|
551
|
+
# Extract output shapes
|
|
552
|
+
output_shapes = {
|
|
553
|
+
out_name: output_name_to_shape.get(out_name, []) for out_name in outputs
|
|
554
|
+
}
|
|
555
|
+
return ONNXLayer(
|
|
556
|
+
id=layer_id,
|
|
557
|
+
name=name,
|
|
558
|
+
op_type=op_type,
|
|
559
|
+
inputs=list(inputs),
|
|
560
|
+
outputs=list(outputs),
|
|
561
|
+
shape=output_shapes,
|
|
562
|
+
params=params,
|
|
563
|
+
opset_version_number=opset_version,
|
|
564
|
+
tensor=None,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
def analyze_constant(
|
|
568
|
+
self: ONNXConverter,
|
|
569
|
+
node: TensorProto,
|
|
570
|
+
output_name_to_shape: dict[str, list[int]],
|
|
571
|
+
id_count: int = -1,
|
|
572
|
+
domain_to_version: dict[str, int] | None = None,
|
|
573
|
+
) -> list[ONNXLayer]:
|
|
574
|
+
"""Convert a constant ONNX node (weights or bias) into a structured ONNXLayer.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
node (NodeProto): The ONNX node to analyze.
|
|
578
|
+
output_name_to_shape (dict[str, list[int]]):
|
|
579
|
+
Map of value name -> inferred shape.
|
|
580
|
+
id_count (int, optional):
|
|
581
|
+
Numeric ID to assign to this layer (increment handled by caller).
|
|
582
|
+
Defaults to -1.
|
|
583
|
+
domain_to_version (dict[str, int], optional):
|
|
584
|
+
Map of opset domain -> version number. Defaults to None.
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
ONNXLayer: ONNXLayer describing the node
|
|
588
|
+
"""
|
|
589
|
+
_ = domain_to_version
|
|
590
|
+
name = node.name
|
|
591
|
+
id_count += 1
|
|
592
|
+
op_type = "Const"
|
|
593
|
+
inputs = []
|
|
594
|
+
outputs = []
|
|
595
|
+
opset_version = -1
|
|
596
|
+
params = {}
|
|
597
|
+
constant_dtype = node.data_type
|
|
598
|
+
# Can do this step in rust potentially to keep file sizes low if needed
|
|
599
|
+
try:
|
|
600
|
+
np_data = onnx.numpy_helper.to_array(node, constant_dtype)
|
|
601
|
+
except (ValueError, TypeError, onnx.ONNXException, Exception) as e:
|
|
602
|
+
raise SerializationError(
|
|
603
|
+
tensor_name=node.name,
|
|
604
|
+
reason=f"Failed to convert tensor: {e!s}",
|
|
605
|
+
) from e
|
|
606
|
+
# 💡 Extract output shapes
|
|
607
|
+
output_shapes = {
|
|
608
|
+
out_name: output_name_to_shape.get(out_name, []) for out_name in outputs
|
|
609
|
+
}
|
|
610
|
+
return ONNXLayer(
|
|
611
|
+
id=id_count,
|
|
612
|
+
name=name,
|
|
613
|
+
op_type=op_type,
|
|
614
|
+
inputs=list(inputs),
|
|
615
|
+
outputs=list(outputs),
|
|
616
|
+
shape=output_shapes,
|
|
617
|
+
params=params,
|
|
618
|
+
opset_version_number=opset_version,
|
|
619
|
+
tensor=np_data.tolist(),
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
def _prepare_model_for_quantization(
|
|
623
|
+
self: ONNXConverter,
|
|
624
|
+
unscaled_model: onnx.ModelProto,
|
|
625
|
+
) -> tuple[onnx.ModelProto, dict[str, onnx.TensorProto], list[str]]:
|
|
626
|
+
"""Prepare the model for quantization by creating a copy and necessary mappings.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
unscaled_model (onnx.ModelProto): The original unscaled model.
|
|
630
|
+
|
|
631
|
+
Returns:
|
|
632
|
+
tuple[onnx.ModelProto, dict[str, onnx.TensorProto], list[str]]:
|
|
633
|
+
Model copy, initializer map, and input names.
|
|
634
|
+
"""
|
|
635
|
+
model = copy.deepcopy(unscaled_model)
|
|
636
|
+
self.op_quantizer.check_model(model)
|
|
637
|
+
initializer_map = {init.name: init for init in model.graph.initializer}
|
|
638
|
+
input_names = [inp.name for inp in unscaled_model.graph.input]
|
|
639
|
+
return model, initializer_map, input_names
|
|
640
|
+
|
|
641
|
+
def _quantize_inputs(
|
|
642
|
+
self: ONNXConverter,
|
|
643
|
+
model: onnx.ModelProto,
|
|
644
|
+
input_names: list[str],
|
|
645
|
+
scale_base: int,
|
|
646
|
+
scale_exponent: int,
|
|
647
|
+
) -> list[onnx.NodeProto]:
|
|
648
|
+
"""Quantize model inputs and update node connections.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
model (onnx.ModelProto): The model being quantized.
|
|
652
|
+
input_names (list[str]): Names of input tensors.
|
|
653
|
+
scale_base (int): Base for scaling.
|
|
654
|
+
scale_exponent (int): Exponent for scaling.
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
list[onnx.NodeProto]: New nodes created for input quantization.
|
|
658
|
+
"""
|
|
659
|
+
new_nodes = []
|
|
660
|
+
for name in input_names:
|
|
661
|
+
output_name, mul_node, _, cast_to_int64 = self.quantize_input(
|
|
662
|
+
input_name=name,
|
|
663
|
+
op_quantizer=self.op_quantizer,
|
|
664
|
+
scale_base=scale_base,
|
|
665
|
+
scale_exponent=scale_exponent,
|
|
666
|
+
)
|
|
667
|
+
new_nodes.extend([mul_node, cast_to_int64])
|
|
668
|
+
|
|
669
|
+
# Update references to this input in all nodes
|
|
670
|
+
for node in model.graph.node:
|
|
671
|
+
for idx, inp in enumerate(node.input):
|
|
672
|
+
if inp == name:
|
|
673
|
+
node.input[idx] = output_name
|
|
674
|
+
return new_nodes
|
|
675
|
+
|
|
676
|
+
def _update_input_types(self: ONNXConverter, model: onnx.ModelProto) -> None:
|
|
677
|
+
"""Update input tensor types from float32 to float64.
|
|
678
|
+
|
|
679
|
+
Args:
|
|
680
|
+
model (onnx.ModelProto): The model to update.
|
|
681
|
+
"""
|
|
682
|
+
for input_tensor in model.graph.input:
|
|
683
|
+
tensor_type = input_tensor.type.tensor_type
|
|
684
|
+
if tensor_type.elem_type == TensorProto.FLOAT:
|
|
685
|
+
tensor_type.elem_type = TensorProto.DOUBLE
|
|
686
|
+
|
|
687
|
+
def _quantize_nodes(
|
|
688
|
+
self: ONNXConverter,
|
|
689
|
+
model: onnx.ModelProto,
|
|
690
|
+
scale_config: ScaleConfig,
|
|
691
|
+
rescale_config: dict | None,
|
|
692
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
693
|
+
) -> list[onnx.NodeProto]:
|
|
694
|
+
"""Quantize all nodes in the model.
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
model (onnx.ModelProto): The model being quantized.
|
|
698
|
+
scale_base (int): Base for scaling.
|
|
699
|
+
scale_exponent (int): Exponent for scaling.
|
|
700
|
+
rescale_config (dict, optional): Rescale configuration.
|
|
701
|
+
initializer_map (dict[str, onnx.TensorProto]): Initializer mapping.
|
|
702
|
+
|
|
703
|
+
Returns:
|
|
704
|
+
list[onnx.NodeProto]: Quantized nodes.
|
|
705
|
+
"""
|
|
706
|
+
quantized_nodes = []
|
|
707
|
+
for node in model.graph.node:
|
|
708
|
+
rescale = rescale_config.get(node.name, False) if rescale_config else True
|
|
709
|
+
quant_nodes = self.quantize_layer(
|
|
710
|
+
node=node,
|
|
711
|
+
model=model,
|
|
712
|
+
scale_config=ScaleConfig(
|
|
713
|
+
exponent=scale_config.exponent,
|
|
714
|
+
base=scale_config.base,
|
|
715
|
+
rescale=rescale,
|
|
716
|
+
),
|
|
717
|
+
initializer_map=initializer_map,
|
|
718
|
+
)
|
|
719
|
+
if isinstance(quant_nodes, list):
|
|
720
|
+
quantized_nodes.extend(quant_nodes)
|
|
721
|
+
else:
|
|
722
|
+
quantized_nodes.append(quant_nodes)
|
|
723
|
+
return quantized_nodes
|
|
724
|
+
|
|
725
|
+
def _process_initializers(
|
|
726
|
+
self: ONNXConverter,
|
|
727
|
+
model: onnx.ModelProto,
|
|
728
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
729
|
+
) -> list[onnx.TensorProto]:
|
|
730
|
+
"""Process and filter initializers, converting types as needed.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
model (onnx.ModelProto): The quantized model.
|
|
734
|
+
initializer_map (dict[str, onnx.TensorProto]): Original initializer map.
|
|
735
|
+
|
|
736
|
+
Returns:
|
|
737
|
+
list[onnx.TensorProto]: Processed initializers to keep.
|
|
738
|
+
"""
|
|
739
|
+
used_initializer_names = set()
|
|
740
|
+
for node in model.graph.node:
|
|
741
|
+
used_initializer_names.update(node.input)
|
|
742
|
+
|
|
743
|
+
kept_initializers = []
|
|
744
|
+
for name in used_initializer_names:
|
|
745
|
+
if name in initializer_map:
|
|
746
|
+
orig_init = initializer_map[name]
|
|
747
|
+
np_array = numpy_helper.to_array(orig_init)
|
|
748
|
+
|
|
749
|
+
if np_array.dtype == np.float32:
|
|
750
|
+
np_array = np_array.astype(np.float64)
|
|
751
|
+
new_init = numpy_helper.from_array(np_array, name=name)
|
|
752
|
+
kept_initializers.append(new_init)
|
|
753
|
+
else:
|
|
754
|
+
kept_initializers.append(orig_init)
|
|
755
|
+
|
|
756
|
+
return kept_initializers
|
|
757
|
+
|
|
758
|
+
def _update_graph_types(self: ONNXConverter, model: onnx.ModelProto) -> None:
|
|
759
|
+
"""Update output and value_info types to INT64.
|
|
760
|
+
|
|
761
|
+
Args:
|
|
762
|
+
model (onnx.ModelProto): The model to update.
|
|
763
|
+
"""
|
|
764
|
+
for out in model.graph.output:
|
|
765
|
+
out.type.tensor_type.elem_type = onnx.TensorProto.INT64
|
|
766
|
+
|
|
767
|
+
for vi in model.graph.value_info:
|
|
768
|
+
vi.type.tensor_type.elem_type = TensorProto.INT64
|
|
769
|
+
|
|
770
|
+
def _add_custom_domain(self: ONNXConverter, model: onnx.ModelProto) -> None:
|
|
771
|
+
"""Add custom opset domain if not present.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
model (onnx.ModelProto): The model to update.
|
|
775
|
+
"""
|
|
776
|
+
custom_domain = helper.make_operatorsetid(
|
|
777
|
+
domain="ai.onnx.contrib",
|
|
778
|
+
version=1,
|
|
779
|
+
)
|
|
780
|
+
domains = [op.domain for op in model.opset_import]
|
|
781
|
+
if "ai.onnx.contrib" not in domains:
|
|
782
|
+
model.opset_import.append(custom_domain)
|
|
783
|
+
|
|
784
|
+
def _log_quantization_results(self: ONNXConverter, model: onnx.ModelProto) -> None:
|
|
785
|
+
"""Log quantization results for debugging.
|
|
786
|
+
|
|
787
|
+
Args:
|
|
788
|
+
model (onnx.ModelProto): The quantized model.
|
|
789
|
+
"""
|
|
790
|
+
for layer in model.graph.node:
|
|
791
|
+
self.logger.debug(
|
|
792
|
+
"Node",
|
|
793
|
+
extra={
|
|
794
|
+
"layer_name": layer.name,
|
|
795
|
+
"op_type": layer.op_type,
|
|
796
|
+
"input": layer.input,
|
|
797
|
+
"output": layer.output,
|
|
798
|
+
},
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
for layer in model.graph.initializer:
|
|
802
|
+
self.logger.debug("Initializer", extra={"layer_name": layer.name})
|
|
803
|
+
|
|
804
|
+
def quantize_model(
|
|
805
|
+
self: ONNXConverter,
|
|
806
|
+
unscaled_model: onnx.ModelProto,
|
|
807
|
+
scale_base: int,
|
|
808
|
+
scale_exponent: int,
|
|
809
|
+
rescale_config: dict | None = None,
|
|
810
|
+
) -> onnx.ModelProto:
|
|
811
|
+
"""Produce a quantized ONNX graph from a floating-point model.
|
|
812
|
+
|
|
813
|
+
Args:
|
|
814
|
+
unscaled_model (onnx.ModelProto): The original unscaled model.
|
|
815
|
+
scale_base (int): Base for fixed-point scaling (e.g., 2).
|
|
816
|
+
scale_exponent (int):
|
|
817
|
+
Exponent for scaling (e.g., 18 would lead to a scale factor 2**18).
|
|
818
|
+
rescale_config (dict, optional): mapping of node name -> bool to control
|
|
819
|
+
whether a given node should apply a final rescale. Defaults to None.
|
|
820
|
+
|
|
821
|
+
Returns:
|
|
822
|
+
onnx.ModelProto: A new onnx model representation of the quantized model.
|
|
823
|
+
"""
|
|
824
|
+
try:
|
|
825
|
+
# Prepare model and mappings
|
|
826
|
+
model, initializer_map, input_names = self._prepare_model_for_quantization(
|
|
827
|
+
unscaled_model,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
# Quantize inputs and collect new nodes
|
|
831
|
+
new_nodes = self._quantize_inputs(
|
|
832
|
+
model,
|
|
833
|
+
input_names,
|
|
834
|
+
scale_base,
|
|
835
|
+
scale_exponent,
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
# Update input types
|
|
839
|
+
self._update_input_types(model)
|
|
840
|
+
|
|
841
|
+
# Quantize all nodes
|
|
842
|
+
quantized_nodes = self._quantize_nodes(
|
|
843
|
+
model,
|
|
844
|
+
ScaleConfig(scale_exponent, scale_base, rescale=True),
|
|
845
|
+
rescale_config,
|
|
846
|
+
initializer_map,
|
|
847
|
+
)
|
|
848
|
+
new_nodes.extend(quantized_nodes)
|
|
849
|
+
|
|
850
|
+
# Update graph with new nodes
|
|
851
|
+
model.graph.ClearField("node")
|
|
852
|
+
model.graph.node.extend(new_nodes)
|
|
853
|
+
|
|
854
|
+
# Process initializers
|
|
855
|
+
kept_initializers = self._process_initializers(model, initializer_map)
|
|
856
|
+
|
|
857
|
+
# Update graph initializers
|
|
858
|
+
model.graph.ClearField("initializer")
|
|
859
|
+
model.graph.initializer.extend(kept_initializers)
|
|
860
|
+
model.graph.initializer.extend(self.op_quantizer.new_initializers)
|
|
861
|
+
self.op_quantizer.new_initializers = []
|
|
862
|
+
|
|
863
|
+
# Update types and add custom domain
|
|
864
|
+
self._update_graph_types(model)
|
|
865
|
+
self._add_custom_domain(model)
|
|
866
|
+
|
|
867
|
+
# Log results
|
|
868
|
+
self._log_quantization_results(model)
|
|
869
|
+
|
|
870
|
+
except (QuantizationError, ModelConversionError):
|
|
871
|
+
raise
|
|
872
|
+
except (
|
|
873
|
+
onnx.ONNXException,
|
|
874
|
+
ValueError,
|
|
875
|
+
TypeError,
|
|
876
|
+
RuntimeError,
|
|
877
|
+
OSError,
|
|
878
|
+
Exception,
|
|
879
|
+
) as e:
|
|
880
|
+
msg = "Quantization failed for model"
|
|
881
|
+
f" '{getattr(self, 'model_file_name', 'unknown')}': {e!s}"
|
|
882
|
+
raise ModelConversionError(
|
|
883
|
+
msg,
|
|
884
|
+
model_type=self.model_type,
|
|
885
|
+
) from e
|
|
886
|
+
else:
|
|
887
|
+
return model
|
|
888
|
+
|
|
889
|
+
def quantize_layer(
|
|
890
|
+
self: ONNXConverter,
|
|
891
|
+
node: onnx.NodeProto,
|
|
892
|
+
model: onnx.ModelProto,
|
|
893
|
+
scale_config: ScaleConfig,
|
|
894
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
895
|
+
) -> onnx.NodeProto:
|
|
896
|
+
"""Quantize a single ONNX node using the configured op quantizer.
|
|
897
|
+
|
|
898
|
+
Args:
|
|
899
|
+
node (onnx.NodeProto): The original onnx node to quantize.
|
|
900
|
+
model (onnx.ModelProto): The original model used for context
|
|
901
|
+
scale_config (ScaleConfig): Contains the following:
|
|
902
|
+
- rescale (bool): Whether to apply output rescaling for this node.
|
|
903
|
+
- scale_exponent (int):
|
|
904
|
+
Exponent for scaling (e.g., 18 would lead to a scale factor 2**18).
|
|
905
|
+
- scale_base (int): Base for fixed-point scaling (e.g., 2).
|
|
906
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
907
|
+
Mapping from initializer name to tensor.
|
|
908
|
+
|
|
909
|
+
Returns:
|
|
910
|
+
onnx.NodeProto:
|
|
911
|
+
A quantized node or list of nodes replacing the initial node.
|
|
912
|
+
"""
|
|
913
|
+
try:
|
|
914
|
+
return self.op_quantizer.quantize(
|
|
915
|
+
node=node,
|
|
916
|
+
rescale=scale_config.rescale,
|
|
917
|
+
graph=model.graph,
|
|
918
|
+
scale_exponent=scale_config.exponent,
|
|
919
|
+
scale_base=scale_config.base,
|
|
920
|
+
initializer_map=initializer_map,
|
|
921
|
+
)
|
|
922
|
+
except QuantizationError:
|
|
923
|
+
raise
|
|
924
|
+
except (RuntimeError, ValueError, TypeError, Exception) as e:
|
|
925
|
+
raise ModelConversionError(str(e), model_type=self.model_type) from e
|
|
926
|
+
|
|
927
|
+
def quantize_input(
|
|
928
|
+
self: ONNXConverter,
|
|
929
|
+
input_name: str,
|
|
930
|
+
op_quantizer: ONNXOpQuantizer,
|
|
931
|
+
scale_base: int,
|
|
932
|
+
scale_exponent: int,
|
|
933
|
+
) -> tuple[str, onnx.NodeProto, onnx.NodeProto, onnx.NodeProto]:
|
|
934
|
+
"""Insert scaling and casting nodes to quantize a model input.
|
|
935
|
+
|
|
936
|
+
Creates:
|
|
937
|
+
- Mul: scales the input by scale_base ** scale.
|
|
938
|
+
- Cast (to INT64): produces the final integer input tensor.
|
|
939
|
+
|
|
940
|
+
Args:
|
|
941
|
+
input_name (str): Name of the graph input to quantize.
|
|
942
|
+
op_quantizer (ONNXOpQuantizer): The op quantizer whose
|
|
943
|
+
``new_initializers`` list is used to store the created scale constant.
|
|
944
|
+
scale_base (int): Base for fixed-point scaling (e.g., 2).
|
|
945
|
+
scale_exponent (int):
|
|
946
|
+
Exponent for scaling (e.g., 18 would lead to a scale factor 2**18).
|
|
947
|
+
|
|
948
|
+
Returns:
|
|
949
|
+
tuple[str, onnx.NodeProto, onnx.NodeProto, onnx.NodeProto]:
|
|
950
|
+
A tuple ``(output_name, mul_node, floor_node, cast_node)`` where
|
|
951
|
+
``output_name`` is the name of the quantized input tensor
|
|
952
|
+
and the nodes are nodes to add to the graph.
|
|
953
|
+
"""
|
|
954
|
+
try:
|
|
955
|
+
scale_value = BaseOpQuantizer.get_scaling(
|
|
956
|
+
scale_base=scale_base,
|
|
957
|
+
scale_exponent=scale_exponent,
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
# === Create scale constant ===
|
|
961
|
+
scale_const_name = input_name + "_scale"
|
|
962
|
+
scale_tensor = numpy_helper.from_array(
|
|
963
|
+
np.array([scale_value], dtype=np.float64),
|
|
964
|
+
name=scale_const_name,
|
|
965
|
+
)
|
|
966
|
+
op_quantizer.new_initializers.append(scale_tensor)
|
|
967
|
+
|
|
968
|
+
# === Add Mul node ===
|
|
969
|
+
scaled_output_name = f"{input_name}_scaled"
|
|
970
|
+
mul_node = helper.make_node(
|
|
971
|
+
"Mul",
|
|
972
|
+
inputs=[input_name, scale_const_name],
|
|
973
|
+
outputs=[scaled_output_name],
|
|
974
|
+
name=f"{input_name}_mul",
|
|
975
|
+
)
|
|
976
|
+
# === Floor node (simulate rounding) ===
|
|
977
|
+
rounded_output_name = f"{input_name}_scaled_floor"
|
|
978
|
+
floor_node = helper.make_node(
|
|
979
|
+
"Floor",
|
|
980
|
+
inputs=[scaled_output_name],
|
|
981
|
+
outputs=[rounded_output_name],
|
|
982
|
+
name=f"{scaled_output_name}",
|
|
983
|
+
)
|
|
984
|
+
output_name = f"{rounded_output_name}_int"
|
|
985
|
+
cast_to_int64 = helper.make_node(
|
|
986
|
+
"Cast",
|
|
987
|
+
inputs=[scaled_output_name],
|
|
988
|
+
outputs=[output_name],
|
|
989
|
+
to=onnx.TensorProto.INT64,
|
|
990
|
+
name=rounded_output_name,
|
|
991
|
+
)
|
|
992
|
+
except (ValueError, TypeError, RuntimeError, OSError, Exception) as e:
|
|
993
|
+
msg = f"Error quantizing inputs: {e}"
|
|
994
|
+
raise ModelConversionError(
|
|
995
|
+
msg,
|
|
996
|
+
self.model_type,
|
|
997
|
+
) from e
|
|
998
|
+
else:
|
|
999
|
+
return output_name, mul_node, floor_node, cast_to_int64
|
|
1000
|
+
|
|
1001
|
+
def _extract_model_io_info(
|
|
1002
|
+
self: ONNXConverter,
|
|
1003
|
+
onnx_model: onnx.ModelProto,
|
|
1004
|
+
) -> None:
|
|
1005
|
+
"""Populate input metadata from a loaded ONNX model.
|
|
1006
|
+
|
|
1007
|
+
Args:
|
|
1008
|
+
onnx_model (onnx.ModelProto): Onnx model
|
|
1009
|
+
"""
|
|
1010
|
+
self.required_keys = [
|
|
1011
|
+
graph_input.name for graph_input in onnx_model.graph.input
|
|
1012
|
+
]
|
|
1013
|
+
self.input_shape = get_input_shapes(onnx_model)
|
|
1014
|
+
|
|
1015
|
+
def get_weights(self: ONNXConverter) -> tuple[
|
|
1016
|
+
dict[str, list[ONNXLayerDict]],
|
|
1017
|
+
dict[str, list[ONNXLayerDict]],
|
|
1018
|
+
CircuitParamsDict,
|
|
1019
|
+
]:
|
|
1020
|
+
"""Export architecture, weights, and circuit parameters for ECC.
|
|
1021
|
+
|
|
1022
|
+
1. Analyze the model for architecture + w & b
|
|
1023
|
+
2. Put arch into format to be read by ECC circuit builder
|
|
1024
|
+
3. Put w + b into format to be read by ECC circuit builder
|
|
1025
|
+
|
|
1026
|
+
Returns:
|
|
1027
|
+
tuple[dict[str, list[dict[str, Any]]],
|
|
1028
|
+
dict[str, list[dict[str, Any]]], dict[str, Any]]:
|
|
1029
|
+
A tuple ``(architecture, weights, circuit_params)``:
|
|
1030
|
+
- ``architecture``: dict with serialized ``architecture`` layers.
|
|
1031
|
+
- ``weights``: dict containing ``w_and_b`` (serialized tensors).
|
|
1032
|
+
- ``circuit_params``: dict containing scaling parameters and
|
|
1033
|
+
``rescale_config``.
|
|
1034
|
+
"""
|
|
1035
|
+
inferred_model = shape_inference.infer_shapes(self.model)
|
|
1036
|
+
|
|
1037
|
+
scaling = BaseOpQuantizer.get_scaling(
|
|
1038
|
+
scale_base=getattr(self, "scale_base", 2),
|
|
1039
|
+
scale_exponent=(getattr(self, "scale_exponent", 18)),
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
# Check the model and print Y"s shape information
|
|
1043
|
+
self._onnx_check_model_safely(inferred_model)
|
|
1044
|
+
output_name_to_shape = extract_shape_dict(inferred_model)
|
|
1045
|
+
(architecture, w_and_b) = self.analyze_layers(output_name_to_shape)
|
|
1046
|
+
for w in w_and_b:
|
|
1047
|
+
try:
|
|
1048
|
+
w_and_b_array = np.asarray(w.tensor)
|
|
1049
|
+
except (ValueError, TypeError, Exception) as e:
|
|
1050
|
+
raise SerializationError(
|
|
1051
|
+
tensor_name=getattr(w, "name", None),
|
|
1052
|
+
reason=f"cannot convert to ndarray: {e}",
|
|
1053
|
+
) from e
|
|
1054
|
+
|
|
1055
|
+
try:
|
|
1056
|
+
# TODO @jsgold-1: We need a better way to distinguish bias tensors from weight tensors # noqa: FIX002, TD003,E501
|
|
1057
|
+
if "bias" in w.name:
|
|
1058
|
+
w_and_b_scaled = w_and_b_array * scaling * scaling
|
|
1059
|
+
else:
|
|
1060
|
+
w_and_b_scaled = w_and_b_array * scaling
|
|
1061
|
+
w_and_b_out = w_and_b_scaled.astype(np.int64).tolist()
|
|
1062
|
+
w.tensor = w_and_b_out
|
|
1063
|
+
except (ValueError, TypeError, OverflowError, Exception) as e:
|
|
1064
|
+
raise SerializationError(
|
|
1065
|
+
tensor_name=getattr(w, "name", None),
|
|
1066
|
+
reason=str(e),
|
|
1067
|
+
) from e
|
|
1068
|
+
|
|
1069
|
+
inputs = []
|
|
1070
|
+
outputs = []
|
|
1071
|
+
for graph_input in self.model.graph.input:
|
|
1072
|
+
shape = output_name_to_shape.get(graph_input.name, [])
|
|
1073
|
+
elem_type = getattr(graph_input, "elem_type", -1)
|
|
1074
|
+
inputs.append(ONNXIO(graph_input.name, elem_type, shape))
|
|
1075
|
+
|
|
1076
|
+
for output in self.model.graph.output:
|
|
1077
|
+
shape = output_name_to_shape.get(output.name, [])
|
|
1078
|
+
elem_type = getattr(output, "elem_type", -1)
|
|
1079
|
+
outputs.append(ONNXIO(output.name, elem_type, shape))
|
|
1080
|
+
|
|
1081
|
+
# Get version from package metadata
|
|
1082
|
+
try:
|
|
1083
|
+
version = get_version(PACKAGE_NAME)
|
|
1084
|
+
except Exception:
|
|
1085
|
+
version = "0.0.0"
|
|
1086
|
+
|
|
1087
|
+
architecture = {
|
|
1088
|
+
"architecture": [asdict(a) for a in architecture],
|
|
1089
|
+
}
|
|
1090
|
+
weights = {"w_and_b": [asdict(w_b) for w_b in w_and_b]}
|
|
1091
|
+
circuit_params = {
|
|
1092
|
+
"scale_base": getattr(self, "scale_base", 2),
|
|
1093
|
+
"scale_exponent": getattr(self, "scale_exponent", 18),
|
|
1094
|
+
"rescale_config": getattr(self, "rescale_config", {}),
|
|
1095
|
+
"inputs": [asdict(i) for i in inputs],
|
|
1096
|
+
"outputs": [asdict(o) for o in outputs],
|
|
1097
|
+
"version": version,
|
|
1098
|
+
}
|
|
1099
|
+
return architecture, weights, circuit_params
|
|
1100
|
+
|
|
1101
|
+
def get_model_and_quantize(self: ONNXConverter) -> None:
|
|
1102
|
+
"""Load the configured model (by path) and build its quantized form.
|
|
1103
|
+
|
|
1104
|
+
Expects the instance to define ``self.model_file_name`` beforehand.
|
|
1105
|
+
|
|
1106
|
+
Raises:
|
|
1107
|
+
FileNotFoundError: If ``self.model_file_name`` is unset or invalid.
|
|
1108
|
+
"""
|
|
1109
|
+
if hasattr(self, "model_file_name"):
|
|
1110
|
+
self.load_model(self.model_file_name)
|
|
1111
|
+
else:
|
|
1112
|
+
msg = "An ONNX model is required at the specified path"
|
|
1113
|
+
raise FileNotFoundError(msg)
|
|
1114
|
+
self.quantized_model = self.quantize_model(
|
|
1115
|
+
self.model,
|
|
1116
|
+
getattr(self, "scale_base", 2),
|
|
1117
|
+
getattr(self, "scale_exponent", 18),
|
|
1118
|
+
rescale_config=getattr(self, "rescale_config", {}),
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
def get_outputs(
|
|
1122
|
+
self: ONNXConverter,
|
|
1123
|
+
inputs: np.ndarray | torch.Tensor,
|
|
1124
|
+
) -> list[np.ndarray]:
|
|
1125
|
+
"""Run the currently loaded (quantized) model via ONNX Runtime.
|
|
1126
|
+
|
|
1127
|
+
Args:
|
|
1128
|
+
inputs (Any): Input array/tensor matching the models first input.
|
|
1129
|
+
|
|
1130
|
+
Returns:
|
|
1131
|
+
Any: The output of the onnxruntime inference.
|
|
1132
|
+
"""
|
|
1133
|
+
try:
|
|
1134
|
+
input_name = self.ort_sess.get_inputs()[0].name
|
|
1135
|
+
output_name = self.ort_sess.get_outputs()[0].name
|
|
1136
|
+
|
|
1137
|
+
# TODO @jsgold-1: This may cause some rounding errors at some point but works for now. # noqa: FIX002, E501, TD003
|
|
1138
|
+
inputs = torch.as_tensor(inputs)
|
|
1139
|
+
if inputs.dtype in (
|
|
1140
|
+
torch.int8,
|
|
1141
|
+
torch.int16,
|
|
1142
|
+
torch.int32,
|
|
1143
|
+
torch.int64,
|
|
1144
|
+
torch.uint8,
|
|
1145
|
+
):
|
|
1146
|
+
inputs = inputs.double()
|
|
1147
|
+
inputs = inputs / BaseOpQuantizer.get_scaling(
|
|
1148
|
+
scale_base=self.scale_base,
|
|
1149
|
+
scale_exponent=self.scale_exponent,
|
|
1150
|
+
)
|
|
1151
|
+
if self.ort_sess.get_inputs()[0].type == "tensor(double)":
|
|
1152
|
+
return self.ort_sess.run(
|
|
1153
|
+
[output_name],
|
|
1154
|
+
{input_name: np.asarray(inputs).astype(np.float64)},
|
|
1155
|
+
)
|
|
1156
|
+
return self.ort_sess.run(
|
|
1157
|
+
[output_name],
|
|
1158
|
+
{input_name: np.asarray(inputs)},
|
|
1159
|
+
)
|
|
1160
|
+
except (RuntimeError, ValueError, TypeError, Exception) as e:
|
|
1161
|
+
raise InferenceError(
|
|
1162
|
+
model_path=getattr(self, "quantized_model_path", None),
|
|
1163
|
+
model_type=self.model_type,
|
|
1164
|
+
reason=str(e),
|
|
1165
|
+
) from e
|
|
1166
|
+
|
|
1167
|
+
|
|
1168
|
+
if __name__ == "__main__":
|
|
1169
|
+
path = "./models_onnx/doom.onnx"
|
|
1170
|
+
|
|
1171
|
+
converter = ONNXConverter()
|
|
1172
|
+
converter.model_file_name, converter.quantized_model_file_name = (
|
|
1173
|
+
path,
|
|
1174
|
+
"quantized_doom.onnx",
|
|
1175
|
+
)
|
|
1176
|
+
converter.scale_base, converter.scale_exponent = 2, 18
|
|
1177
|
+
|
|
1178
|
+
converter.load_model(path)
|
|
1179
|
+
converter.get_model_and_quantize()
|
|
1180
|
+
|
|
1181
|
+
converter.test_accuracy()
|