JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (61) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
  2. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
  3. python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +114 -59
  7. python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
  8. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  9. python/core/model_processing/onnx_custom_ops/mul.py +66 -0
  10. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  11. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  12. python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
  13. python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
  14. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  15. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  16. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  17. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  18. python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
  19. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  20. python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
  21. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
  22. python/core/utils/general_layer_functions.py +17 -12
  23. python/core/utils/model_registry.py +6 -3
  24. python/scripts/gen_and_bench.py +2 -2
  25. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  26. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  27. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  28. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  29. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  30. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  31. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  32. python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
  33. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  34. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  35. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  36. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  37. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  38. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  39. python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
  40. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  41. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  42. python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
  43. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  44. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  45. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  46. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  47. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  48. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  49. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
  50. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  51. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  52. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  53. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  54. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  55. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  56. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  57. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  58. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
  59. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
  60. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
  61. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from typing import ClassVar
4
5
 
5
6
  import numpy as np
6
7
  import onnx
7
8
  from onnx import helper, numpy_helper
8
9
 
9
10
  from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
11
+ extract_attributes,
10
12
  replace_input_references,
11
13
  )
12
14
  from python.core.model_processing.onnx_quantizer.exceptions import (
@@ -188,7 +190,7 @@ class BaseOpQuantizer:
188
190
  """
189
191
  self.validate_node_has_output(node)
190
192
 
191
- original_output = node.output.get(0)
193
+ original_output = node.output[0]
192
194
  quantized_output = original_output + "_raw"
193
195
  node.output[0] = quantized_output
194
196
 
@@ -294,6 +296,61 @@ class BaseOpQuantizer:
294
296
  # === Mutate the original node ===
295
297
  return nodes, new_inputs
296
298
 
299
+ def add_scaled_initializer_inputs(
300
+ self: BaseOpQuantizer,
301
+ node: onnx.NodeProto,
302
+ initializer_map: dict[str, onnx.TensorProto],
303
+ scale_base: int,
304
+ scale_exponent: int,
305
+ scale_plan: dict[int, int],
306
+ ) -> tuple[list[onnx.NodeProto], list[str]]:
307
+ """
308
+ Scale and cast specific initializer inputs
309
+ of a node according to a scaling plan.
310
+
311
+ Handles optional inputs gracefully (e.g. missing bias).
312
+ """
313
+ new_nodes: list[onnx.NodeProto] = []
314
+ new_inputs = list(node.input)
315
+
316
+ for input_idx, scale_mult in scale_plan.items():
317
+ # Skip if node doesn't have that many inputs (e.g. missing bias)
318
+ if input_idx >= len(node.input):
319
+ # Just ignore — optional input not provided
320
+ continue
321
+
322
+ input_name = node.input[input_idx]
323
+ if not input_name:
324
+ # Empty input name → optional input not present
325
+ continue
326
+
327
+ if input_name not in initializer_map:
328
+ # Optional inputs may be missing from initializers (e.g., dynamic bias)
329
+ continue
330
+
331
+ tensor = initializer_map[input_name]
332
+ if not tensor.name:
333
+ raise HandlerImplementationError(
334
+ op_type=node.op_type,
335
+ message=f"Initializer tensor for '{input_name}' on node "
336
+ f"'{node.name}' is missing a name.",
337
+ )
338
+
339
+ # Scale according to plan (e.g., scale_exponent * 2 for bias)
340
+ quant_name, mul_node, cast_node = self.insert_scale_node(
341
+ tensor=tensor,
342
+ scale_base=scale_base,
343
+ scale_exponent=(scale_exponent * scale_mult),
344
+ )
345
+
346
+ # Update node input to point to scaled version
347
+ new_inputs[input_idx] = quant_name
348
+
349
+ # Record new scaling/cast nodes
350
+ new_nodes.extend([mul_node, cast_node])
351
+
352
+ return new_nodes, new_inputs
353
+
297
354
  def insert_scale_node(
298
355
  self: BaseOpQuantizer,
299
356
  tensor: onnx.TensorProto,
@@ -360,6 +417,136 @@ class BaseOpQuantizer:
360
417
  return output_name, mul_node, cast_to_int64
361
418
 
362
419
 
420
+ class QuantizerBase:
421
+ OP_TYPE = None
422
+ DOMAIN = "ai.onnx.contrib"
423
+ DEFAULT_ATTRS: ClassVar = {}
424
+ USE_WB = False
425
+ USE_SCALING = False
426
+
427
+ def quantize(
428
+ self,
429
+ node: onnx.NodeProto,
430
+ graph: onnx.GraphProto,
431
+ scale_config: ScaleConfig,
432
+ initializer_map: dict[str, onnx.TensorProto],
433
+ ) -> list[onnx.NodeProto]:
434
+ """Generic quantization template for most Int64 ops."""
435
+ _ = graph
436
+ nodes = []
437
+
438
+ # (1) Quantize weights/bias if applicable
439
+ if self.USE_WB:
440
+ # Each subclass defines its scaling plan for which inputs get scaled and how
441
+ scale_plan = getattr(self, "SCALE_PLAN", {1: 1, 2: 2}) # default for W & B
442
+ nodes, new_inputs = self.add_scaled_initializer_inputs(
443
+ node=node,
444
+ initializer_map=initializer_map,
445
+ scale_base=scale_config.base,
446
+ scale_exponent=scale_config.exponent,
447
+ scale_plan=scale_plan,
448
+ )
449
+ node.input[:] = new_inputs
450
+
451
+ # (2) Collect & merge attributes
452
+ attrs = extract_attributes(node)
453
+ for k, v in self.DEFAULT_ATTRS.items():
454
+ attrs.setdefault(k, v)
455
+ if self.USE_SCALING:
456
+ attrs["rescale"] = int(scale_config.rescale)
457
+
458
+ # (3) Add scaling constant if needed
459
+ if self.USE_SCALING:
460
+ scale_value = self.get_scaling(scale_config.base, scale_config.exponent)
461
+ scale_name = f"{node.name}_int_scaler"
462
+ scale_tensor = numpy_helper.from_array(
463
+ np.array([scale_value], dtype=np.int64),
464
+ name=scale_name,
465
+ )
466
+ self.new_initializers.append(scale_tensor)
467
+ node.input.append(scale_name)
468
+
469
+ # (4) Create quantized node
470
+ quantized_node = onnx.helper.make_node(
471
+ self.OP_TYPE,
472
+ inputs=node.input,
473
+ outputs=node.output,
474
+ name=node.name,
475
+ domain=self.DOMAIN,
476
+ **attrs,
477
+ )
478
+
479
+ nodes.append(quantized_node)
480
+ return nodes
481
+
482
+ def pre_analysis_transform(
483
+ self: QuantizerBase,
484
+ node: onnx.NodeProto,
485
+ graph: onnx.GraphProto,
486
+ initializer_map: dict[str, onnx.TensorProto],
487
+ scale_base: int,
488
+ scale_exponent: int,
489
+ ) -> None:
490
+ """
491
+ pre_analysis_transform aims to transform the given layer along the
492
+ same lines as it would be transformed for the quantized model, but
493
+ for the weights and biases file instead, to be sent to the backend
494
+
495
+ Default pre-analysis behavior:
496
+
497
+ - If the subclass uses weights/bias (`USE_WB=True`), apply the SAME
498
+ scaling rules as quantization, but directly mutate the initializers.
499
+
500
+ - Subclasses can override this to implement more complex rewrites
501
+ (e.g., BatchNorm → Mul/Add).
502
+
503
+ Args:
504
+ node (onnx.NodeProto): Node to transform.
505
+ graph (onnx.GraphProto): Rest of the Onnx graph for initializers.
506
+ initializer_map (dict[str, onnx.TensorProto]): The initializer map.
507
+
508
+ scale_base (int): Scaling base.
509
+ scale_exponent (int): Scaling exponent.
510
+
511
+ NOTE
512
+ - The resulting model will not make accurate prediction and should be
513
+ used solely for analysis and keeping track of w_and_b
514
+ """
515
+ # If subclass does not want auto-scaling, do nothing
516
+ if not getattr(self, "USE_WB", False):
517
+ return
518
+
519
+ # Each quantizer defines which inputs to scale (Weight:1x, Bias:2x etc.)
520
+ scale_plan = getattr(self, "SCALE_PLAN", {})
521
+
522
+ # Perform the same scaling as quantization, but directly modify initializers
523
+ for input_idx, scale_mult in scale_plan.items():
524
+ if input_idx >= len(node.input):
525
+ continue
526
+
527
+ name = node.input[input_idx]
528
+ if name not in initializer_map:
529
+ continue # optional input missing
530
+
531
+ tensor = initializer_map[name]
532
+ arr = numpy_helper.to_array(tensor).astype(np.float64)
533
+
534
+ scale = scale_base ** (scale_exponent * scale_mult)
535
+ new_arr = arr * scale
536
+
537
+ # Replace initializer directly
538
+ new_tensor = numpy_helper.from_array(new_arr, name=tensor.name)
539
+
540
+ # Modify graph initializer in place
541
+ for j in range(len(graph.initializer)):
542
+ if graph.initializer[j].name == tensor.name:
543
+ del graph.initializer[j]
544
+ break
545
+ graph.initializer.append(new_tensor)
546
+
547
+ initializer_map[tensor.name] = new_tensor
548
+
549
+
363
550
  class PassthroughQuantizer(BaseOpQuantizer):
364
551
  """
365
552
  Quantizer that leaves the node unchanged.
@@ -0,0 +1,224 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ from python.core.circuits.errors import CircuitConfigurationError
6
+
7
+ if TYPE_CHECKING:
8
+ import onnx
9
+
10
+ import numpy as np
11
+ from onnx import helper, numpy_helper
12
+
13
+ from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_attributes
14
+ from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
15
+ from python.core.model_processing.onnx_quantizer.layers.base import (
16
+ BaseOpQuantizer,
17
+ QuantizerBase,
18
+ ScaleConfig,
19
+ )
20
+
21
+
22
+ class QuantizeBatchnorm(QuantizerBase):
23
+ OP_TYPE = "Int64BatchNorm"
24
+ USE_WB = True
25
+ USE_SCALING = False
26
+ SCALE_PLAN: ClassVar = {}
27
+
28
+
29
+ class BatchnormQuantizer(BaseOpQuantizer, QuantizeBatchnorm):
30
+ """
31
+ Quantizer for ONNX Batchnorm layers.
32
+
33
+ - Uses standard ONNX Batchnorm layer in standard domain, and
34
+ makes relevant additional changes to the graph.
35
+ """
36
+
37
+ def __init__(
38
+ self: BatchnormQuantizer,
39
+ new_initializers: list[onnx.TensorProto] | None = None,
40
+ ) -> None:
41
+ super().__init__()
42
+ # Only replace if caller provided something
43
+ if new_initializers is not None:
44
+ self.new_initializers = new_initializers
45
+
46
+ def _compute_mul_add(
47
+ self: BatchnormQuantizer,
48
+ initializer_map: dict[str, onnx.TensorProto],
49
+ node: onnx.NodeProto,
50
+ scale_base: int,
51
+ scale_exponent: int,
52
+ ) -> tuple[np.ndarray, np.ndarray]:
53
+ """
54
+ Compute the 'mul' and 'add' tensors for BatchNorm folding.
55
+ """
56
+ self._validate_inputs(node=node)
57
+ # ONNX BatchNorm inputs: [X, scale, bias, mean, var]
58
+ scale_factor = scale_base**scale_exponent
59
+ scale = numpy_helper.to_array(initializer_map[node.input[1]]).astype(np.float32)
60
+ bias = numpy_helper.to_array(initializer_map[node.input[2]]).astype(np.float32)
61
+ mean = numpy_helper.to_array(initializer_map[node.input[3]]).astype(np.float32)
62
+ var = numpy_helper.to_array(initializer_map[node.input[4]]).astype(np.float32)
63
+
64
+ # Find epsilon attribute
65
+ epsilon_attr = next((a for a in node.attribute if a.name == "epsilon"), None)
66
+ epsilon = float(epsilon_attr.f) if epsilon_attr else 1e-5
67
+
68
+ mul = scale / np.sqrt(var + epsilon)
69
+ add = bias - mean * mul
70
+ scaled_add = add * (scale_factor**2)
71
+ scaled_mul = scale_factor * mul
72
+ return scaled_mul, scaled_add
73
+
74
+ def pre_analysis_transform(
75
+ self: BatchnormQuantizer,
76
+ node: onnx.NodeProto,
77
+ graph: onnx.GraphProto,
78
+ initializer_map: dict[str, onnx.TensorProto],
79
+ scale_base: int,
80
+ scale_exponent: int,
81
+ ) -> None:
82
+ # Compute linearized BN tensors
83
+ mul, add = self._compute_mul_add(
84
+ initializer_map,
85
+ node,
86
+ scale_base=scale_base,
87
+ scale_exponent=scale_exponent,
88
+ )
89
+
90
+ # Name base
91
+ node_name = node.name if node.name else node.input[0]
92
+ mul_name = f"{node_name}_mul"
93
+ add_name = f"{node_name}_add"
94
+
95
+ # Create ONNX tensors
96
+ mul_tensor = numpy_helper.from_array(mul.astype(np.int64), name=mul_name)
97
+ add_tensor = numpy_helper.from_array(add.astype(np.int64), name=add_name)
98
+
99
+ # Insert them into the graph
100
+ graph.initializer.extend([mul_tensor, add_tensor])
101
+ initializer_map[mul_name] = mul_tensor
102
+ initializer_map[add_name] = add_tensor
103
+ self.new_initializers.extend([mul_tensor, add_tensor])
104
+
105
+ node.input[:] = [node.input[0], mul_name, add_name]
106
+
107
+ del node.attribute[:]
108
+
109
+ def quantize(
110
+ self,
111
+ node: onnx.NodeProto,
112
+ graph: onnx.GraphProto,
113
+ scale_config: ScaleConfig,
114
+ initializer_map: dict[str, onnx.TensorProto],
115
+ ) -> list[onnx.NodeProto]:
116
+ _ = graph
117
+
118
+ nodes: list[onnx.NodeProto] = []
119
+
120
+ # 1. Compute unscaled float mul/add coefficients
121
+ mul, add = self._compute_mul_add(
122
+ initializer_map,
123
+ node,
124
+ scale_base=1,
125
+ scale_exponent=1,
126
+ )
127
+
128
+ node_name = node.name if node.name else node.input[0]
129
+ mul_name = f"{node_name}_mul"
130
+ add_name = f"{node_name}_add"
131
+
132
+ # 2. Store unscaled mul and add initializers (as floats)
133
+ scale_value = self.get_scaling(scale_config.base, scale_config.exponent)
134
+ scale_name = f"{node.name}_int_scaler"
135
+ scale_tensor = numpy_helper.from_array(
136
+ np.array([scale_value], dtype=np.int64),
137
+ name=scale_name,
138
+ )
139
+ self.new_initializers.append(scale_tensor)
140
+
141
+ mul_tensor = numpy_helper.from_array(mul.astype(np.float32), name=mul_name)
142
+ add_tensor = numpy_helper.from_array(add.astype(np.float32), name=add_name)
143
+
144
+ initializer_map[mul_name] = mul_tensor
145
+ initializer_map[add_name] = add_tensor
146
+
147
+ # 3. Insert scale and cast for mul_tensor
148
+ scaled_mul_name, mul_scale_node, mul_cast_node = self.insert_scale_node(
149
+ tensor=mul_tensor,
150
+ scale_base=scale_config.base,
151
+ scale_exponent=scale_config.exponent,
152
+ )
153
+
154
+ # 4. Insert scale and cast for add_tensor
155
+ scaled_add_name, add_scale_node, add_cast_node = self.insert_scale_node(
156
+ tensor=add_tensor,
157
+ scale_base=scale_config.base,
158
+ scale_exponent=scale_config.exponent * 2,
159
+ )
160
+ # Note, order is important here
161
+ nodes.extend(
162
+ [
163
+ mul_scale_node,
164
+ mul_cast_node,
165
+ add_scale_node,
166
+ add_cast_node,
167
+ ],
168
+ )
169
+
170
+ # 5. Build final Int64BatchNorm node
171
+ attrs = extract_attributes(node)
172
+ for k, v in getattr(self, "DEFAULT_ATTRS", {}).items():
173
+ attrs.setdefault(k, v)
174
+ attrs["rescale"] = 1
175
+
176
+ quant_node = helper.make_node(
177
+ self.OP_TYPE, # Should be "Int64BatchNorm"
178
+ inputs=[
179
+ node.input[0], # original X
180
+ scaled_mul_name, # scaled mul
181
+ scaled_add_name, # scaled add
182
+ scale_name, # scaling factor
183
+ ],
184
+ outputs=node.output,
185
+ name=node.name,
186
+ domain=self.DOMAIN,
187
+ **attrs,
188
+ )
189
+
190
+ nodes.append(quant_node)
191
+ return nodes
192
+
193
+ def check_supported(
194
+ self: BatchnormQuantizer,
195
+ node: onnx.NodeProto,
196
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
197
+ ) -> None:
198
+ """
199
+ For our current implementation, all batchnorm inputs
200
+ (scale, variance, mean, etc.)
201
+ must be initializers to the circuit and not inputs from earlier in the graph.
202
+ """
203
+
204
+ if initializer_map is None:
205
+ msg = "initializer_map is required for BatchNorm support check"
206
+ raise CircuitConfigurationError(node.name, node.op_type, msg)
207
+
208
+ self._validate_inputs(node=node)
209
+
210
+ # First, check to make sure that each of the batchnorm inputs are initializers
211
+ initializer_inputs = node.input[1:]
212
+ if not all(i in initializer_map for i in initializer_inputs):
213
+ msg = "Unsupported BatchNorm with normalization inputs not in initializers"
214
+ raise InvalidParamError(node.name, node.op_type, msg)
215
+
216
+ def _validate_inputs(self, node: onnx.NodeProto) -> None:
217
+ """Validate BatchNorm has required inputs in initializer_map."""
218
+ num_inputs = 5
219
+ if len(node.input) < num_inputs:
220
+ raise InvalidParamError(
221
+ node.name,
222
+ node.op_type,
223
+ f"BatchNorm requires 5 inputs, got {len(node.input)}",
224
+ )
@@ -38,7 +38,7 @@ class ConstantQuantizer(BaseOpQuantizer):
38
38
 
39
39
  def __init__(
40
40
  self: ConstantQuantizer,
41
- new_initializer: dict[str, onnx.TensorProto] | None = None,
41
+ new_initializer: list[onnx.TensorProto] | None = None,
42
42
  ) -> None:
43
43
  super().__init__()
44
44
  _ = new_initializer
@@ -1,18 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
- import onnx
5
- from onnx import numpy_helper
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
6
7
 
7
- from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_attributes
8
8
  from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
9
9
  from python.core.model_processing.onnx_quantizer.layers.base import (
10
10
  BaseOpQuantizer,
11
+ QuantizerBase,
11
12
  ScaleConfig,
12
13
  )
13
14
 
14
15
 
15
- class ConvQuantizer(BaseOpQuantizer):
16
+ class QuantizeConv(QuantizerBase):
17
+ OP_TYPE = "Int64Conv"
18
+ USE_WB = True
19
+ USE_SCALING = True
20
+ DEFAULT_ATTRS: ClassVar = {"group": 1, "auto_pad": "NOTSET"}
21
+ SCALE_PLAN: ClassVar = {1: 1, 2: 2} # weight = 1x scale, bias = 2x scale
22
+
23
+
24
+ class ConvQuantizer(BaseOpQuantizer, QuantizeConv):
16
25
  """
17
26
  Quantizer for ONNX Conv layers.
18
27
 
@@ -23,9 +32,12 @@ class ConvQuantizer(BaseOpQuantizer):
23
32
 
24
33
  def __init__(
25
34
  self: ConvQuantizer,
26
- new_initializers: dict[str, onnx.TensorProto],
35
+ new_initializers: list[onnx.TensorProto] | None = None,
27
36
  ) -> None:
28
- self.new_initializers = new_initializers
37
+ super().__init__()
38
+ # Only replace if caller provided something
39
+ if new_initializers is not None:
40
+ self.new_initializers = new_initializers
29
41
 
30
42
  def quantize(
31
43
  self: ConvQuantizer,
@@ -34,67 +46,7 @@ class ConvQuantizer(BaseOpQuantizer):
34
46
  scale_config: ScaleConfig,
35
47
  initializer_map: dict[str, onnx.TensorProto],
36
48
  ) -> list[onnx.NodeProto]:
37
- """
38
- Quantize a Conv node by:
39
- 1. Quantizing its weights and bias.
40
- 2. Adding a scale constant.
41
- 3. Replacing it with an Int64Conv node.
42
-
43
- Args:
44
- node (onnx.NodeProto): The node to quantize.
45
- rescale (bool): Whether rescaling is enabled
46
- (Doesnt have an affect on this op type)
47
- graph (onnx.GraphProto): The ONNX graph.
48
- scale_exponent (int): Scale exponent.
49
- scale_base (int): The base of scaling.
50
- initializer_map (dict[str, onnx.TensorProto]):
51
- Map of initializer names to tensor data.
52
-
53
- Returns:
54
- list[onnx.NodeProto]: A list of ONNX nodes
55
- (quantized and any auxiliary nodes).
56
- """
57
- _ = graph
58
-
59
- nodes = []
60
- output_name = f"{node.name}_int"
61
-
62
- nodes, node.input[:] = self.add_nodes_w_and_b(
63
- node=node,
64
- scale_exponent=scale_config.exponent,
65
- scale_base=scale_config.base,
66
- initializer_map=initializer_map,
67
- )
68
- attrs = extract_attributes(node)
69
- attrs.setdefault("group", 1)
70
- attrs.setdefault("auto_pad", "NOTSET")
71
-
72
- attrs["rescale"] = int(scale_config.rescale)
73
-
74
- scale_value = self.get_scaling(
75
- scale_config.base,
76
- scale_config.exponent,
77
- )
78
-
79
- # Create scale constant
80
- scale_const_name = f"{output_name}_scaler"
81
- scale_tensor = numpy_helper.from_array(
82
- np.array([scale_value], dtype=np.int64),
83
- name=scale_const_name,
84
- )
85
- self.new_initializers.append(scale_tensor)
86
- node.input.append(scale_const_name)
87
- int64_conv_node = onnx.helper.make_node(
88
- "Int64Conv",
89
- inputs=node.input,
90
- outputs=node.output, # preserve original output name
91
- name=node.name,
92
- domain="ai.onnx.contrib",
93
- **attrs,
94
- )
95
-
96
- nodes.append(int64_conv_node)
97
- return nodes
49
+ return QuantizeConv.quantize(self, node, graph, scale_config, initializer_map)
98
50
 
99
51
  def check_supported(
100
52
  self: ConvQuantizer,
@@ -1,18 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
- import onnx
5
- from onnx import numpy_helper
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
6
7
 
7
- from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_attributes
8
8
  from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
9
9
  from python.core.model_processing.onnx_quantizer.layers.base import (
10
10
  BaseOpQuantizer,
11
+ QuantizerBase,
11
12
  ScaleConfig,
12
13
  )
13
14
 
14
15
 
15
- class GemmQuantizer(BaseOpQuantizer):
16
+ class QuantizeGemm(QuantizerBase):
17
+ OP_TYPE = "Int64Gemm"
18
+ USE_WB = True
19
+ USE_SCALING = True
20
+ DEFAULT_ATTRS: ClassVar = {"transA": 0, "transB": 0}
21
+ SCALE_PLAN: ClassVar = {1: 1, 2: 2}
22
+
23
+
24
+ class GemmQuantizer(BaseOpQuantizer, QuantizeGemm):
16
25
  """
17
26
  Quantizer for ONNX Gemm layers.
18
27
 
@@ -23,9 +32,12 @@ class GemmQuantizer(BaseOpQuantizer):
23
32
 
24
33
  def __init__(
25
34
  self: GemmQuantizer,
26
- new_initializers: dict[str, onnx.TensorProto],
35
+ new_initializers: list[onnx.TensorProto] | None = None,
27
36
  ) -> None:
28
- self.new_initializers = new_initializers
37
+ super().__init__()
38
+ # Only replace if caller provided something
39
+ if new_initializers is not None:
40
+ self.new_initializers = new_initializers
29
41
 
30
42
  def quantize(
31
43
  self: GemmQuantizer,
@@ -34,65 +46,7 @@ class GemmQuantizer(BaseOpQuantizer):
34
46
  scale_config: ScaleConfig,
35
47
  initializer_map: dict[str, onnx.TensorProto],
36
48
  ) -> list[onnx.NodeProto]:
37
- """
38
- Quantize a Gemm node by:
39
- 1. Quantizing its weights and bias.
40
- 2. Adding a scale constant.
41
- 3. Replacing it with an Int64Gemm node.
42
-
43
- Args:
44
- node (onnx.NodeProto): The node to quantize.
45
- rescale (bool): Whether rescaling is enabled
46
- graph (onnx.GraphProto): The ONNX graph.
47
- scale_exponent (int): Scale exponent.
48
- scale_base (int): The base of scaling.
49
- initializer_map (dict[str, onnx.TensorProto]):
50
- Map of initializer names to tensor data.
51
-
52
- Returns:
53
- List[onnx.NodeProto]: A list of ONNX nodes
54
- (quantized and any auxiliary nodes).
55
- """
56
- _ = graph
57
- nodes = []
58
- output_name = f"{node.name}_int"
59
-
60
- nodes, new_inputs = self.add_nodes_w_and_b(
61
- node=node,
62
- scale_exponent=scale_config.exponent,
63
- scale_base=scale_config.base,
64
- initializer_map=initializer_map,
65
- )
66
- node.input[:] = new_inputs
67
-
68
- attrs = extract_attributes(node)
69
- attrs.setdefault("transA", 0)
70
- attrs.setdefault("transB", 0)
71
- attrs["rescale"] = int(scale_config.rescale)
72
-
73
- scale_value = self.get_scaling(
74
- scale_config.base,
75
- scale_config.exponent,
76
- )
77
-
78
- # === Create scale constant ===
79
- scale_const_name = f"{output_name}_scaler"
80
- scale_tensor = numpy_helper.from_array(
81
- np.array([scale_value], dtype=np.int64),
82
- name=scale_const_name,
83
- )
84
- self.new_initializers.append(scale_tensor)
85
- node.input.append(scale_const_name)
86
- int64_gemm = onnx.helper.make_node(
87
- "Int64Gemm",
88
- inputs=node.input,
89
- outputs=node.output, # preserve original output name
90
- name=output_name,
91
- domain="ai.onnx.contrib",
92
- **attrs,
93
- )
94
- nodes.append(int64_gemm)
95
- return nodes
49
+ return QuantizeGemm.quantize(self, node, graph, scale_config, initializer_map)
96
50
 
97
51
  def check_supported(
98
52
  self: GemmQuantizer,