JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.1.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/METADATA +2 -2
  2. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/RECORD +51 -24
  3. python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +86 -32
  7. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  8. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  9. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  10. python/core/model_processing/onnx_quantizer/layers/base.py +121 -1
  11. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  12. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  13. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  14. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  15. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  16. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -1
  17. python/core/utils/general_layer_functions.py +17 -12
  18. python/core/utils/model_registry.py +6 -3
  19. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  20. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  21. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  22. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  23. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  24. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  25. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  26. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  27. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  28. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  29. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  30. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  31. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  32. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  33. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  34. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  35. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  36. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  37. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  38. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  39. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  40. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +265 -0
  41. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  42. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  43. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  44. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  45. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  46. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  47. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  48. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  49. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/WHEEL +0 -0
  50. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/entry_points.txt +0 -0
  51. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/licenses/LICENSE +0 -0
  52. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
7
+
8
+ from python.core.model_processing.onnx_quantizer.layers.base import (
9
+ BaseOpQuantizer,
10
+ QuantizerBase,
11
+ ScaleConfig,
12
+ )
13
+
14
+
15
+ class QuantizeAdd(QuantizerBase):
16
+ OP_TYPE = "Add"
17
+ DOMAIN = ""
18
+ USE_WB = True
19
+ USE_SCALING = False
20
+ SCALE_PLAN: ClassVar = {0: 1, 1: 1}
21
+
22
+
23
+ class AddQuantizer(BaseOpQuantizer, QuantizeAdd):
24
+ """
25
+ Quantizer for ONNX Add layers.
26
+
27
+ - Uses standard ONNX Add layer in standard domain, and
28
+ makes relevant additional changes to the graph.
29
+ """
30
+
31
+ def __init__(
32
+ self: AddQuantizer,
33
+ new_initializers: list[onnx.TensorProto] | None = None,
34
+ ) -> None:
35
+ super().__init__()
36
+ # Only replace if caller provided something
37
+ if new_initializers is not None:
38
+ self.new_initializers = new_initializers
39
+
40
+ def quantize(
41
+ self: AddQuantizer,
42
+ node: onnx.NodeProto,
43
+ graph: onnx.GraphProto,
44
+ scale_config: ScaleConfig,
45
+ initializer_map: dict[str, onnx.TensorProto],
46
+ ) -> list[onnx.NodeProto]:
47
+ return QuantizeAdd.quantize(self, node, graph, scale_config, initializer_map)
48
+
49
+ def check_supported(
50
+ self: AddQuantizer,
51
+ node: onnx.NodeProto,
52
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
53
+ ) -> None:
54
+ pass
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from typing import ClassVar
4
5
 
5
6
  import numpy as np
6
7
  import onnx
7
8
  from onnx import helper, numpy_helper
8
9
 
9
10
  from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
11
+ extract_attributes,
10
12
  replace_input_references,
11
13
  )
