JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (61) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
  2. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
  3. python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +114 -59
  7. python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
  8. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  9. python/core/model_processing/onnx_custom_ops/mul.py +66 -0
  10. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  11. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  12. python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
  13. python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
  14. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  15. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  16. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  17. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  18. python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
  19. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  20. python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
  21. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
  22. python/core/utils/general_layer_functions.py +17 -12
  23. python/core/utils/model_registry.py +6 -3
  24. python/scripts/gen_and_bench.py +2 -2
  25. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  26. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  27. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  28. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  29. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  30. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  31. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  32. python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
  33. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  34. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  35. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  36. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  37. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  38. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  39. python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
  40. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  41. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  42. python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
  43. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  44. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  45. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  46. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  47. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  48. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  49. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
  50. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  51. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  52. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  53. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  54. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  55. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  56. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  57. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  58. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
  59. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
  60. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
  61. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,10 @@ import logging
5
5
  from dataclasses import asdict, dataclass
6
6
  from importlib.metadata import version as get_version
7
7
  from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ if TYPE_CHECKING:
11
+ from onnxruntime import NodeArg
8
12
 
9
13
  import numpy as np
10
14
  import onnx
@@ -18,6 +22,7 @@ from onnxruntime_extensions import get_library_path
18
22
 
19
23
  import python.core.model_processing.onnx_custom_ops # noqa: F401
20
24
  from python.core import PACKAGE_NAME
25
+ from python.core.circuits.errors import CircuitConfigurationError
21
26
  from python.core.model_processing.converters.base import ModelConverter, ModelType
