JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +6 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,1181 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import logging
5
+ from dataclasses import asdict, dataclass
6
+ from importlib.metadata import version as get_version
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import onnx
11
+ import torch
12
+ from onnx import NodeProto, TensorProto, helper, numpy_helper, shape_inference
13
+
14
+ # Keep the ununused import below as it
15
+ # must remain due to 'SessionOptions' dependency.
16
+ from onnxruntime import InferenceSession, SessionOptions
17
+ from onnxruntime_extensions import get_library_path
18
+
19
+ import python.core.model_processing.onnx_custom_ops # noqa: F401
20
+ from python.core import PACKAGE_NAME
21
+ from python.core.model_processing.converters.base import ModelConverter, ModelType
22
+ from python.core.model_processing.errors import (
23
+ InferenceError,
24
+ InvalidModelError,
25
+ IOInfoExtractionError,
26
+ LayerAnalysisError,
27
+ ModelConversionError,
28
+ ModelLoadError,
29
+ ModelSaveError,
30
+ SerializationError,
31
+ )
32
+ from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
33
+ extract_shape_dict,
34
+ get_input_shapes,
35
+ parse_attributes,
36
+ )
37
+ from python.core.model_processing.onnx_quantizer.exceptions import QuantizationError
38
+ from python.core.model_processing.onnx_quantizer.layers.base import (
39
+ BaseOpQuantizer,
40
+ ScaleConfig,
41
+ )
42
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
43
+ ONNXOpQuantizer,
44
+ )
45
+
46
+ try:
47
+ import tomllib # Python 3.11+
48
+ except ModuleNotFoundError:
49
+ import tomli as tomllib # noqa: F401
50
+
51
+ ONNXLayerDict = dict[
52
+ str,
53
+ int | str | list[str] | dict[str, list[int]] | list | None | dict,
54
+ ]
55
+
56
+ ONNXIODict = dict[str, str | int | list[int]]
57
+
58
+ CircuitParamsDict = dict[str, int | dict[str, bool]]
59
+
60
+
61
+ @dataclass
62
+ class ONNXLayer:
63
+ """
64
+ A dataclass representing an ONNX layer in the form
65
+ to be sent to the circuit building process.
66
+
67
+ This class encapsulates the essential information
68
+ about a layer in an ONNX model. It is designed to facilitate the
69
+ conversion and processing of ONNX models for circuit building purposes.
70
+
71
+ Attributes:
72
+ id (int): A unique identifier for the layer.
73
+ name (str): The name of the layer.
74
+ op_type (str): The operation type of the layer,
75
+ such as "Conv" for convolution layers.
76
+ inputs (list[str]): A list of input names that this layer depends on.
77
+ outputs (list[str]): A list of output names produced by this layer.
78
+ shape (dict[str, list[int]]): A dictionary mapping output names
79
+ to their corresponding shapes.
80
+ tensor (Optional[list]): For constant nodes, this contains the
81
+ tensor data (weights or biases) as a list. For other layers, empty.
82
+ params (Optional[dict]): A dictionary of parameters specific to the
83
+ layer's operation. For example, convolution layers may include parameters
84
+ like dilation, kernel_shape, pad, strides, and group.
85
+ opset_version_number (int): The version number of the ONNX opset
86
+ used for this operation. This is included for infrastructure
87
+ purposes and may not be actively used in all processing steps.
88
+ """
89
+
90
+ id: int
91
+ name: str
92
+ op_type: str
93
+ inputs: list[str]
94
+ outputs: list[str]
95
+ shape: dict[
96
+ str,
97
+ list[int],
98
+ ]
99
+ tensor: list | None
100
+ params: dict | None
101
+ opset_version_number: int
102
+
103
+
104
+ @dataclass
105
+ class ONNXIO:
106
+ """
107
+ A dataclass representing an ONNX input or output,
108
+ in the form to be sent to the circuit building process
109
+ """
110
+
111
+ name: str
112
+ elem_type: int
113
+ shape: list[int]
114
+
115
+
116
+ class ONNXConverter(ModelConverter):
117
+ """Concrete implementation of `ModelConverter` for ONNX models."""
118
+
119
+ def __init__(self: ONNXConverter) -> None:
120
+ """Initialize the converter and its operator quantizer.
121
+
122
+ Initializes:
123
+ self.op_quantizer (ONNXOpQuantizer): Dispatcher that quantizes
124
+ individual ONNX ops and accumulates newly created initializers.
125
+ """
126
+ self.op_quantizer = ONNXOpQuantizer()
127
+ self.model_type = ModelType.ONNX
128
+ self.logger = logging.getLogger(__name__)
129
+
130
+ def save_model(self: ONNXConverter, file_path: str) -> None:
131
+ """Serialize the ONNX model to file.
132
+
133
+ Args:
134
+ file_path (str):
135
+ Destination path (e.g., ``"models/my_model.onnx"``).
136
+
137
+ Note
138
+ ----
139
+ - For saving and loading:
140
+ https://onnx.ai/onnx/intro/python.html,
141
+ larger models may require a different structure
142
+ """
143
+ try:
144
+ Path(file_path).parent.mkdir(parents=True, exist_ok=True)
145
+ onnx.save(self.model, file_path)
146
+ except Exception as e:
147
+ raise ModelSaveError(
148
+ file_path,
149
+ model_type=self.model_type,
150
+ reason=str(e),
151
+ ) from e
152
+
153
+ def load_model(self: ONNXConverter, file_path: str) -> onnx.ModelProto:
154
+ """Load an ONNX model from file and extract basic I/O metadata.
155
+
156
+ Args:
157
+ file_path (str): Path to the `.onnx` file.
158
+
159
+ Returns:
160
+ onnx.ModelProto: The loaded onnx model.
161
+
162
+ Raises:
163
+ ModelLoadError: If the model cannot be loaded or validated.
164
+ """
165
+ try:
166
+ onnx_model = onnx.load(file_path)
167
+ except Exception as e:
168
+ raise ModelLoadError(
169
+ file_path,
170
+ model_type=self.model_type,
171
+ reason=str(e),
172
+ ) from e
173
+
174
+ self.model = onnx_model
175
+
176
+ try:
177
+ self._extract_model_io_info(onnx_model)
178
+ except Exception as e:
179
+ raise IOInfoExtractionError(
180
+ model_path=file_path,
181
+ model_type=self.model_type,
182
+ reason=str(e),
183
+ ) from e
184
+ return self.model
185
+
186
+ def save_quantized_model(self: ONNXConverter, file_path: str) -> None:
187
+ """Serialize the quantized ONNX model to file.
188
+
189
+ Args:
190
+ file_path (str): Destination path for the quantized model.
191
+ """
192
+ try:
193
+ Path(file_path).parent.mkdir(parents=True, exist_ok=True)
194
+ onnx.save(self.quantized_model, file_path)
195
+ except Exception as e:
196
+ raise ModelSaveError(
197
+ file_path,
198
+ model_type=self.model_type,
199
+ reason=str(e),
200
+ ) from e
201
+
202
+ # Not sure this is ideal
203
+ def load_quantized_model(self: ONNXConverter, file_path: str) -> None:
204
+ """Load a quantized ONNX model and create an inference session.
205
+
206
+ Note
207
+ ----
208
+ - Uses the custom opset for the quantized layers
209
+
210
+ Args:
211
+ file_path (str): Path to the quantized ``.onnx`` file.
212
+
213
+ Raises:
214
+ FileNotFoundError: If the file does not exist.
215
+ ModelLoadError: If loading or validation fails.
216
+ """
217
+ if not Path(file_path).exists():
218
+ msg = f"Quantized model file not found: {file_path}"
219
+ raise FileNotFoundError(msg)
220
+ self.logger.info("Loading quantized model from", extra={"file_path": file_path})
221
+ onnx_model = onnx.load(file_path)
222
+ custom_domain = onnx.helper.make_operatorsetid(
223
+ domain="ai.onnx.contrib",
224
+ version=1,
225
+ )
226
+ onnx_model.opset_import.append(custom_domain)
227
+ # Fix, can remove this next line
228
+ self.quantized_model = onnx_model
229
+ self.ort_sess = self._create_inference_session(file_path)
230
+ self._extract_model_io_info(onnx_model)
231
+
232
+ self.quantized_model_path = file_path
233
+
234
+ def _onnx_check_model_safely(self: ONNXConverter, model: onnx.ModelProto) -> None:
235
+ try:
236
+ onnx.checker.check_model(model)
237
+ except Exception as e:
238
+ raise InvalidModelError(
239
+ model_path=getattr(self, "model_file_name", None),
240
+ reason=f"Model validation failed: {e!s}",
241
+ ) from e
242
+
243
+ def analyze_layers(
244
+ self: ONNXConverter,
245
+ output_name_to_shape: dict[str, list[int]] | None = None,
246
+ ) -> tuple[list[ONNXLayer], list[ONNXLayer]]:
247
+ """Analyze the onnx model graph into
248
+ logical layers and parameter tensors.
249
+
250
+ Args:
251
+ output_name_to_shape (dict[str, list[int]], optional):
252
+ mapping of value name -> shape. If omitted,
253
+ shapes are inferred via `onnx.shape_inference`. Defaults to None.
254
+
255
+ Returns:
256
+ Tuple[list[ONNXLayer], list[ONNXLayer]]: ``(architecture, w_and_b)`` where:
257
+ - ``architecture`` is a list of `ONNXLayer` describing
258
+ the computational graph.
259
+ - ``w_and_b`` is a list of `ONNXLayer` representing
260
+ constant tensors (initializers).
261
+ """
262
+ try:
263
+ id_count = 0
264
+ # Apply shape inference on the model
265
+ if not output_name_to_shape:
266
+ inferred_model = shape_inference.infer_shapes(self.model)
267
+ self._onnx_check_model_safely(inferred_model)
268
+
269
+ output_name_to_shape = extract_shape_dict(inferred_model)
270
+ domain_to_version = {
271
+ opset.domain: opset.version for opset in self.model.opset_import
272
+ }
273
+
274
+ id_count = 0
275
+ architecture = self.get_model_architecture(
276
+ self.model,
277
+ output_name_to_shape,
278
+ domain_to_version,
279
+ )
280
+ w_and_b = self.get_model_w_and_b(
281
+ self.model,
282
+ output_name_to_shape,
283
+ id_count,
284
+ domain_to_version,
285
+ )
286
+ except InvalidModelError:
287
+ raise
288
+ except (ValueError, TypeError, RuntimeError, OSError, onnx.ONNXException) as e:
289
+ raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
290
+ except Exception as e:
291
+ raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
292
+ else:
293
+ return (architecture, w_and_b)
294
+
295
+ def run_model_onnx_runtime(
296
+ self: ONNXConverter,
297
+ path: str,
298
+ inputs: torch.Tensor,
299
+ ) -> list[np.ndarray]:
300
+ """Execute a model on CPU via ONNX Runtime and return its outputs.
301
+
302
+ Creates a fresh inference session for the model at ``path``, feeds
303
+ the provided tensor under the first input name, and returns the
304
+ first output.
305
+
306
+ Args:
307
+ path (str): Path to the ONNX model to execute.
308
+ input (torch.Tensor): Input tensor to feed into the model's first input.
309
+
310
+ Returns:
311
+ Any: The output(s) as returned by `InferenceSession.run`.
312
+ """
313
+ # Fix, can remove this next line
314
+ try:
315
+ ort_sess = self._create_inference_session(path)
316
+ input_name = ort_sess.get_inputs()[0].name
317
+ output_name = ort_sess.get_outputs()[0].name
318
+ if ort_sess.get_inputs()[0].type == "tensor(double)":
319
+ outputs = ort_sess.run(
320
+ [output_name],
321
+ {input_name: np.asarray(inputs).astype(np.float64)},
322
+ )
323
+ else:
324
+ outputs = ort_sess.run([output_name], {input_name: np.asarray(inputs)})
325
+
326
+ except (
327
+ ModelConversionError,
328
+ RuntimeError,
329
+ ValueError,
330
+ TypeError,
331
+ OSError,
332
+ Exception,
333
+ ) as e:
334
+ raise InferenceError(
335
+ model_path=path,
336
+ model_type=self.model_type,
337
+ reason=str(e),
338
+ ) from e
339
+ else:
340
+ return outputs
341
+
342
+ def _collect_constant_values(
343
+ self: ONNXConverter,
344
+ model: onnx.ModelProto,
345
+ ) -> dict[str, np.ndarray]:
346
+ """Collect constant values from Constant nodes in the model.
347
+
348
+ Args:
349
+ model (onnx.ModelProto): The ONNX model to analyze.
350
+
351
+ Returns:
352
+ dict[str, np.ndarray]: Mapping of output name to constant value.
353
+ """
354
+ constant_values = {}
355
+ for node in model.graph.node:
356
+ if node.op_type == "Constant":
357
+ self.logger.debug("Constant node", extra={"node": node})
358
+ for attr in node.attribute:
359
+ if attr.name == "value":
360
+ tensor = attr.t
361
+ const_value = numpy_helper.to_array(tensor)
362
+ constant_values[node.output[0]] = const_value
363
+ return constant_values
364
+
365
+ def _attach_constant_parameters(
366
+ self: ONNXConverter,
367
+ layer: ONNXLayer,
368
+ node: NodeProto,
369
+ constant_values: dict[str, np.ndarray],
370
+ ) -> None:
371
+ """Attach constant inputs as parameters to a layer.
372
+
373
+ Args:
374
+ layer (ONNXLayer): The layer to modify.
375
+ node (NodeProto): The ONNX node being processed.
376
+ constant_values (dict[str, np.ndarray]): Constant values mapping.
377
+ """
378
+ for input_name in node.input:
379
+ if input_name in constant_values:
380
+ self.logger.debug(
381
+ "Layer params before:",
382
+ extra={"layer_params": layer.params},
383
+ )
384
+ if not hasattr(layer, "params") or layer.params is None:
385
+ layer.params = {}
386
+ result = constant_values[input_name]
387
+ if isinstance(result, (np.ndarray, torch.Tensor)):
388
+ layer.params[input_name] = result.tolist()
389
+ else:
390
+ layer.params[input_name] = constant_values[input_name]
391
+ self.logger.debug(
392
+ "Updated layer params",
393
+ extra={"layer_params": layer.params},
394
+ )
395
+
396
+ def get_model_architecture(
397
+ self: ONNXConverter,
398
+ model: onnx.ModelProto,
399
+ output_name_to_shape: dict[str, list[int]],
400
+ domain_to_version: dict[str, int] | None = None,
401
+ ) -> list[ONNXLayer]:
402
+ """Construct ONNXLayer objects for architecture graph nodes
403
+ (not weights or biases).
404
+
405
+ Args:
406
+ model (onnx.ModelProto): The ONNX model to analyze.
407
+ output_name_to_shape (dict[str, list[int]]):
408
+ Map of value name -> inferred shape.
409
+ id_count (int, optional):
410
+ Starting numeric ID for layers (incremented per node).
411
+ Defaults to 0.
412
+ domain_to_version (dict[str, int], optional):
413
+ Map of opset domain -> version used. Defaults to None.
414
+
415
+ Returns:
416
+ list[ONNXLayer]:
417
+ Models computational layers
418
+ (excluding initializers) in the form of ONNXLayers.
419
+ """
420
+ _ = domain_to_version
421
+ constant_values = self._collect_constant_values(model)
422
+ layers = []
423
+ current_id = 0
424
+
425
+ for node in model.graph.node:
426
+ if node.op_type == "Constant":
427
+ continue # Skip constant nodes
428
+
429
+ layer = self.analyze_layer(
430
+ node,
431
+ output_name_to_shape,
432
+ current_id,
433
+ domain_to_version,
434
+ )
435
+ self.logger.debug(
436
+ "Layer",
437
+ extra={
438
+ "layer_name": layer.name,
439
+ "layer_op": layer.op_type,
440
+ "layer_shape": layer.shape,
441
+ },
442
+ )
443
+
444
+ self._attach_constant_parameters(layer, node, constant_values)
445
+ layers.append(layer)
446
+ current_id += 1
447
+
448
+ return layers
449
+
450
+ def get_model_w_and_b(
451
+ self: ONNXConverter,
452
+ model: onnx.ModelProto,
453
+ output_name_to_shape: dict[str, list[int]],
454
+ id_count: int = 0,
455
+ domain_to_version: dict[str, int] | None = None,
456
+ ) -> list[ONNXLayer]:
457
+ """Extract constant initializers (weights/biases) as layers.
458
+
459
+ Iterates through graph initializers and wraps each tensor
460
+ into an ONNXLayers.
461
+
462
+ Args:
463
+ model (onnx.ModelProto): The ONNX model to analyze.
464
+ output_name_to_shape (dict[str, list[int]]):
465
+ Map of value name -> inferred shape
466
+ id_count (int, optional):
467
+ Starting numeric ID for layers (incremented per tensor).
468
+ Defaults to 0.
469
+ domain_to_version (dict[str, int], optional):
470
+ Map of opset domain -> version used (unused). Defaults to None.
471
+
472
+ Returns:
473
+ list[ONNXLayer]: ONNXLayers representing weights/biases found in the graph
474
+ """
475
+ layers = []
476
+ # Check the model and print Y"s shape information
477
+ for _, node in enumerate(model.graph.initializer):
478
+ layer = self.analyze_constant(
479
+ node,
480
+ output_name_to_shape,
481
+ id_count,
482
+ domain_to_version,
483
+ )
484
+ layers.append(layer)
485
+ id_count += 1
486
+
487
+ return layers
488
+
489
+ def _create_inference_session(
490
+ self: ONNXConverter,
491
+ model_path: str,
492
+ ) -> InferenceSession:
493
+ """Internal helper to create and configure an ONNX Runtime InferenceSession.
494
+ Registers a custom ops shared library for use with the
495
+ custom quantized operations.
496
+
497
+ Args:
498
+ model_path (str): Path to the ONNX model to load.
499
+
500
+ Returns:
501
+ InferenceSession: A configured InferenceSession.
502
+ """
503
+ try:
504
+ opts = SessionOptions()
505
+ opts.register_custom_ops_library(get_library_path())
506
+ return InferenceSession(
507
+ model_path,
508
+ opts,
509
+ providers=["CPUExecutionProvider"],
510
+ )
511
+ except (OSError, onnx.ONNXException, RuntimeError, Exception) as e:
512
+ raise InferenceError(
513
+ model_path,
514
+ model_type=self.model_type,
515
+ reason=str(e),
516
+ ) from e
517
+
518
+ def analyze_layer(
519
+ self: ONNXConverter,
520
+ node: NodeProto,
521
+ output_name_to_shape: dict[str, list[int]],
522
+ id_count: int = -1,
523
+ domain_to_version: dict[str, int] | None = None,
524
+ ) -> ONNXLayer:
525
+ """Convert a non-constant ONNX node into a structured ONNXLayer.
526
+
527
+ Args:
528
+ node (NodeProto): The ONNX node to analyze.
529
+ output_name_to_shape (dict[str, list[int]]):
530
+ Map of value name -> inferred shape.
531
+ id_count (int, optional):
532
+ Numeric ID to assign to this layer (increment handled by caller).
533
+ Defaults to -1.
534
+ domain_to_version (dict[str, int], optional):
535
+ Map of opset domain -> version number. Defaults to None.
536
+
537
+ Returns:
538
+ ONNXLayer: ONNXLayer describing the node
539
+ """
540
+ name = node.name
541
+ layer_id = id_count
542
+ id_count += 1
543
+ op_type = node.op_type
544
+ inputs = node.input
545
+ outputs = node.output
546
+ opset_version = (
547
+ domain_to_version.get(node.domain, "unknown") if domain_to_version else -1
548
+ )
549
+ params = parse_attributes(node.attribute)
550
+
551
+ # Extract output shapes
552
+ output_shapes = {
553
+ out_name: output_name_to_shape.get(out_name, []) for out_name in outputs
554
+ }
555
+ return ONNXLayer(
556
+ id=layer_id,
557
+ name=name,
558
+ op_type=op_type,
559
+ inputs=list(inputs),
560
+ outputs=list(outputs),
561
+ shape=output_shapes,
562
+ params=params,
563
+ opset_version_number=opset_version,
564
+ tensor=None,
565
+ )
566
+
567
+ def analyze_constant(
568
+ self: ONNXConverter,
569
+ node: TensorProto,
570
+ output_name_to_shape: dict[str, list[int]],
571
+ id_count: int = -1,
572
+ domain_to_version: dict[str, int] | None = None,
573
+ ) -> list[ONNXLayer]:
574
+ """Convert a constant ONNX node (weights or bias) into a structured ONNXLayer.
575
+
576
+ Args:
577
+ node (NodeProto): The ONNX node to analyze.
578
+ output_name_to_shape (dict[str, list[int]]):
579
+ Map of value name -> inferred shape.
580
+ id_count (int, optional):
581
+ Numeric ID to assign to this layer (increment handled by caller).
582
+ Defaults to -1.
583
+ domain_to_version (dict[str, int], optional):
584
+ Map of opset domain -> version number. Defaults to None.
585
+
586
+ Returns:
587
+ ONNXLayer: ONNXLayer describing the node
588
+ """
589
+ _ = domain_to_version
590
+ name = node.name
591
+ id_count += 1
592
+ op_type = "Const"
593
+ inputs = []
594
+ outputs = []
595
+ opset_version = -1
596
+ params = {}
597
+ constant_dtype = node.data_type
598
+ # Can do this step in rust potentially to keep file sizes low if needed
599
+ try:
600
+ np_data = onnx.numpy_helper.to_array(node, constant_dtype)
601
+ except (ValueError, TypeError, onnx.ONNXException, Exception) as e:
602
+ raise SerializationError(
603
+ tensor_name=node.name,
604
+ reason=f"Failed to convert tensor: {e!s}",
605
+ ) from e
606
+ # 💡 Extract output shapes
607
+ output_shapes = {
608
+ out_name: output_name_to_shape.get(out_name, []) for out_name in outputs
609
+ }
610
+ return ONNXLayer(
611
+ id=id_count,
612
+ name=name,
613
+ op_type=op_type,
614
+ inputs=list(inputs),
615
+ outputs=list(outputs),
616
+ shape=output_shapes,
617
+ params=params,
618
+ opset_version_number=opset_version,
619
+ tensor=np_data.tolist(),
620
+ )
621
+
622
+ def _prepare_model_for_quantization(
623
+ self: ONNXConverter,
624
+ unscaled_model: onnx.ModelProto,
625
+ ) -> tuple[onnx.ModelProto, dict[str, onnx.TensorProto], list[str]]:
626
+ """Prepare the model for quantization by creating a copy and necessary mappings.
627
+
628
+ Args:
629
+ unscaled_model (onnx.ModelProto): The original unscaled model.
630
+
631
+ Returns:
632
+ tuple[onnx.ModelProto, dict[str, onnx.TensorProto], list[str]]:
633
+ Model copy, initializer map, and input names.
634
+ """
635
+ model = copy.deepcopy(unscaled_model)
636
+ self.op_quantizer.check_model(model)
637
+ initializer_map = {init.name: init for init in model.graph.initializer}
638
+ input_names = [inp.name for inp in unscaled_model.graph.input]
639
+ return model, initializer_map, input_names
640
+
641
+ def _quantize_inputs(
642
+ self: ONNXConverter,
643
+ model: onnx.ModelProto,
644
+ input_names: list[str],
645
+ scale_base: int,
646
+ scale_exponent: int,
647
+ ) -> list[onnx.NodeProto]:
648
+ """Quantize model inputs and update node connections.
649
+
650
+ Args:
651
+ model (onnx.ModelProto): The model being quantized.
652
+ input_names (list[str]): Names of input tensors.
653
+ scale_base (int): Base for scaling.
654
+ scale_exponent (int): Exponent for scaling.
655
+
656
+ Returns:
657
+ list[onnx.NodeProto]: New nodes created for input quantization.
658
+ """
659
+ new_nodes = []
660
+ for name in input_names:
661
+ output_name, mul_node, _, cast_to_int64 = self.quantize_input(
662
+ input_name=name,
663
+ op_quantizer=self.op_quantizer,
664
+ scale_base=scale_base,
665
+ scale_exponent=scale_exponent,
666
+ )
667
+ new_nodes.extend([mul_node, cast_to_int64])
668
+
669
+ # Update references to this input in all nodes
670
+ for node in model.graph.node:
671
+ for idx, inp in enumerate(node.input):
672
+ if inp == name:
673
+ node.input[idx] = output_name
674
+ return new_nodes
675
+
676
+ def _update_input_types(self: ONNXConverter, model: onnx.ModelProto) -> None:
677
+ """Update input tensor types from float32 to float64.
678
+
679
+ Args:
680
+ model (onnx.ModelProto): The model to update.
681
+ """
682
+ for input_tensor in model.graph.input:
683
+ tensor_type = input_tensor.type.tensor_type
684
+ if tensor_type.elem_type == TensorProto.FLOAT:
685
+ tensor_type.elem_type = TensorProto.DOUBLE
686
+
687
+ def _quantize_nodes(
688
+ self: ONNXConverter,
689
+ model: onnx.ModelProto,
690
+ scale_config: ScaleConfig,
691
+ rescale_config: dict | None,
692
+ initializer_map: dict[str, onnx.TensorProto],
693
+ ) -> list[onnx.NodeProto]:
694
+ """Quantize all nodes in the model.
695
+
696
+ Args:
697
+ model (onnx.ModelProto): The model being quantized.
698
+ scale_base (int): Base for scaling.
699
+ scale_exponent (int): Exponent for scaling.
700
+ rescale_config (dict, optional): Rescale configuration.
701
+ initializer_map (dict[str, onnx.TensorProto]): Initializer mapping.
702
+
703
+ Returns:
704
+ list[onnx.NodeProto]: Quantized nodes.
705
+ """
706
+ quantized_nodes = []
707
+ for node in model.graph.node:
708
+ rescale = rescale_config.get(node.name, False) if rescale_config else True
709
+ quant_nodes = self.quantize_layer(
710
+ node=node,
711
+ model=model,
712
+ scale_config=ScaleConfig(
713
+ exponent=scale_config.exponent,
714
+ base=scale_config.base,
715
+ rescale=rescale,
716
+ ),
717
+ initializer_map=initializer_map,
718
+ )
719
+ if isinstance(quant_nodes, list):
720
+ quantized_nodes.extend(quant_nodes)
721
+ else:
722
+ quantized_nodes.append(quant_nodes)
723
+ return quantized_nodes
724
+
725
+ def _process_initializers(
726
+ self: ONNXConverter,
727
+ model: onnx.ModelProto,
728
+ initializer_map: dict[str, onnx.TensorProto],
729
+ ) -> list[onnx.TensorProto]:
730
+ """Process and filter initializers, converting types as needed.
731
+
732
+ Args:
733
+ model (onnx.ModelProto): The quantized model.
734
+ initializer_map (dict[str, onnx.TensorProto]): Original initializer map.
735
+
736
+ Returns:
737
+ list[onnx.TensorProto]: Processed initializers to keep.
738
+ """
739
+ used_initializer_names = set()
740
+ for node in model.graph.node:
741
+ used_initializer_names.update(node.input)
742
+
743
+ kept_initializers = []
744
+ for name in used_initializer_names:
745
+ if name in initializer_map:
746
+ orig_init = initializer_map[name]
747
+ np_array = numpy_helper.to_array(orig_init)
748
+
749
+ if np_array.dtype == np.float32:
750
+ np_array = np_array.astype(np.float64)
751
+ new_init = numpy_helper.from_array(np_array, name=name)
752
+ kept_initializers.append(new_init)
753
+ else:
754
+ kept_initializers.append(orig_init)
755
+
756
+ return kept_initializers
757
+
758
+ def _update_graph_types(self: ONNXConverter, model: onnx.ModelProto) -> None:
759
+ """Update output and value_info types to INT64.
760
+
761
+ Args:
762
+ model (onnx.ModelProto): The model to update.
763
+ """
764
+ for out in model.graph.output:
765
+ out.type.tensor_type.elem_type = onnx.TensorProto.INT64
766
+
767
+ for vi in model.graph.value_info:
768
+ vi.type.tensor_type.elem_type = TensorProto.INT64
769
+
770
+ def _add_custom_domain(self: ONNXConverter, model: onnx.ModelProto) -> None:
771
+ """Add custom opset domain if not present.
772
+
773
+ Args:
774
+ model (onnx.ModelProto): The model to update.
775
+ """
776
+ custom_domain = helper.make_operatorsetid(
777
+ domain="ai.onnx.contrib",
778
+ version=1,
779
+ )
780
+ domains = [op.domain for op in model.opset_import]
781
+ if "ai.onnx.contrib" not in domains:
782
+ model.opset_import.append(custom_domain)
783
+
784
+ def _log_quantization_results(self: ONNXConverter, model: onnx.ModelProto) -> None:
785
+ """Log quantization results for debugging.
786
+
787
+ Args:
788
+ model (onnx.ModelProto): The quantized model.
789
+ """
790
+ for layer in model.graph.node:
791
+ self.logger.debug(
792
+ "Node",
793
+ extra={
794
+ "layer_name": layer.name,
795
+ "op_type": layer.op_type,
796
+ "input": layer.input,
797
+ "output": layer.output,
798
+ },
799
+ )
800
+
801
+ for layer in model.graph.initializer:
802
+ self.logger.debug("Initializer", extra={"layer_name": layer.name})
803
+
804
+ def quantize_model(
805
+ self: ONNXConverter,
806
+ unscaled_model: onnx.ModelProto,
807
+ scale_base: int,
808
+ scale_exponent: int,
809
+ rescale_config: dict | None = None,
810
+ ) -> onnx.ModelProto:
811
+ """Produce a quantized ONNX graph from a floating-point model.
812
+
813
+ Args:
814
+ unscaled_model (onnx.ModelProto): The original unscaled model.
815
+ scale_base (int): Base for fixed-point scaling (e.g., 2).
816
+ scale_exponent (int):
817
+ Exponent for scaling (e.g., 18 would lead to a scale factor 2**18).
818
+ rescale_config (dict, optional): mapping of node name -> bool to control
819
+ whether a given node should apply a final rescale. Defaults to None.
820
+
821
+ Returns:
822
+ onnx.ModelProto: A new onnx model representation of the quantized model.
823
+ """
824
+ try:
825
+ # Prepare model and mappings
826
+ model, initializer_map, input_names = self._prepare_model_for_quantization(
827
+ unscaled_model,
828
+ )
829
+
830
+ # Quantize inputs and collect new nodes
831
+ new_nodes = self._quantize_inputs(
832
+ model,
833
+ input_names,
834
+ scale_base,
835
+ scale_exponent,
836
+ )
837
+
838
+ # Update input types
839
+ self._update_input_types(model)
840
+
841
+ # Quantize all nodes
842
+ quantized_nodes = self._quantize_nodes(
843
+ model,
844
+ ScaleConfig(scale_exponent, scale_base, rescale=True),
845
+ rescale_config,
846
+ initializer_map,
847
+ )
848
+ new_nodes.extend(quantized_nodes)
849
+
850
+ # Update graph with new nodes
851
+ model.graph.ClearField("node")
852
+ model.graph.node.extend(new_nodes)
853
+
854
+ # Process initializers
855
+ kept_initializers = self._process_initializers(model, initializer_map)
856
+
857
+ # Update graph initializers
858
+ model.graph.ClearField("initializer")
859
+ model.graph.initializer.extend(kept_initializers)
860
+ model.graph.initializer.extend(self.op_quantizer.new_initializers)
861
+ self.op_quantizer.new_initializers = []
862
+
863
+ # Update types and add custom domain
864
+ self._update_graph_types(model)
865
+ self._add_custom_domain(model)
866
+
867
+ # Log results
868
+ self._log_quantization_results(model)
869
+
870
+ except (QuantizationError, ModelConversionError):
871
+ raise
872
+ except (
873
+ onnx.ONNXException,
874
+ ValueError,
875
+ TypeError,
876
+ RuntimeError,
877
+ OSError,
878
+ Exception,
879
+ ) as e:
880
+ msg = "Quantization failed for model"
881
+ f" '{getattr(self, 'model_file_name', 'unknown')}': {e!s}"
882
+ raise ModelConversionError(
883
+ msg,
884
+ model_type=self.model_type,
885
+ ) from e
886
+ else:
887
+ return model
888
+
889
+ def quantize_layer(
890
+ self: ONNXConverter,
891
+ node: onnx.NodeProto,
892
+ model: onnx.ModelProto,
893
+ scale_config: ScaleConfig,
894
+ initializer_map: dict[str, onnx.TensorProto],
895
+ ) -> onnx.NodeProto:
896
+ """Quantize a single ONNX node using the configured op quantizer.
897
+
898
+ Args:
899
+ node (onnx.NodeProto): The original onnx node to quantize.
900
+ model (onnx.ModelProto): The original model used for context
901
+ scale_config (ScaleConfig): Contains the following:
902
+ - rescale (bool): Whether to apply output rescaling for this node.
903
+ - scale_exponent (int):
904
+ Exponent for scaling (e.g., 18 would lead to a scale factor 2**18).
905
+ - scale_base (int): Base for fixed-point scaling (e.g., 2).
906
+ initializer_map (dict[str, onnx.TensorProto]):
907
+ Mapping from initializer name to tensor.
908
+
909
+ Returns:
910
+ onnx.NodeProto:
911
+ A quantized node or list of nodes replacing the initial node.
912
+ """
913
+ try:
914
+ return self.op_quantizer.quantize(
915
+ node=node,
916
+ rescale=scale_config.rescale,
917
+ graph=model.graph,
918
+ scale_exponent=scale_config.exponent,
919
+ scale_base=scale_config.base,
920
+ initializer_map=initializer_map,
921
+ )
922
+ except QuantizationError:
923
+ raise
924
+ except (RuntimeError, ValueError, TypeError, Exception) as e:
925
+ raise ModelConversionError(str(e), model_type=self.model_type) from e
926
+
927
+ def quantize_input(
928
+ self: ONNXConverter,
929
+ input_name: str,
930
+ op_quantizer: ONNXOpQuantizer,
931
+ scale_base: int,
932
+ scale_exponent: int,
933
+ ) -> tuple[str, onnx.NodeProto, onnx.NodeProto, onnx.NodeProto]:
934
+ """Insert scaling and casting nodes to quantize a model input.
935
+
936
+ Creates:
937
+ - Mul: scales the input by scale_base ** scale.
938
+ - Cast (to INT64): produces the final integer input tensor.
939
+
940
+ Args:
941
+ input_name (str): Name of the graph input to quantize.
942
+ op_quantizer (ONNXOpQuantizer): The op quantizer whose
943
+ ``new_initializers`` list is used to store the created scale constant.
944
+ scale_base (int): Base for fixed-point scaling (e.g., 2).
945
+ scale_exponent (int):
946
+ Exponent for scaling (e.g., 18 would lead to a scale factor 2**18).
947
+
948
+ Returns:
949
+ tuple[str, onnx.NodeProto, onnx.NodeProto, onnx.NodeProto]:
950
+ A tuple ``(output_name, mul_node, floor_node, cast_node)`` where
951
+ ``output_name`` is the name of the quantized input tensor
952
+ and the nodes are nodes to add to the graph.
953
+ """
954
+ try:
955
+ scale_value = BaseOpQuantizer.get_scaling(
956
+ scale_base=scale_base,
957
+ scale_exponent=scale_exponent,
958
+ )
959
+
960
+ # === Create scale constant ===
961
+ scale_const_name = input_name + "_scale"
962
+ scale_tensor = numpy_helper.from_array(
963
+ np.array([scale_value], dtype=np.float64),
964
+ name=scale_const_name,
965
+ )
966
+ op_quantizer.new_initializers.append(scale_tensor)
967
+
968
+ # === Add Mul node ===
969
+ scaled_output_name = f"{input_name}_scaled"
970
+ mul_node = helper.make_node(
971
+ "Mul",
972
+ inputs=[input_name, scale_const_name],
973
+ outputs=[scaled_output_name],
974
+ name=f"{input_name}_mul",
975
+ )
976
+ # === Floor node (simulate rounding) ===
977
+ rounded_output_name = f"{input_name}_scaled_floor"
978
+ floor_node = helper.make_node(
979
+ "Floor",
980
+ inputs=[scaled_output_name],
981
+ outputs=[rounded_output_name],
982
+ name=f"{scaled_output_name}",
983
+ )
984
+ output_name = f"{rounded_output_name}_int"
985
+ cast_to_int64 = helper.make_node(
986
+ "Cast",
987
+ inputs=[scaled_output_name],
988
+ outputs=[output_name],
989
+ to=onnx.TensorProto.INT64,
990
+ name=rounded_output_name,
991
+ )
992
+ except (ValueError, TypeError, RuntimeError, OSError, Exception) as e:
993
+ msg = f"Error quantizing inputs: {e}"
994
+ raise ModelConversionError(
995
+ msg,
996
+ self.model_type,
997
+ ) from e
998
+ else:
999
+ return output_name, mul_node, floor_node, cast_to_int64
1000
+
1001
+ def _extract_model_io_info(
1002
+ self: ONNXConverter,
1003
+ onnx_model: onnx.ModelProto,
1004
+ ) -> None:
1005
+ """Populate input metadata from a loaded ONNX model.
1006
+
1007
+ Args:
1008
+ onnx_model (onnx.ModelProto): Onnx model
1009
+ """
1010
+ self.required_keys = [
1011
+ graph_input.name for graph_input in onnx_model.graph.input
1012
+ ]
1013
+ self.input_shape = get_input_shapes(onnx_model)
1014
+
1015
+ def get_weights(self: ONNXConverter) -> tuple[
1016
+ dict[str, list[ONNXLayerDict]],
1017
+ dict[str, list[ONNXLayerDict]],
1018
+ CircuitParamsDict,
1019
+ ]:
1020
+ """Export architecture, weights, and circuit parameters for ECC.
1021
+
1022
+ 1. Analyze the model for architecture + w & b
1023
+ 2. Put arch into format to be read by ECC circuit builder
1024
+ 3. Put w + b into format to be read by ECC circuit builder
1025
+
1026
+ Returns:
1027
+ tuple[dict[str, list[dict[str, Any]]],
1028
+ dict[str, list[dict[str, Any]]], dict[str, Any]]:
1029
+ A tuple ``(architecture, weights, circuit_params)``:
1030
+ - ``architecture``: dict with serialized ``architecture`` layers.
1031
+ - ``weights``: dict containing ``w_and_b`` (serialized tensors).
1032
+ - ``circuit_params``: dict containing scaling parameters and
1033
+ ``rescale_config``.
1034
+ """
1035
+ inferred_model = shape_inference.infer_shapes(self.model)
1036
+
1037
+ scaling = BaseOpQuantizer.get_scaling(
1038
+ scale_base=getattr(self, "scale_base", 2),
1039
+ scale_exponent=(getattr(self, "scale_exponent", 18)),
1040
+ )
1041
+
1042
+ # Check the model and print Y"s shape information
1043
+ self._onnx_check_model_safely(inferred_model)
1044
+ output_name_to_shape = extract_shape_dict(inferred_model)
1045
+ (architecture, w_and_b) = self.analyze_layers(output_name_to_shape)
1046
+ for w in w_and_b:
1047
+ try:
1048
+ w_and_b_array = np.asarray(w.tensor)
1049
+ except (ValueError, TypeError, Exception) as e:
1050
+ raise SerializationError(
1051
+ tensor_name=getattr(w, "name", None),
1052
+ reason=f"cannot convert to ndarray: {e}",
1053
+ ) from e
1054
+
1055
+ try:
1056
+ # TODO @jsgold-1: We need a better way to distinguish bias tensors from weight tensors # noqa: FIX002, TD003,E501
1057
+ if "bias" in w.name:
1058
+ w_and_b_scaled = w_and_b_array * scaling * scaling
1059
+ else:
1060
+ w_and_b_scaled = w_and_b_array * scaling
1061
+ w_and_b_out = w_and_b_scaled.astype(np.int64).tolist()
1062
+ w.tensor = w_and_b_out
1063
+ except (ValueError, TypeError, OverflowError, Exception) as e:
1064
+ raise SerializationError(
1065
+ tensor_name=getattr(w, "name", None),
1066
+ reason=str(e),
1067
+ ) from e
1068
+
1069
+ inputs = []
1070
+ outputs = []
1071
+ for graph_input in self.model.graph.input:
1072
+ shape = output_name_to_shape.get(graph_input.name, [])
1073
+ elem_type = getattr(graph_input, "elem_type", -1)
1074
+ inputs.append(ONNXIO(graph_input.name, elem_type, shape))
1075
+
1076
+ for output in self.model.graph.output:
1077
+ shape = output_name_to_shape.get(output.name, [])
1078
+ elem_type = getattr(output, "elem_type", -1)
1079
+ outputs.append(ONNXIO(output.name, elem_type, shape))
1080
+
1081
+ # Get version from package metadata
1082
+ try:
1083
+ version = get_version(PACKAGE_NAME)
1084
+ except Exception:
1085
+ version = "0.0.0"
1086
+
1087
+ architecture = {
1088
+ "architecture": [asdict(a) for a in architecture],
1089
+ }
1090
+ weights = {"w_and_b": [asdict(w_b) for w_b in w_and_b]}
1091
+ circuit_params = {
1092
+ "scale_base": getattr(self, "scale_base", 2),
1093
+ "scale_exponent": getattr(self, "scale_exponent", 18),
1094
+ "rescale_config": getattr(self, "rescale_config", {}),
1095
+ "inputs": [asdict(i) for i in inputs],
1096
+ "outputs": [asdict(o) for o in outputs],
1097
+ "version": version,
1098
+ }
1099
+ return architecture, weights, circuit_params
1100
+
1101
+ def get_model_and_quantize(self: ONNXConverter) -> None:
1102
+ """Load the configured model (by path) and build its quantized form.
1103
+
1104
+ Expects the instance to define ``self.model_file_name`` beforehand.
1105
+
1106
+ Raises:
1107
+ FileNotFoundError: If ``self.model_file_name`` is unset or invalid.
1108
+ """
1109
+ if hasattr(self, "model_file_name"):
1110
+ self.load_model(self.model_file_name)
1111
+ else:
1112
+ msg = "An ONNX model is required at the specified path"
1113
+ raise FileNotFoundError(msg)
1114
+ self.quantized_model = self.quantize_model(
1115
+ self.model,
1116
+ getattr(self, "scale_base", 2),
1117
+ getattr(self, "scale_exponent", 18),
1118
+ rescale_config=getattr(self, "rescale_config", {}),
1119
+ )
1120
+
1121
+ def get_outputs(
1122
+ self: ONNXConverter,
1123
+ inputs: np.ndarray | torch.Tensor,
1124
+ ) -> list[np.ndarray]:
1125
+ """Run the currently loaded (quantized) model via ONNX Runtime.
1126
+
1127
+ Args:
1128
+ inputs (Any): Input array/tensor matching the models first input.
1129
+
1130
+ Returns:
1131
+ Any: The output of the onnxruntime inference.
1132
+ """
1133
+ try:
1134
+ input_name = self.ort_sess.get_inputs()[0].name
1135
+ output_name = self.ort_sess.get_outputs()[0].name
1136
+
1137
+ # TODO @jsgold-1: This may cause some rounding errors at some point but works for now. # noqa: FIX002, E501, TD003
1138
+ inputs = torch.as_tensor(inputs)
1139
+ if inputs.dtype in (
1140
+ torch.int8,
1141
+ torch.int16,
1142
+ torch.int32,
1143
+ torch.int64,
1144
+ torch.uint8,
1145
+ ):
1146
+ inputs = inputs.double()
1147
+ inputs = inputs / BaseOpQuantizer.get_scaling(
1148
+ scale_base=self.scale_base,
1149
+ scale_exponent=self.scale_exponent,
1150
+ )
1151
+ if self.ort_sess.get_inputs()[0].type == "tensor(double)":
1152
+ return self.ort_sess.run(
1153
+ [output_name],
1154
+ {input_name: np.asarray(inputs).astype(np.float64)},
1155
+ )
1156
+ return self.ort_sess.run(
1157
+ [output_name],
1158
+ {input_name: np.asarray(inputs)},
1159
+ )
1160
+ except (RuntimeError, ValueError, TypeError, Exception) as e:
1161
+ raise InferenceError(
1162
+ model_path=getattr(self, "quantized_model_path", None),
1163
+ model_type=self.model_type,
1164
+ reason=str(e),
1165
+ ) from e
1166
+
1167
+
1168
+ if __name__ == "__main__":
1169
+ path = "./models_onnx/doom.onnx"
1170
+
1171
+ converter = ONNXConverter()
1172
+ converter.model_file_name, converter.quantized_model_file_name = (
1173
+ path,
1174
+ "quantized_doom.onnx",
1175
+ )
1176
+ converter.scale_base, converter.scale_exponent = 2, 18
1177
+
1178
+ converter.load_model(path)
1179
+ converter.get_model_and_quantize()
1180
+
1181
+ converter.test_accuracy()