12
14
  from python.core.model_processing.onnx_quantizer.exceptions import (
@@ -188,7 +190,7 @@ class BaseOpQuantizer:
188
190
  """
189
191
  self.validate_node_has_output(node)
190
192
 
191
- original_output = node.output.get(0)
193
+ original_output = node.output[0]
192
194
  quantized_output = original_output + "_raw"
193
195
  node.output[0] = quantized_output
194
196
 
@@ -294,6 +296,61 @@ class BaseOpQuantizer:
294
296
  # === Mutate the original node ===
295
297
  return nodes, new_inputs
296
298
 
299
+ def add_scaled_initializer_inputs(
300
+ self: BaseOpQuantizer,
301
+ node: onnx.NodeProto,
302
+ initializer_map: dict[str, onnx.TensorProto],
303
+ scale_base: int,
304
+ scale_exponent: int,
305
+ scale_plan: dict[int, int],
306
+ ) -> tuple[list[onnx.NodeProto], list[str]]:
307
+ """
308
+ Scale and cast specific initializer inputs
309
+ of a node according to a scaling plan.
310
+
311
+ Handles optional inputs gracefully (e.g. missing bias).
312
+ """
313
+ new_nodes: list[onnx.NodeProto] = []
314
+ new_inputs = list(node.input)
315
+
316
+ for input_idx, scale_mult in scale_plan.items():
317
+ # Skip if node doesn't have that many inputs (e.g. missing bias)
318
+ if input_idx >= len(node.input):
319
+ # Just ignore — optional input not provided
320
+ continue
321
+
322
+ input_name = node.input[input_idx]
323
+ if not input_name:
324
+ # Empty input name → optional input not present
325
+ continue
326
+
327
+ if input_name not in initializer_map:
328
+ # Optional inputs may be missing from initializers (e.g., dynamic bias)
329
+ continue
330
+
331
+ tensor = initializer_map[input_name]
332
+ if not tensor.name:
333
+ raise HandlerImplementationError(
334
+ op_type=node.op_type,
335
+ message=f"Initializer tensor for '{input_name}' on node "
336
+ f"'{node.name}' is missing a name.",
337
+ )
338
+
339
+ # Scale according to plan (e.g., scale_exponent * 2 for bias)
340
+ quant_name, mul_node, cast_node = self.insert_scale_node(
341
+ tensor=tensor,
342
+ scale_base=scale_base,
343
+ scale_exponent=(scale_exponent * scale_mult),
344
+ )
345
+
346
+ # Update node input to point to scaled version
347
+ new_inputs[input_idx] = quant_name
348
+
349
+ # Record new scaling/cast nodes
350
+ new_nodes.extend([mul_node, cast_node])
351
+
352
+ return new_nodes, new_inputs
353
+
297
354
  def insert_scale_node(
298
355
  self: BaseOpQuantizer,
299
356
  tensor: onnx.TensorProto,
@@ -360,6 +417,69 @@ class BaseOpQuantizer:
360
417
  return output_name, mul_node, cast_to_int64
361
418
 
362
419
 
420
+ class QuantizerBase:
421
+ OP_TYPE = None
422
+ DOMAIN = "ai.onnx.contrib"
423
+ DEFAULT_ATTRS: ClassVar = {}
424
+ USE_WB = False
425
+ USE_SCALING = False
426
+
427
+ def quantize(
428
+ self,
429
+ node: onnx.NodeProto,
430
+ graph: onnx.GraphProto,
431
+ scale_config: ScaleConfig,
432
+ initializer_map: dict[str, onnx.TensorProto],
433
+ ) -> list[onnx.NodeProto]:
434
+ """Generic quantization template for most Int64 ops."""
435
+ _ = graph
436
+ nodes = []
437
+
438
+ # (1) Quantize weights/bias if applicable
439
+ if self.USE_WB:
440
+ # Each subclass defines its scaling plan for which inputs get scaled and how
441
+ scale_plan = getattr(self, "SCALE_PLAN", {1: 1, 2: 2}) # default for W & B
442
+ nodes, new_inputs = self.add_scaled_initializer_inputs(
443
+ node=node,
444
+ initializer_map=initializer_map,
445
+ scale_base=scale_config.base,
446
+ scale_exponent=scale_config.exponent,
447
+ scale_plan=scale_plan,
448
+ )
449
+ node.input[:] = new_inputs
450
+
451
+ # (2) Collect & merge attributes
452
+ attrs = extract_attributes(node)
453
+ for k, v in self.DEFAULT_ATTRS.items():
454
+ attrs.setdefault(k, v)
455
+ if self.USE_SCALING:
456
+ attrs["rescale"] = int(scale_config.rescale)
457
+
458
+ # (3) Add scaling constant if needed
459
+ if self.USE_SCALING:
460
+ scale_value = self.get_scaling(scale_config.base, scale_config.exponent)
461
+ scale_name = f"{node.name}_int_scaler"
462
+ scale_tensor = numpy_helper.from_array(
463
+ np.array([scale_value], dtype=np.int64),
464
+ name=scale_name,
465
+ )
466
+ self.new_initializers.append(scale_tensor)
467
+ node.input.append(scale_name)
468
+
469
+ # (4) Create quantized node
470
+ quantized_node = onnx.helper.make_node(
471
+ self.OP_TYPE,
472
+ inputs=node.input,
473
+ outputs=node.output,
474
+ name=node.name,
475
+ domain=self.DOMAIN,
476
+ **attrs,
477
+ )
478
+
479
+ nodes.append(quantized_node)
480
+ return nodes
481
+
482
+
363
483
  class PassthroughQuantizer(BaseOpQuantizer):
364
484
  """
365
485
  Quantizer that leaves the node unchanged.
@@ -38,7 +38,7 @@ class ConstantQuantizer(BaseOpQuantizer):
38
38
 
39
39
  def __init__(
40
40
  self: ConstantQuantizer,
41
- new_initializer: dict[str, onnx.TensorProto] | None = None,
41
+ new_initializer: list[onnx.TensorProto] | None = None,
42
42
  ) -> None:
43
43
  super().__init__()
44
44
  _ = new_initializer
@@ -1,18 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
- import onnx
5
- from onnx import numpy_helper
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
6
7
 
7
- from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_attributes
8
8
  from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
9
9
  from python.core.model_processing.onnx_quantizer.layers.base import (
10
10
  BaseOpQuantizer,
11
+ QuantizerBase,
11
12
  ScaleConfig,
12
13
  )
13
14
 
14
15
 
15
- class ConvQuantizer(BaseOpQuantizer):
16
+ class QuantizeConv(QuantizerBase):
17
+ OP_TYPE = "Int64Conv"
18
+ USE_WB = True
19
+ USE_SCALING = True
20
+ DEFAULT_ATTRS: ClassVar = {"group": 1, "auto_pad": "NOTSET"}
21
+ SCALE_PLAN: ClassVar = {1: 1, 2: 2} # weight = 1x scale, bias = 2x scale
22
+
23
+
24
+ class ConvQuantizer(BaseOpQuantizer, QuantizeConv):
16
25
  """
17
26
  Quantizer for ONNX Conv layers.
18
27
 
@@ -23,9 +32,12 @@ class ConvQuantizer(BaseOpQuantizer):
23
32
 
24
33
  def __init__(
25
34
  self: ConvQuantizer,
26
- new_initializers: dict[str, onnx.TensorProto],
35
+ new_initializers: list[onnx.TensorProto] | None = None,
27
36
  ) -> None:
28
- self.new_initializers = new_initializers
37
+ super().__init__()
38
+ # Only replace if caller provided something
39
+ if new_initializers is not None:
40
+ self.new_initializers = new_initializers
29
41
 
30
42
  def quantize(
31
43
  self: ConvQuantizer,
@@ -34,67 +46,7 @@ class ConvQuantizer(BaseOpQuantizer):
34
46
  scale_config: ScaleConfig,
35
47
  initializer_map: dict[str, onnx.TensorProto],
36
48
  ) -> list[onnx.NodeProto]:
37
- """
38
- Quantize a Conv node by:
39
- 1. Quantizing its weights and bias.
40
- 2. Adding a scale constant.
41
- 3. Replacing it with an Int64Conv node.
42
-
43
- Args:
44
- node (onnx.NodeProto): The node to quantize.
45
- rescale (bool): Whether rescaling is enabled
46
- (Doesnt have an affect on this op type)
47
- graph (onnx.GraphProto): The ONNX graph.
48
- scale_exponent (int): Scale exponent.
49
- scale_base (int): The base of scaling.
50
- initializer_map (dict[str, onnx.TensorProto]):
51
- Map of initializer names to tensor data.
52
-
53
- Returns:
54
- list[onnx.NodeProto]: A list of ONNX nodes
55
- (quantized and any auxiliary nodes).
56
- """
57
- _ = graph
58
-
59
- nodes = []
60
- output_name = f"{node.name}_int"
61
-
62
- nodes, node.input[:] = self.add_nodes_w_and_b(
63
- node=node,
64
- scale_exponent=scale_config.exponent,
65
- scale_base=scale_config.base,
66
- initializer_map=initializer_map,
67
- )
68
- attrs = extract_attributes(node)
69
- attrs.setdefault("group", 1)
70
- attrs.setdefault("auto_pad", "NOTSET")
71
-
72
- attrs["rescale"] = int(scale_config.rescale)
73
-
74
- scale_value = self.get_scaling(
75
- scale_config.base,
76
- scale_config.exponent,
77
- )
78
-
79
- # Create scale constant
80
- scale_const_name = f"{output_name}_scaler"
81
- scale_tensor = numpy_helper.from_array(
82
- np.array([scale_value], dtype=np.int64),
83
- name=scale_const_name,
84
- )
85
- self.new_initializers.append(scale_tensor)
86
- node.input.append(scale_const_name)
87
- int64_conv_node = onnx.helper.make_node(
88
- "Int64Conv",
89
- inputs=node.input,
90
- outputs=node.output, # preserve original output name
91
- name=node.name,
92
- domain="ai.onnx.contrib",
93
- **attrs,
94
- )
95
-
96
- nodes.append(int64_conv_node)
97
- return nodes
49
+ return QuantizeConv.quantize(self, node, graph, scale_config, initializer_map)
98
50
 
99
51
  def check_supported(
100
52
  self: ConvQuantizer,
@@ -1,18 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
- import onnx
5
- from onnx import numpy_helper
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
6
7
 
7
- from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_attributes
8
8
  from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
9
9
  from python.core.model_processing.onnx_quantizer.layers.base import (
10
10
  BaseOpQuantizer,
11
+ QuantizerBase,
11
12
  ScaleConfig,
12
13
  )
13
14
 
14
15
 
15
- class GemmQuantizer(BaseOpQuantizer):
16
+ class QuantizeGemm(QuantizerBase):
17
+ OP_TYPE = "Int64Gemm"
18
+ USE_WB = True
19
+ USE_SCALING = True
20
+ DEFAULT_ATTRS: ClassVar = {"transA": 0, "transB": 0}
21
+ SCALE_PLAN: ClassVar = {1: 1, 2: 2}
22
+
23
+
24
+ class GemmQuantizer(BaseOpQuantizer, QuantizeGemm):
16
25
  """
17
26
  Quantizer for ONNX Gemm layers.
18
27
 
@@ -23,9 +32,12 @@ class GemmQuantizer(BaseOpQuantizer):
23
32
 
24
33
  def __init__(
25
34
  self: GemmQuantizer,
26
- new_initializers: dict[str, onnx.TensorProto],
35
+ new_initializers: list[onnx.TensorProto] | None = None,
27
36
  ) -> None:
28
- self.new_initializers = new_initializers
37
+ super().__init__()
38
+ # Only replace if caller provided something
39
+ if new_initializers is not None:
40
+ self.new_initializers = new_initializers
29
41
 
30
42
  def quantize(
31
43
  self: GemmQuantizer,
@@ -34,65 +46,7 @@ class GemmQuantizer(BaseOpQuantizer):
34
46
  scale_config: ScaleConfig,
35
47
  initializer_map: dict[str, onnx.TensorProto],
36
48
  ) -> list[onnx.NodeProto]:
37
- """
38
- Quantize a Gemm node by:
39
- 1. Quantizing its weights and bias.
40
- 2. Adding a scale constant.
41
- 3. Replacing it with an Int64Gemm node.
42
-
43
- Args:
44
- node (onnx.NodeProto): The node to quantize.
45
- rescale (bool): Whether rescaling is enabled
46
- graph (onnx.GraphProto): The ONNX graph.
47
- scale_exponent (int): Scale exponent.
48
- scale_base (int): The base of scaling.
49
- initializer_map (dict[str, onnx.TensorProto]):
50
- Map of initializer names to tensor data.
51
-
52
- Returns:
53
- List[onnx.NodeProto]: A list of ONNX nodes
54
- (quantized and any auxiliary nodes).
55
- """
56
- _ = graph
57
- nodes = []
58
- output_name = f"{node.name}_int"
59
-
60
- nodes, new_inputs = self.add_nodes_w_and_b(
61
- node=node,
62
- scale_exponent=scale_config.exponent,
63
- scale_base=scale_config.base,
64
- initializer_map=initializer_map,
65
- )
66
- node.input[:] = new_inputs
67
-
68
- attrs = extract_attributes(node)
69
- attrs.setdefault("transA", 0)
70
- attrs.setdefault("transB", 0)
71
- attrs["rescale"] = int(scale_config.rescale)
72
-
73
- scale_value = self.get_scaling(
74
- scale_config.base,
75
- scale_config.exponent,
76
- )
77
-
78
- # === Create scale constant ===
79
- scale_const_name = f"{output_name}_scaler"
80
- scale_tensor = numpy_helper.from_array(
81
- np.array([scale_value], dtype=np.int64),
82
- name=scale_const_name,
83
- )
84
- self.new_initializers.append(scale_tensor)
85
- node.input.append(scale_const_name)
86
- int64_gemm = onnx.helper.make_node(
87
- "Int64Gemm",
88
- inputs=node.input,
89
- outputs=node.output, # preserve original output name
90
- name=output_name,
91
- domain="ai.onnx.contrib",
92
- **attrs,
93
- )
94
- nodes.append(int64_gemm)
95
- return nodes
49
+ return QuantizeGemm.quantize(self, node, graph, scale_config, initializer_map)
96
50
 
97
51
  def check_supported(
98
52
  self: GemmQuantizer,
@@ -1,20 +1,28 @@
1
1
  from __future__ import annotations
2
2
 
3
- import onnx
4
- from onnx import helper
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
5
7
 
6
8
  from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
7
- extract_attributes,
8
9
  get_attribute_ints,
9
10
  )
10
11
  from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
11
12
  from python.core.model_processing.onnx_quantizer.layers.base import (
12
13
  BaseOpQuantizer,
14
+ QuantizerBase,
13
15
  ScaleConfig,
14
16
  )
15
17
 
16
18
 
17
- class MaxpoolQuantizer(BaseOpQuantizer):
19
+ class QuantizeMaxpool(QuantizerBase):
20
+ OP_TYPE = "Int64MaxPool"
21
+ USE_WB = False
22
+ USE_SCALING = False
23
+
24
+
25
+ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
18
26
  """
19
27
  Quantizer for ONNX MaxPool layers.
20
28
 
@@ -25,55 +33,26 @@ class MaxpoolQuantizer(BaseOpQuantizer):
25
33
 
26
34
  def __init__(
27
35
  self: MaxpoolQuantizer,
28
- new_initializer: dict[str, onnx.TensorProto] | None = None,
36
+ new_initializer: list[onnx.TensorProto] | None = None,
29
37
  ) -> None:
30
38
  super().__init__()
31
39
  self.accepted_kernel_shapes = [2]
32
40
  _ = new_initializer
33
41
 
34
42
  def quantize(
35
- self: BaseOpQuantizer,
43
+ self: MaxpoolQuantizer,
36
44
  node: onnx.NodeProto,
37
45
  graph: onnx.GraphProto,
38
46
  scale_config: ScaleConfig,
39
47
  initializer_map: dict[str, onnx.TensorProto],
40
48
  ) -> list[onnx.NodeProto]:
41
- """
42
- Quantize a node by converting the node to Int64 version
43
-
44
- Args:
45
- node (onnx.NodeProto): The node to quantize.
46
- rescale (bool): Whether rescaling is enabled
47
- (Doesnt have an affect on this op type)
48
- graph (onnx.GraphProto): The ONNX graph.
49
- scale_exponent (int): Scale exponent.
50
- scale_base (int): The base of scaling.
51
- initializer_map (dict[str, onnx.TensorProto]):
52
- Map of initializer names to tensor data.
53
-
54
- Returns:
55
- List[onnx.NodeProto]: A list of ONNX nodes
56
- (quantized MaxPool and any auxiliary nodes).
57
- """
58
- _ = initializer_map, graph
59
-
60
- attrs = extract_attributes(node)
61
- attrs["rescale"] = int(scale_config.rescale)
62
-
63
- attr_str = {
64
- k: ",".join(map(str, v)) if isinstance(v, list) else str(v)
65
- for k, v in attrs.items()
66
- }
67
- return [
68
- helper.make_node(
69
- "Int64MaxPool",
70
- inputs=node.input,
71
- outputs=node.output,
72
- name=node.name,
73
- domain="ai.onnx.contrib",
74
- **attr_str,
75
- ),
76
- ]
49
+ return QuantizeMaxpool.quantize(
50
+ self,
51
+ node,
52
+ graph,
53
+ scale_config,
54
+ initializer_map,
55
+ )
77
56
 
78
57
  def check_supported(
79
58
  self: MaxpoolQuantizer,
@@ -95,6 +74,7 @@ class MaxpoolQuantizer(BaseOpQuantizer):
95
74
  _ = initializer_map
96
75
  self.check_all_params_exist(node)
97
76
  self.check_params_size(node)
77
+ self.check_pool_pads(node)
98
78
 
99
79
  def check_all_params_exist(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
100
80
  """Checks all parameters that are needed, do exist
@@ -131,10 +111,40 @@ class MaxpoolQuantizer(BaseOpQuantizer):
131
111
  InvalidParamError: If shape requirement is not met.
132
112
  """
133
113
 
134
- kernel_shape = get_attribute_ints(node, "kernel_shape", default="N/A")
114
+ kernel_shape = get_attribute_ints(node, "kernel_shape", default=[])
135
115
  if len(kernel_shape) not in self.accepted_kernel_shapes:
136
116
  raise InvalidParamError(
137
117
  node.name,
138
118
  node.op_type,
139
119
  f"Currently only maxpool2d is supported. Found {len(kernel_shape)}D",
140
120
  )
121
+
122
+ def check_pool_pads(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
123
+ kernel_shape = get_attribute_ints(node, "kernel_shape", default=[])
124
+ pads = get_attribute_ints(node, "pads", default=None)
125
+ if pads is None:
126
+ return
127
+ num_dims = len(kernel_shape)
128
+ if len(pads) != num_dims * 2:
129
+ raise InvalidParamError(
130
+ node.name,
131
+ node.op_type,
132
+ f"Expected {num_dims * 2} pads, got {len(pads)}",
133
+ )
134
+
135
+ for dim in range(num_dims):
136
+ pad_before = pads[dim]
137
+ pad_after = pads[dim + num_dims]
138
+ kernel = kernel_shape[dim]
139
+ if pad_before >= kernel:
140
+ raise InvalidParamError(
141
+ node.name,
142
+ node.op_type,
143
+ f"pads[{dim}]={pad_before} >= kernel[{dim}]={kernel}",
144
+ )
145
+ if pad_after >= kernel:
146
+ raise InvalidParamError(
147
+ node.name,
148
+ node.op_type,
149
+ f"pads[{dim + num_dims}]={pad_after} >= kernel[{dim}]={kernel}",
150
+ )
@@ -1,14 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
- import onnx
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from onnx import GraphProto, NodeProto, TensorProto
4
7
 
5
8
  from python.core.model_processing.onnx_quantizer.layers.base import (
6
9
  BaseOpQuantizer,
10
+ QuantizerBase,
7
11
  ScaleConfig,
8
12
  )
9
13
 
10
14
 
11
- class ReluQuantizer(BaseOpQuantizer):
15
+ class QuantizeRelu(QuantizerBase):
16
+ OP_TYPE = "Int64Relu"
17
+ USE_WB = False
18
+ USE_SCALING = False
19
+
20
+
21
+ class ReluQuantizer(BaseOpQuantizer, QuantizeRelu):
12
22
  """
13
23
  Quantizer for ONNX ReLU layers.
14
24
 
@@ -19,49 +29,24 @@ class ReluQuantizer(BaseOpQuantizer):
19
29
 
20
30
  def __init__(
21
31
  self: ReluQuantizer,
22
- new_initializer: dict[str, onnx.TensorProto] | None = None,
32
+ new_initializer: list[TensorProto] | None = None,
23
33
  ) -> None:
24
34
  super().__init__()
25
35
  _ = new_initializer
26
36
 
27
37
  def quantize(
28
38
  self: ReluQuantizer,
29
- node: onnx.NodeProto,
30
- graph: onnx.GraphProto,
39
+ node: NodeProto,
40
+ graph: GraphProto,
31
41
  scale_config: ScaleConfig,
32
- initializer_map: dict[str, onnx.TensorProto],
33
- ) -> list[onnx.NodeProto]:
34
- """
35
- Quantize a node by converting the node to Int64 version
36
-
37
- Args:
38
- node (onnx.NodeProto): The node to quantize.
39
- rescale (bool): Whether rescaling is enabled
40
- (Doesnt have an affect on this op type)
41
- graph (onnx.GraphProto): The ONNX graph.
42
- scale_exponent (int): Scale exponent.
43
- scale_base (int): The base of scaling.
44
- initializer_map (dict[str, onnx.TensorProto]):
45
- Map of initializer names to tensor data.
46
-
47
- Returns:
48
- List[onnx.NodeProto]: The quantized ONNX node.
49
- """
50
- _ = graph, scale_config, initializer_map
51
- return [
52
- onnx.helper.make_node(
53
- "Int64Relu",
54
- inputs=node.input,
55
- outputs=node.output, # preserve original output name
56
- name=node.name,
57
- domain="ai.onnx.contrib",
58
- ),
59
- ]
42
+ initializer_map: dict[str, TensorProto],
43
+ ) -> list[NodeProto]:
44
+ return QuantizeRelu.quantize(self, node, graph, scale_config, initializer_map)
60
45
 
61
46
  def check_supported(
62
47
  self: ReluQuantizer,
63
- node: onnx.NodeProto,
64
- initializer_map: dict[str, onnx.TensorProto] | None = None,
48
+ node: NodeProto,
49
+ initializer_map: dict[str, TensorProto] | None = None,
65
50
  ) -> None:
66
51
  """
67
52
  Perform high-level validation to ensure that this node