22
27
  from python.core.model_processing.errors import (
23
28
  InferenceError,
@@ -242,6 +247,7 @@ class ONNXConverter(ModelConverter):
242
247
 
243
248
  def analyze_layers(
244
249
  self: ONNXConverter,
250
+ model: onnx.ModelProto,
245
251
  output_name_to_shape: dict[str, list[int]] | None = None,
246
252
  ) -> tuple[list[ONNXLayer], list[ONNXLayer]]:
247
253
  """Analyze the onnx model graph into
@@ -263,29 +269,29 @@ class ONNXConverter(ModelConverter):
263
269
  id_count = 0
264
270
  # Apply shape inference on the model
265
271
  if not output_name_to_shape:
266
- inferred_model = shape_inference.infer_shapes(self.model)
272
+ inferred_model = shape_inference.infer_shapes(model)
267
273
  self._onnx_check_model_safely(inferred_model)
268
274
 
269
275
  output_name_to_shape = extract_shape_dict(inferred_model)
270
276
  domain_to_version = {
271
- opset.domain: opset.version for opset in self.model.opset_import
277
+ opset.domain: opset.version for opset in model.opset_import
272
278
  }
273
279
 
274
280
  id_count = 0
275
281
  architecture = self.get_model_architecture(
276
- self.model,
282
+ model,
277
283
  output_name_to_shape,
278
284
  domain_to_version,
279
285
  )
280
286
  w_and_b = self.get_model_w_and_b(
281
- self.model,
287
+ model,
282
288
  output_name_to_shape,
283
289
  id_count,
284
290
  domain_to_version,
285
291
  )
286
292
  except InvalidModelError:
287
293
  raise
288
- except (ValueError, TypeError, RuntimeError, OSError, onnx.ONNXException) as e:
294
+ except (ValueError, TypeError, RuntimeError, OSError) as e:
289
295
  raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
290
296
  except Exception as e:
291
297
  raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
@@ -508,9 +514,9 @@ class ONNXConverter(ModelConverter):
508
514
  opts,
509
515
  providers=["CPUExecutionProvider"],
510
516
  )
511
- except (OSError, onnx.ONNXException, RuntimeError, Exception) as e:
517
+ except (OSError, RuntimeError, Exception) as e:
512
518
  raise InferenceError(
513
- model_path,
519
+ model_path=model_path,
514
520
  model_type=self.model_type,
515
521
  reason=str(e),
516
522
  ) from e
@@ -552,6 +558,7 @@ class ONNXConverter(ModelConverter):
552
558
  output_shapes = {
553
559
  out_name: output_name_to_shape.get(out_name, []) for out_name in outputs
554
560
  }
561
+
555
562
  return ONNXLayer(
556
563
  id=layer_id,
557
564
  name=name,
@@ -600,6 +607,7 @@ class ONNXConverter(ModelConverter):
600
607
  np_data = onnx.numpy_helper.to_array(node, constant_dtype)
601
608
  except (ValueError, TypeError, onnx.ONNXException, Exception) as e:
602
609
  raise SerializationError(
610
+ model_type=self.model_type,
603
611
  tensor_name=node.name,
604
612
  reason=f"Failed to convert tensor: {e!s}",
605
613
  ) from e
@@ -877,8 +885,10 @@ class ONNXConverter(ModelConverter):
877
885
  OSError,
878
886
  Exception,
879
887
  ) as e:
880
- msg = "Quantization failed for model"
881
- f" '{getattr(self, 'model_file_name', 'unknown')}': {e!s}"
888
+ msg = (
889
+ "Quantization failed for model"
890
+ f" '{getattr(self, 'model_file_name', 'unknown')}': {e!s}"
891
+ )
882
892
  raise ModelConversionError(
883
893
  msg,
884
894
  model_type=self.model_type,
@@ -1033,38 +1043,36 @@ class ONNXConverter(ModelConverter):
1033
1043
  ``rescale_config``.
1034
1044
  """
1035
1045
  inferred_model = shape_inference.infer_shapes(self.model)
1036
-
1037
- scaling = BaseOpQuantizer.get_scaling(
1038
- scale_base=getattr(self, "scale_base", 2),
1039
- scale_exponent=(getattr(self, "scale_exponent", 18)),
1040
- )
1046
+ scale_base = getattr(self, "scale_base", 2)
1047
+ scale_exponent = getattr(self, "scale_exponent", 18)
1041
1048
 
1042
1049
  # Check the model and print Y"s shape information
1043
1050
  self._onnx_check_model_safely(inferred_model)
1044
1051
  output_name_to_shape = extract_shape_dict(inferred_model)
1045
- (architecture, w_and_b) = self.analyze_layers(output_name_to_shape)
1046
- for w in w_and_b:
1052
+ scaled_and_transformed_model = self.op_quantizer.apply_pre_analysis_transforms(
1053
+ inferred_model,
1054
+ scale_exponent=scale_exponent,
1055
+ scale_base=scale_base,
1056
+ )
1057
+ # Get layers in correct format
1058
+ (architecture, w_and_b) = self.analyze_layers(
1059
+ scaled_and_transformed_model,
1060
+ output_name_to_shape,
1061
+ )
1062
+
1063
+ def _convert_tensor_to_int_list(w: ONNXLayer) -> list:
1047
1064
  try:
1048
- w_and_b_array = np.asarray(w.tensor)
1049
- except (ValueError, TypeError, Exception) as e:
1065
+ arr = np.asarray(w.tensor).astype(np.int64)
1066
+ return arr.tolist()
1067
+ except Exception as e:
1050
1068
  raise SerializationError(
1051
1069
  tensor_name=getattr(w, "name", None),
1070
+ model_type=self.model_type,
1052
1071
  reason=f"cannot convert to ndarray: {e}",
1053
1072
  ) from e
1054
1073
 
1055
- try:
1056
- # TODO @jsgold-1: We need a better way to distinguish bias tensors from weight tensors # noqa: FIX002, TD003,E501
1057
- if "bias" in w.name:
1058
- w_and_b_scaled = w_and_b_array * scaling * scaling
1059
- else:
1060
- w_and_b_scaled = w_and_b_array * scaling
1061
- w_and_b_out = w_and_b_scaled.astype(np.int64).tolist()
1062
- w.tensor = w_and_b_out
1063
- except (ValueError, TypeError, OverflowError, Exception) as e:
1064
- raise SerializationError(
1065
- tensor_name=getattr(w, "name", None),
1066
- reason=str(e),
1067
- ) from e
1074
+ for w in w_and_b:
1075
+ w.tensor = _convert_tensor_to_int_list(w)
1068
1076
 
1069
1077
  inputs = []
1070
1078
  outputs = []
@@ -1118,45 +1126,92 @@ class ONNXConverter(ModelConverter):
1118
1126
  rescale_config=getattr(self, "rescale_config", {}),
1119
1127
  )
