JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +5 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,43 @@
1
+ import numpy as np
2
+ from onnxruntime_extensions import PyCustomOpDef, onnx_op
3
+
4
+
5
+ @onnx_op(
6
+ op_type="Int64Relu",
7
+ domain="ai.onnx.contrib",
8
+ inputs=[PyCustomOpDef.dt_int64],
9
+ outputs=[PyCustomOpDef.dt_int64],
10
+ )
11
+ def int64_relu(x: np.ndarray) -> np.ndarray:
12
+ """
13
+ Performs a ReLU operation on int64 input tensors.
14
+
15
+ This function is registered as a custom ONNX operator via onnxruntime_extensions
16
+ and is used in the JSTprove quantized inference pipeline.
17
+ It applies ReLU as is (there are no attributes to ReLU).
18
+
19
+ Parameters
20
+ ----------
21
+ X : Input tensor with dtype int64.
22
+
23
+ Returns
24
+ -------
25
+ numpy.ndarray
26
+ ReLU tensor with dtype int64.
27
+
28
+ Notes
29
+ -----
30
+ - This op is part of the `ai.onnx.contrib` custom domain.
31
+ - ONNX Runtime Extensions is required to register this op.
32
+
33
+ References
34
+ ----------
35
+ For more information on the ReLU operation, please refer to the
36
+ ONNX standard ReLU operator documentation:
37
+ https://onnx.ai/onnx/operators/onnx__Relu.html
38
+ """
39
+ try:
40
+ return np.maximum(x, 0).astype(np.int64)
41
+ except Exception as e:
42
+ msg = f"Int64Gemm failed: {e}"
43
+ raise RuntimeError(msg) from e
@@ -0,0 +1,168 @@
1
+ from __future__ import annotations
2
+
3
+ REPORTING_URL = " https://discord.com/invite/inferencelabs"
4
+
5
+
6
+ class QuantizationError(Exception):
7
+ """
8
+ Base exception class for errors raised during model quantization.
9
+ Can be extended for specific quantization-related errors.
10
+ """
11
+
12
+ GENERIC_MESSAGE = (
13
+ "\nThis model is not supported by JSTprove."
14
+ f"\nSubmit model support requests via the JSTprove channel: {REPORTING_URL},"
15
+ )
16
+
17
+ def __init__(self: Exception, message: str) -> None:
18
+ """Initialize QuantizationError with a detailed message.
19
+
20
+ Args:
21
+ message (str): Specific error message describing the quantization issue.
22
+ """
23
+ full_msg = f"{self.GENERIC_MESSAGE}\n\n{message}"
24
+ super().__init__(full_msg)
25
+
26
+
27
+ class InvalidParamError(QuantizationError):
28
+ """
29
+ Exception raised when invalid parameters or unsupported
30
+ parameters are encountered in a node that is reached during
31
+ quantization the quantization process.
32
+ """
33
+
34
+ def __init__( # noqa: PLR0913
35
+ self: QuantizationError,
36
+ node_name: str,
37
+ op_type: str,
38
+ message: str,
39
+ attr_key: str | None = None,
40
+ expected: str | None = None,
41
+ ) -> None:
42
+ """Initialize InvalidParamError with context about the invalid parameter.
43
+
44
+ Args:
45
+ node_name (str): The name of the node where the error occurred.
46
+ op_type (str): The type of operation of the node.
47
+ message (str): Description of the invalid parameter error.
48
+ attr_key (str, optional): The attribute key that caused the error.
49
+ Defaults to None.
50
+ expected (str, optional): The expected value or format for the attribute.
51
+ Defaults to None.
52
+ """
53
+ self.node_name = node_name
54
+ self.op_type = op_type
55
+ self.message = message
56
+ self.attr_key = attr_key
57
+ self.expected = expected
58
+
59
+ error_msg = (
60
+ f"Invalid parameters in node '{node_name}' "
61
+ f"(op_type='{op_type}'): {message}"
62
+ )
63
+ if attr_key:
64
+ error_msg += f" [Attribute: {attr_key}]"
65
+ if expected:
66
+ error_msg += f" [Expected: {expected}]"
67
+ super().__init__(error_msg)
68
+
69
+
70
+ class UnsupportedOpError(QuantizationError):
71
+ """
72
+ Exception to be raised when an unsupported operation type is
73
+ reached during quantization.
74
+ """
75
+
76
+ def __init__(
77
+ self: QuantizationError,
78
+ op_type: str,
79
+ node_name: str | None = None,
80
+ ) -> None:
81
+ """Initialize UnsupportedOpError with details about the unsupported operation.
82
+
83
+ Args:
84
+ op_type (str): The type of the unsupported operation.
85
+ node_name (str, optional): The name of the node where the
86
+ unsupported operation was found to help with debugging.
87
+ Defaults to None.
88
+ """
89
+ error_msg = f"Unsupported op type: '{op_type}'"
90
+ if node_name:
91
+ error_msg += f" in node '{node_name}'"
92
+ error_msg += ". Please check out the documentation for supported layers."
93
+ self.unsupported_ops = op_type
94
+ super().__init__(error_msg)
95
+
96
+
97
+ class MissingHandlerError(QuantizationError):
98
+ """
99
+ Raised when no handler is registered for an operator.
100
+ """
101
+
102
+ def __init__(self: QuantizationError, op_type: str) -> None:
103
+ error_msg = f"No quantization handler registered for operator type '{op_type}'."
104
+ super().__init__(error_msg)
105
+
106
+
107
+ class InitializerNotFoundError(QuantizationError):
108
+ """
109
+ Raised when an initializer required by a node is missing from the initializer map.
110
+ """
111
+
112
+ def __init__(
113
+ self: QuantizationError,
114
+ node_name: str,
115
+ initializer_name: str,
116
+ ) -> None:
117
+ error_msg = (
118
+ f"Initializer '{initializer_name}' required for node '{node_name}' "
119
+ "was not found in the initializer map."
120
+ )
121
+ super().__init__(error_msg)
122
+
123
+
124
+ class HandlerImplementationError(QuantizationError):
125
+ """
126
+ Raised when a handler does not conform to the expected quantizer interface.
127
+ For example, missing 'quantize' method, wrong return type, etc.
128
+ """
129
+
130
+ def __init__(self: QuantizationError, op_type: str, message: str) -> None:
131
+ error_msg = f"Handler implementation error for operator '{op_type}': {message}"
132
+ super().__init__(error_msg)
133
+
134
+
135
+ class InvalidGraphError(QuantizationError):
136
+ """
137
+ Raised when the ONNX graph is malformed or missing critical information.
138
+ """
139
+
140
+ def __init__(self: QuantizationError, message: str) -> None:
141
+ error_msg = f"Invalid ONNX graph structure: {message}"
142
+ super().__init__(error_msg)
143
+
144
+
145
+ class InvalidConfigError(QuantizationError):
146
+ """
147
+ Exception raised when the overall quantization configuration is invalid
148
+ or unsupported (e.g., bad scale_base, scale_exponent, or global settings).
149
+ """
150
+
151
+ def __init__(
152
+ self: QuantizationError,
153
+ key: str,
154
+ value: str | float | bool | None,
155
+ expected: str | None = None,
156
+ ) -> None:
157
+ """Initialize InvalidConfigError with context about the bad config.
158
+
159
+ Args:
160
+ key (str): The name of the configuration parameter.
161
+ value (Union[str, int, float, bool, None]):
162
+ The invalid value that was provided.
163
+ expected (str, optional): Description of the expected valid range or type.
164
+ """
165
+ error_msg = f"Invalid configuration for '{key}': got {value}"
166
+ if expected:
167
+ error_msg += f" (expected {expected})"
168
+ super().__init__(error_msg)
@@ -0,0 +1,396 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+ import onnx
7
+ from onnx import helper, numpy_helper
8
+
9
+ from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
10
+ replace_input_references,
11
+ )
12
+ from python.core.model_processing.onnx_quantizer.exceptions import (
13
+ HandlerImplementationError,
14
+ InitializerNotFoundError,
15
+ InvalidConfigError,
16
+ InvalidParamError,
17
+ )
18
+
19
+
20
+ @dataclass
21
+ class ScaleConfig:
22
+ exponent: int
23
+ base: int
24
+ rescale: bool
25
+
26
+
27
+ class BaseOpQuantizer:
28
+ """
29
+ Abstract base class for ONNX operator quantizers.
30
+
31
+ Subclasses must implement:
32
+ - `quantize`: Apply quantization logic to an ONNX node.
33
+ - `check_supported`: Checks if the layer and param specs are supported.
34
+
35
+ Attributes:
36
+ new_initializers (list[onnx.TensorProto]):
37
+ A list of initializers created during quantization.
38
+ These should be added to the graph after processing.
39
+ """
40
+
41
+ def __init__(self: BaseOpQuantizer) -> None:
42
+ self.new_initializers: list[onnx.TensorProto] = []
43
+
44
+ @staticmethod
45
+ def get_scaling(scale_base: int, scale_exponent: int) -> int:
46
+ """Validate and compute the scaling factor.
47
+
48
+ Args:
49
+ scale_base (int): Base for the scaling exponent.
50
+ scale_exponent (int): Scaling exponent.
51
+
52
+ Returns:
53
+ int: The computed scaling factor (scale_base ** scale_exponent).
54
+
55
+ Raises:
56
+ InvalidConfigError: If parameters are invalid.
57
+ """
58
+ if scale_base <= 0:
59
+ key = "scale_base"
60
+ raise InvalidConfigError(key, scale_base, expected="> 0")
61
+ if scale_exponent < 0:
62
+ key = "scale_exponent"
63
+ raise InvalidConfigError(key, scale_exponent, expected=">= 0")
64
+
65
+ try:
66
+ return scale_base**scale_exponent
67
+ except (TypeError, OverflowError, ValueError, Exception) as e:
68
+ key = "scaling"
69
+ raise InvalidConfigError(
70
+ key,
71
+ f"{scale_base}^{scale_exponent}",
72
+ str(e),
73
+ ) from e
74
+
75
+ @staticmethod
76
+ def validate_node_has_output(node: onnx.NodeProto) -> None:
77
+ """Ensure a node has at least one output.
78
+
79
+ Args:
80
+ node (onnx.NodeProto): The node to validate.
81
+ op_type (str): Name of the operator type for error reporting.
82
+
83
+ Raises:
84
+ HandlerImplementationError: If the node has no outputs.
85
+ """
86
+ if not node.output or len(node.output) == 0:
87
+ raise HandlerImplementationError(
88
+ op_type=node.op_type,
89
+ message=f"Node '{node.name or '<unnamed>'}' of type '{node.op_type}'"
90
+ " has no outputs.",
91
+ )
92
+
93
+ @staticmethod
94
+ def validate_required_attrs(
95
+ node: onnx.NodeProto,
96
+ required_attrs: list[str],
97
+ ) -> None:
98
+ """
99
+ Ensure that a node contains all required attributes.
100
+
101
+ Args:
102
+ node (onnx.NodeProto): The ONNX node to validate.
103
+ required_attrs (list[str]): list of attribute names that must exist.
104
+ op_type (str): Name of the operator type for error reporting.
105
+
106
+ Raises:
107
+ InvalidParamError: If any required attribute is missing.
108
+ """
109
+ missing_attrs = []
110
+ for attr_name in required_attrs:
111
+ found = any(attr.name == attr_name for attr in node.attribute)
112
+ if not found:
113
+ missing_attrs.append(attr_name)
114
+
115
+ if missing_attrs:
116
+ missing_str = ", ".join(missing_attrs)
117
+ raise InvalidParamError(
118
+ node_name=node.name,
119
+ op_type=node.op_type,
120
+ message=f"Missing required attributes: {missing_str}",
121
+ )
122
+
123
+ def quantize(
124
+ self: BaseOpQuantizer,
125
+ node: onnx.NodeProto,
126
+ graph: onnx.GraphProto,
127
+ scale_config: ScaleConfig,
128
+ initializer_map: dict[str, onnx.TensorProto],
129
+ ) -> list[onnx.NodeProto]:
130
+ """
131
+ Quantize the given node.
132
+
133
+ Must be implemented by subclasses.
134
+
135
+ Raises:
136
+ HandlerImplementationError: If subclass does not implement quantize
137
+ """
138
+ _ = node, graph, scale_config, initializer_map
139
+ raise HandlerImplementationError(
140
+ op_type=self.__class__.__name__,
141
+ message="quantize() not implemented in subclass.",
142
+ )
143
+
144
+ def check_supported(
145
+ self: BaseOpQuantizer,
146
+ node: onnx.NodeProto,
147
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
148
+ ) -> str | None:
149
+ """
150
+ Check if the node is supported by the quantizer.
151
+
152
+ Must be overridden by subclasses to validate parameters.
153
+
154
+ Raises:
155
+ HandlerImplementationError: If called on BaseOpQuantizer directly.
156
+ """
157
+ _ = node, initializer_map
158
+ raise HandlerImplementationError(
159
+ op_type=self.__class__.__name__,
160
+ message="check_supported() not implemented in subclass.",
161
+ )
162
+
163
+ def rescale_layer(
164
+ self: BaseOpQuantizer,
165
+ node: onnx.NodeProto,
166
+ scale_base: int,
167
+ scale_exponent: int,
168
+ graph: onnx.GraphProto,
169
+ ) -> list[onnx.NodeProto]:
170
+ """
171
+ Helper function for any quantizer.
172
+ Used to add a rescaling step after the given node.
173
+
174
+ This replaces the node's output with a scaled version using a Div op.
175
+ This function incorporates the logic to insert and restructure the graph.
176
+
177
+ Args:
178
+ node (onnx.NodeProto): Node to rescale.
179
+ scale_base (int): Base for the scaling exponent.
180
+ scale_exponent (int): Scaling exponent.
181
+ graph (onnx.GraphProto): The ONNX graph.
182
+
183
+ Returns:
184
+ list[onnx.NodeProto]: Original node and the inserted Div node.
185
+
186
+ Raises:
187
+ HandlerImplementationError if there are no outputs to be rescaled
188
+ """
189
+ self.validate_node_has_output(node)
190
+
191
+ original_output = node.output.get(0)
192
+ quantized_output = original_output + "_raw"
193
+ node.output[0] = quantized_output
194
+
195
+ # Create scale constant initializer
196
+ scale_const_name = node.name + "_scale"
197
+
198
+ scale_value = self.get_scaling(scale_base, scale_exponent)
199
+ scale_tensor = numpy_helper.from_array(
200
+ np.array([scale_value], dtype=np.int64),
201
+ name=scale_const_name,
202
+ )
203
+ self.new_initializers.append(scale_tensor)
204
+
205
+ # Create Div node for rescaling output
206
+ div_node = helper.make_node(
207
+ "Div",
208
+ inputs=[quantized_output, scale_const_name],
209
+ outputs=[original_output], # restore original output name
210
+ name=node.name + "_rescale",
211
+ )
212
+
213
+ # Rewire consumers to point to the new output
214
+ replace_input_references(
215
+ graph=graph,
216
+ old_output=original_output,
217
+ new_output=div_node.output[0],
218
+ )
219
+
220
+ return [node, div_node]
221
+
222
+ def add_nodes_w_and_b(
223
+ self: BaseOpQuantizer,
224
+ node: onnx.NodeProto,
225
+ scale_exponent: int,
226
+ scale_base: int,
227
+ initializer_map: dict[str, onnx.TensorProto],
228
+ ) -> tuple[list[onnx.NodeProto], list[str]]:
229
+ """Insert scaling and casting nodes for weight and bias,
230
+ to convert from float to scaled int64 values.
231
+
232
+ Args:
233
+ node (onnx.NodeProto): Node to find used weights and biases.
234
+ scale_exponent (int): Scaling exponent.
235
+ scale_base (int): Base for the scaling exponent.
236
+ initializer_map (dict[str, onnx.TensorProto]): The initializer map.
237
+ graph (onnx.GraphProto): ONNX Graph
238
+
239
+ Returns:
240
+ tuple[list[onnx.NodeProto], list[str]]:
241
+ list of new nodes added, updated input names for nodes.
242
+
243
+ Raises:
244
+ InitializerNotFoundError: If weights or biases are missing from the graph.
245
+ HandlerImplementationError:
246
+ If there are no weights or biases to add to the graph.
247
+ """
248
+ weights_input_length = 2
249
+ if len(node.input) < weights_input_length:
250
+ raise HandlerImplementationError(
251
+ op_type=node.op_type,
252
+ message=f"Node '{node.name}'"
253
+ " has fewer than 2 inputs (weights missing).",
254
+ )
255
+ # Quantize weight
256
+ weight_name = node.input[1]
257
+ if not weight_name or weight_name not in initializer_map:
258
+ raise InitializerNotFoundError(node.name, weight_name or "<missing>")
259
+
260
+ weight_tensor = initializer_map[weight_name]
261
+ if not weight_tensor.name:
262
+ raise HandlerImplementationError(
263
+ op_type=node.op_type,
264
+ message=f"Weight tensor for node '{node.name}' is missing a name.",
265
+ )
266
+
267
+ quant_weight_name, mul_node, cast_node = self.insert_scale_node(
268
+ tensor=weight_tensor,
269
+ scale_base=scale_base,
270
+ scale_exponent=scale_exponent,
271
+ )
272
+
273
+ # Quantize bias if present
274
+ new_inputs = [node.input[0], quant_weight_name]
275
+ nodes = [mul_node, cast_node]
276
+
277
+ bias_inputs_length = 3
278
+
279
+ if len(node.input) >= bias_inputs_length:
280
+ bias_name = node.input[2]
281
+ if bias_name not in initializer_map:
282
+ raise InitializerNotFoundError(node.name, bias_name)
283
+
284
+ bias_tensor = initializer_map[bias_name]
285
+ quant_bias_name, mul_node_2, cast_node_2 = self.insert_scale_node(
286
+ tensor=bias_tensor,
287
+ scale_base=scale_base,
288
+ scale_exponent=(scale_exponent * 2),
289
+ )
290
+ new_inputs.append(quant_bias_name)
291
+ nodes.append(mul_node_2)
292
+ nodes.append(cast_node_2)
293
+
294
+ # === Mutate the original node ===
295
+ return nodes, new_inputs
296
+
297
+ def insert_scale_node(
298
+ self: BaseOpQuantizer,
299
+ tensor: onnx.TensorProto,
300
+ scale_base: int,
301
+ scale_exponent: int,
302
+ ) -> tuple[str, onnx.NodeProto, onnx.NodeProto]:
303
+ """Insert Mul and Cast nodes to apply scaling to a tensor.
304
+
305
+ Args:
306
+ tensor (onnx.TensorProto): Tensor to scale.
307
+ scale_base (int): Base for scaling exponent.
308
+ scale_exponent (int): Scaling exponent.
309
+ graph (onnx.GraphProto): ONNX graph.
310
+
311
+ Returns:
312
+ tuple[str, onnx.NodeProto, onnx.NodeProto]:
313
+ New tensor name, Mul node, Cast node.
314
+
315
+ Raises:
316
+ HandlerImplementationError:
317
+ If tensor does not exist, incorrectly formatted or not named
318
+ """
319
+ if not tensor or not isinstance(tensor, onnx.TensorProto):
320
+ raise HandlerImplementationError(
321
+ op_type="insert_scale_node",
322
+ message="Expected a valid onnx.TensorProto, got None or wrong type.",
323
+ )
324
+
325
+ if not tensor.name:
326
+ raise HandlerImplementationError(
327
+ op_type="insert_scale_node",
328
+ message="Tensor is missing a name.",
329
+ )
330
+
331
+ scale_value = self.get_scaling(scale_base, scale_exponent)
332
+
333
+ # Create scale constant
334
+ scale_const_name = tensor.name + "_scale"
335
+ scale_tensor = numpy_helper.from_array(
336
+ np.array([scale_value], dtype=np.float64),
337
+ name=scale_const_name,
338
+ )
339
+ self.new_initializers.append(scale_tensor)
340
+
341
+ # Add Mul node
342
+ scaled_output_name = f"{tensor.name}_scaled"
343
+ mul_node = helper.make_node(
344
+ "Mul",
345
+ inputs=[tensor.name, scale_const_name],
346
+ outputs=[scaled_output_name],
347
+ name=f"{tensor.name}_mul",
348
+ )
349
+
350
+ # Add cast node
351
+ output_name = f"{scaled_output_name}_cast"
352
+ rounded_output_name = scaled_output_name
353
+ cast_to_int64 = helper.make_node(
354
+ "Cast",
355
+ inputs=[scaled_output_name],
356
+ outputs=[output_name],
357
+ to=onnx.TensorProto.INT64,
358
+ name=rounded_output_name,
359
+ )
360
+ return output_name, mul_node, cast_to_int64
361
+
362
+
363
+ class PassthroughQuantizer(BaseOpQuantizer):
364
+ """
365
+ Quantizer that leaves the node unchanged.
366
+ Useful for operators that do not require quantization, such as shaping operations.
367
+ """
368
+
369
+ def __init__(
370
+ self: BaseOpQuantizer,
371
+ new_initializer: dict[str, onnx.TensorProto] | None = None,
372
+ ) -> None:
373
+ _ = new_initializer
374
+ super().__init__()
375
+
376
+ def quantize(
377
+ self: BaseOpQuantizer,
378
+ node: onnx.NodeProto,
379
+ graph: onnx.GraphProto,
380
+ scale_config: ScaleConfig,
381
+ initializer_map: dict[str, onnx.TensorProto],
382
+ ) -> list[onnx.NodeProto]:
383
+ _ = graph, scale_config, initializer_map
384
+ if not isinstance(node, onnx.NodeProto):
385
+ raise HandlerImplementationError(
386
+ op_type="PassthroughQuantizer",
387
+ message="quantize() expected a NodeProto",
388
+ )
389
+ return [node]
390
+
391
+ def check_supported(
392
+ self: BaseOpQuantizer,
393
+ node: onnx.NodeProto,
394
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
395
+ ) -> None:
396
+ _ = node, initializer_map