1120
1128
 
1129
+ def _process_single_input_for_get_outputs(
1130
+ self: ONNXConverter,
1131
+ value: np.ndarray | torch.Tensor,
1132
+ input_def: NodeArg,
1133
+ ) -> np.ndarray:
1134
+ """Process a single input tensor according to dtype and scale settings."""
1135
+ value = torch.as_tensor(value)
1136
+
1137
+ if value.dtype in (
1138
+ torch.int8,
1139
+ torch.int16,
1140
+ torch.int32,
1141
+ torch.int64,
1142
+ torch.uint8,
1143
+ ):
1144
+ value = value.double()
1145
+ value = value / BaseOpQuantizer.get_scaling(
1146
+ scale_base=self.scale_base,
1147
+ scale_exponent=self.scale_exponent,
1148
+ )
1149
+
1150
+ if input_def.type == "tensor(double)":
1151
+ return np.asarray(value).astype(np.float64)
1152
+ return np.asarray(value)
1153
+
1121
1154
  def get_outputs(
1122
1155
  self: ONNXConverter,
1123
- inputs: np.ndarray | torch.Tensor,
1156
+ inputs: np.ndarray | torch.Tensor | dict[str, np.ndarray | torch.Tensor],
1124
1157
  ) -> list[np.ndarray]:
1125
1158
  """Run the currently loaded (quantized) model via ONNX Runtime.
1126
1159
 
1127
1160
  Args:
1128
- inputs (Any): Input array/tensor matching the models first input.
1161
+ inputs: Single tensor/array or a dict of named inputs.
1129
1162
 
1130
1163
  Returns:
1131
- Any: The output of the onnxruntime inference.
1164
+ list[np.ndarray]: List of output arrays from ONNX Runtime inference.
1132
1165
  """
1166
+
1167
+ def _raise_type_error(inputs: np.ndarray | torch.Tensor) -> None:
1168
+ msg = (
1169
+ "Expected np.ndarray, torch.Tensor, or dict "
1170
+ f"for inputs, got {type(inputs)}"
1171
+ )
1172
+ raise TypeError(msg)
1173
+
1174
+ def _raise_value_error(msg: str) -> None:
1175
+ raise ValueError(msg)
1176
+
1177
+ def _raise_no_scale_configs() -> None:
1178
+ raise CircuitConfigurationError(
1179
+ missing_attributes=["scale_base", "scale_exponent"],
1180
+ )
1181
+
1182
+ scale_base = getattr(self, "scale_base", None)
1183
+ scale_exponent = getattr(self, "scale_exponent", None)
1184
+
1133
1185
  try:
1134
- input_name = self.ort_sess.get_inputs()[0].name
1135
- output_name = self.ort_sess.get_outputs()[0].name
1136
-
1137
- # TODO @jsgold-1: This may cause some rounding errors at some point but works for now. # noqa: FIX002, E501, TD003
1138
- inputs = torch.as_tensor(inputs)
1139
- if inputs.dtype in (
1140
- torch.int8,
1141
- torch.int16,
1142
- torch.int32,
1143
- torch.int64,
1144
- torch.uint8,
1145
- ):
1146
- inputs = inputs.double()
1147
- inputs = inputs / BaseOpQuantizer.get_scaling(
1148
- scale_base=self.scale_base,
1149
- scale_exponent=self.scale_exponent,
1186
+ input_defs = self.ort_sess.get_inputs()
1187
+ output_defs = self.ort_sess.get_outputs()
1188
+ output_names = [out.name for out in output_defs]
1189
+
1190
+ if scale_base is None or scale_exponent is None:
1191
+ _raise_no_scale_configs()
1192
+
1193
+ # Normalize inputs into a dict
1194
+ if isinstance(inputs, (np.ndarray, torch.Tensor)):
1195
+ input_name = input_defs[0].name
1196
+ inputs = {input_name: inputs}
1197
+ elif not isinstance(inputs, dict):
1198
+ _raise_type_error(inputs)
1199
+
1200
+ # Process inputs
1201
+ processed_inputs = {}
1202
+ for input_def in input_defs:
1203
+ name = input_def.name
1204
+ if name not in inputs:
1205
+ _raise_value_error(
1206
+ f"Missing required input '{name}' in provided inputs",
1207
+ )
1208
+ processed_inputs[name] = self._process_single_input_for_get_outputs(
1209
+ inputs[name],
1210
+ input_def,
1150
1211
  )
1151
- if self.ort_sess.get_inputs()[0].type == "tensor(double)":
1152
- return self.ort_sess.run(
1153
- [output_name],
1154
- {input_name: np.asarray(inputs).astype(np.float64)},
1155
- )
1156
- return self.ort_sess.run(
1157
- [output_name],
1158
- {input_name: np.asarray(inputs)},
1159
- )
1212
+
1213
+ return self.ort_sess.run(output_names, processed_inputs)
1214
+
1160
1215
  except (RuntimeError, ValueError, TypeError, Exception) as e:
1161
1216
  raise InferenceError(
1162
1217
  model_path=getattr(self, "quantized_model_path", None),
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from onnxruntime_extensions import PyCustomOpDef, onnx_op
5
+
6
+ from .custom_helpers import rescaling
7
+
8
+
9
+ @onnx_op(
10
+ op_type="Int64BatchNorm",
11
+ domain="ai.onnx.contrib",
12
+ inputs=[
13
+ PyCustomOpDef.dt_int64, # X (int64)
14
+ PyCustomOpDef.dt_int64, # mul (int64 scaled multiplier)
15
+ PyCustomOpDef.dt_int64, # add (int64 scaled adder)
16
+ PyCustomOpDef.dt_int64, # scaling_factor
17
+ ],
18
+ outputs=[PyCustomOpDef.dt_int64],
19
+ attrs={"rescale": PyCustomOpDef.dt_int64},
20
+ )
21
+ def int64_batchnorm(
22
+ x: np.ndarray,
23
+ mul: np.ndarray,
24
+ add: np.ndarray,
25
+ scaling_factor: np.ndarray | None = None,
26
+ rescale: int | None = None,
27
+ ) -> np.ndarray:
28
+ """
29
+ Int64 BatchNorm (folded into affine transform).
30
+
31
+ Computes:
32
+ Y = X * mul + add
33
+ where mul/add are already scaled to int64.
34
+
35
+ Parameters
36
+ ----------
37
+ x : Input int64 tensor
38
+ mul : Per-channel int64 scale multipliers
39
+ add : Per-channel int64 bias terms
40
+ scaling_factor: factor to rescale
41
+ rescale : Optional flag to apply post-scaling
42
+
43
+ Returns
44
+ -------
45
+ numpy.ndarray (int64)
46
+ """
47
+ try:
48
+ # Broadcasting shapes must match batchnorm layout: NCHW
49
+ # Typically mul/add have shape [C]
50
+ dims_x = len(x.shape)
51
+ dim_ones = (1,) * (dims_x - 2)
52
+ mul = mul.reshape(-1, *dim_ones)
53
+ add = add.reshape(-1, *dim_ones)
54
+
55
+ y = x * mul + add
56
+
57
+ if rescale is not None:
58
+ y = rescaling(scaling_factor, rescale, y)
59
+
60
+ return y.astype(np.int64)
61
+
62
+ except Exception as e:
63
+ msg = f"Int64BatchNorm failed: {e}"
64
+ raise RuntimeError(msg) from e
@@ -75,5 +75,5 @@ def int64_maxpool(
75
75
  )
76
76
  return result.numpy().astype(np.int64)
77
77
  except Exception as e:
78
- msg = f"Int64Gemm failed: {e}"
78
+ msg = f"Int64MaxPool failed: {e}"
79
79
  raise RuntimeError(msg) from e
@@ -0,0 +1,66 @@
1
+ import numpy as np
2
+ from onnxruntime_extensions import PyCustomOpDef, onnx_op
3
+
4
+ from .custom_helpers import rescaling
5
+
6
+
7
+ @onnx_op(
8
+ op_type="Int64Mul",
9
+ domain="ai.onnx.contrib",
10
+ inputs=[
11
+ PyCustomOpDef.dt_int64,
12
+ PyCustomOpDef.dt_int64,
13
+ PyCustomOpDef.dt_int64, # Scalar
14
+ ],
15
+ outputs=[PyCustomOpDef.dt_int64],
16
+ attrs={
17
+ "rescale": PyCustomOpDef.dt_int64,
18
+ },
19
+ )
20
+ def int64_mul(
21
+ a: np.ndarray,
22
+ b: np.ndarray,
23
+ scaling_factor: np.ndarray | None = None,
24
+ rescale: int | None = None,
25
+ ) -> np.ndarray:
26
+ """
27
+ Performs a Mul (hadamard product) operation on int64 input tensors.
28
+
29
+ This function is registered as a custom ONNX operator via onnxruntime_extensions
30
+ and is used in the JSTprove quantized inference pipeline.
31
+ It applies Mul with the rescaling the outputs back to the original scale.
32
+
33
+ Parameters
34
+ ----------
35
+ a : np.ndarray
36
+ First input tensor with dtype int64.
37
+ b : np.ndarray
38
+ Second input tensor with dtype int64.
39
+ scaling_factor : Scaling factor for rescaling the output.
40
+ Optional scalar tensor for rescaling when rescale=1.
41
+ rescale : int, optional
42
+ Whether to apply rescaling (0=no, 1=yes).
43
+
44
+ Returns
45
+ -------
46
+ numpy.ndarray
47
+ Mul tensor with dtype int64.
48
+
49
+ Notes
50
+ -----
51
+ - This op is part of the `ai.onnx.contrib` custom domain.
52
+ - ONNX Runtime Extensions is required to register this op.
53
+
54
+ References
55
+ ----------
56
+ For more information on the Mul operation, please refer to the
57
+ ONNX standard Mul operator documentation:
58
+ https://onnx.ai/onnx/operators/onnx__Mul.html
59
+ """
60
+ try:
61
+ result = a * b
62
+ result = rescaling(scaling_factor, rescale, result)
63
+ return result.astype(np.int64)
64
+ except Exception as e:
65
+ msg = f"Int64Mul failed: {e}"
66
+ raise RuntimeError(msg) from e
@@ -39,5 +39,5 @@ def int64_relu(x: np.ndarray) -> np.ndarray:
39
39
  try:
40
40
  return np.maximum(x, 0).astype(np.int64)
41
41
  except Exception as e:
42
- msg = f"Int64Gemm failed: {e}"
42
+ msg = f"Int64ReLU failed: {e}"
43
43
  raise RuntimeError(msg) from e
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
7
+
8
+ from python.core.model_processing.onnx_quantizer.layers.base import (
9
+ BaseOpQuantizer,
10
+ QuantizerBase,
11
+ ScaleConfig,
12
+ )
13
+
14
+
15
+ class QuantizeAdd(QuantizerBase):
16
+ OP_TYPE = "Add"
17
+ DOMAIN = ""
18
+ USE_WB = True
19
+ USE_SCALING = False
20
+ SCALE_PLAN: ClassVar = {0: 1, 1: 1}
21
+
22
+
23
+ class AddQuantizer(BaseOpQuantizer, QuantizeAdd):
24
+ """
25
+ Quantizer for ONNX Add layers.
26
+
27
+ - Uses standard ONNX Add layer in standard domain, and
28
+ makes relevant additional changes to the graph.
29
+ """
30
+
31
+ def __init__(
32
+ self: AddQuantizer,
33
+ new_initializers: list[onnx.TensorProto] | None = None,
34
+ ) -> None:
35
+ super().__init__()
36
+ # Only replace if caller provided something
37
+ if new_initializers is not None:
38
+ self.new_initializers = new_initializers
39
+
40
+ def quantize(
41
+ self: AddQuantizer,
42
+ node: onnx.NodeProto,
43
+ graph: onnx.GraphProto,
44
+ scale_config: ScaleConfig,
45
+ initializer_map: dict[str, onnx.TensorProto],
46
+ ) -> list[onnx.NodeProto]:
47
+ return QuantizeAdd.quantize(self, node, graph, scale_config, initializer_map)
48
+
49
+ def check_supported(
50
+ self: AddQuantizer,
51
+ node: onnx.NodeProto,
52
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
53
+ ) -> None:
54
+ pass