JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +5 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,118 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import ClassVar
4
+
5
+ import numpy as np
6
+ import onnx
7
+ from onnx import numpy_helper
8
+
9
+ from python.core.model_processing.onnx_quantizer.exceptions import (
10
+ HandlerImplementationError,
11
+ )
12
+ from python.core.model_processing.onnx_quantizer.layers.base import (
13
+ BaseOpQuantizer,
14
+ ScaleConfig,
15
+ )
16
+
17
+
18
+ class ConstantQuantizer(BaseOpQuantizer):
19
+ """
20
+ Quantizer for ONNX Constant node.
21
+
22
+ This quantizer only modifies constants that are:
23
+ - Numeric tensors
24
+ - Used directly in computation
25
+
26
+ Constants used for shape, indexing, or other non-numeric roles are left unchanged.
27
+ """
28
+
29
+ DATA_OPS: ClassVar = {
30
+ "Add",
31
+ "Mul",
32
+ "Conv",
33
+ "MatMul",
34
+ "Sub",
35
+ "Div",
36
+ "Gemm",
37
+ } # ops that consume numeric constants
38
+
39
+ def __init__(
40
+ self: ConstantQuantizer,
41
+ new_initializer: dict[str, onnx.TensorProto] | None = None,
42
+ ) -> None:
43
+ super().__init__()
44
+ _ = new_initializer
45
+
46
+ def quantize(
47
+ self: ConstantQuantizer,
48
+ node: onnx.NodeProto,
49
+ graph: onnx.GraphProto,
50
+ scale_config: ScaleConfig,
51
+ initializer_map: dict[str, onnx.TensorProto],
52
+ ) -> list[onnx.NodeProto]:
53
+ """Apply quantization scaling to a constant if it is used in
54
+ numeric computation.
55
+
56
+ Args:
57
+ node (onnx.NodeProto): The Constant node to quantize.
58
+ rescale (bool): Whether rescaling is enabled
59
+ (Doesnt have an affect on this op type in some cases)
60
+ graph (onnx.GraphProto): The ONNX graph.
61
+ scale_exponent (int): Scale exponent.
62
+ scale_base (int): The base of scaling
63
+ initializer_map (dict[str, onnx.TensorProto]):
64
+ Map of initializer names to tensor data.
65
+
66
+ Returns:
67
+ list[onnx.NodeProto]: The modified node (possibly unchanged).
68
+
69
+ Raises:
70
+ HandlerImplementationError: If tensor is unreadable
71
+ """
72
+ _ = initializer_map
73
+ self.validate_node_has_output(node)
74
+
75
+ output_name = node.output[0]
76
+
77
+ is_data_constant = any(
78
+ output_name in n.input and n.op_type in self.DATA_OPS for n in graph.node
79
+ )
80
+
81
+ if not is_data_constant:
82
+ # Skip quantization for non-numeric constants
83
+ return [node]
84
+
85
+ # Safe to quantize: numeric constant used in computation
86
+ for attr in node.attribute:
87
+ if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR:
88
+ try:
89
+ arr = numpy_helper.to_array(attr.t).astype(np.float64)
90
+ except (ValueError, Exception) as e:
91
+ raise HandlerImplementationError(
92
+ op_type="Constant",
93
+ message="Failed to read tensor from Constant node"
94
+ f" '{node.name}': {e}",
95
+ ) from e
96
+
97
+ arr *= self.get_scaling(
98
+ scale_config.base,
99
+ scale_config.exponent,
100
+ )
101
+ attr.t.CopyFrom(numpy_helper.from_array(arr, name=""))
102
+
103
+ node.name += "_quant"
104
+ return [node]
105
+
106
+ def check_supported(
107
+ self: ConstantQuantizer,
108
+ node: onnx.NodeProto,
109
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
110
+ ) -> None:
111
+ """All Constant nodes are supported... For now.
112
+
113
+ Args:
114
+ node (onnx.NodeProto): Node to be checked
115
+ initializer_map (dict[str, onnx.TensorProto], optional):
116
+ Map of initializer names to tensor data. Defaults to None.
117
+ """
118
+ _ = node, initializer_map
@@ -0,0 +1,180 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import onnx
5
+ from onnx import numpy_helper
6
+
7
+ from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_attributes
8
+ from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
9
+ from python.core.model_processing.onnx_quantizer.layers.base import (
10
+ BaseOpQuantizer,
11
+ ScaleConfig,
12
+ )
13
+
14
+
15
+ class ConvQuantizer(BaseOpQuantizer):
16
+ """
17
+ Quantizer for ONNX Conv layers.
18
+
19
+ - Replaces standard Conv with Int64Conv from the `ai.onnx.contrib` domain
20
+ and makes relevant additional changes to the graph.
21
+ - Validates that all required Conv parameters are present.
22
+ """
23
+
24
+ def __init__(
25
+ self: ConvQuantizer,
26
+ new_initializers: dict[str, onnx.TensorProto],
27
+ ) -> None:
28
+ self.new_initializers = new_initializers
29
+
30
+ def quantize(
31
+ self: ConvQuantizer,
32
+ node: onnx.NodeProto,
33
+ graph: onnx.GraphProto,
34
+ scale_config: ScaleConfig,
35
+ initializer_map: dict[str, onnx.TensorProto],
36
+ ) -> list[onnx.NodeProto]:
37
+ """
38
+ Quantize a Conv node by:
39
+ 1. Quantizing its weights and bias.
40
+ 2. Adding a scale constant.
41
+ 3. Replacing it with an Int64Conv node.
42
+
43
+ Args:
44
+ node (onnx.NodeProto): The node to quantize.
45
+ rescale (bool): Whether rescaling is enabled
46
+ (Doesnt have an affect on this op type)
47
+ graph (onnx.GraphProto): The ONNX graph.
48
+ scale_exponent (int): Scale exponent.
49
+ scale_base (int): The base of scaling.
50
+ initializer_map (dict[str, onnx.TensorProto]):
51
+ Map of initializer names to tensor data.
52
+
53
+ Returns:
54
+ list[onnx.NodeProto]: A list of ONNX nodes
55
+ (quantized and any auxiliary nodes).
56
+ """
57
+ _ = graph
58
+
59
+ nodes = []
60
+ output_name = f"{node.name}_int"
61
+
62
+ nodes, node.input[:] = self.add_nodes_w_and_b(
63
+ node=node,
64
+ scale_exponent=scale_config.exponent,
65
+ scale_base=scale_config.base,
66
+ initializer_map=initializer_map,
67
+ )
68
+ attrs = extract_attributes(node)
69
+ attrs.setdefault("group", 1)
70
+ attrs.setdefault("auto_pad", "NOTSET")
71
+
72
+ attrs["rescale"] = int(scale_config.rescale)
73
+
74
+ scale_value = self.get_scaling(
75
+ scale_config.base,
76
+ scale_config.exponent,
77
+ )
78
+
79
+ # Create scale constant
80
+ scale_const_name = f"{output_name}_scaler"
81
+ scale_tensor = numpy_helper.from_array(
82
+ np.array([scale_value], dtype=np.int64),
83
+ name=scale_const_name,
84
+ )
85
+ self.new_initializers.append(scale_tensor)
86
+ node.input.append(scale_const_name)
87
+ int64_conv_node = onnx.helper.make_node(
88
+ "Int64Conv",
89
+ inputs=node.input,
90
+ outputs=node.output, # preserve original output name
91
+ name=node.name,
92
+ domain="ai.onnx.contrib",
93
+ **attrs,
94
+ )
95
+
96
+ nodes.append(int64_conv_node)
97
+ return nodes
98
+
99
+ def check_supported(
100
+ self: ConvQuantizer,
101
+ node: onnx.NodeProto,
102
+ initializer_map: dict[str, onnx.TensorProto],
103
+ ) -> None:
104
+ """
105
+ Perform high-level validation to ensure that this Conv node
106
+ can be quantized safely.
107
+
108
+ Args:
109
+ node (onnx.NodeProto): ONNX node to be checked
110
+ initializer_map (dict[str, onnx.TensorProto]):
111
+ Initializer map (name of weight or bias and tensor)
112
+
113
+ Raises:
114
+ InvalidParamError: If any requirement is not met.
115
+ """
116
+ num_inputs = 2
117
+ if len(node.input) < num_inputs:
118
+ raise InvalidParamError(
119
+ node.name,
120
+ node.op_type,
121
+ f"Expected at least 2 inputs (input, weights), got {len(node.input)}",
122
+ )
123
+ num_inputs = 3
124
+
125
+ if len(node.input) < num_inputs:
126
+ raise InvalidParamError(
127
+ node.name,
128
+ node.op_type,
129
+ "Expected at least 3 inputs (input, weights, bias),"
130
+ f" got {len(node.input)}",
131
+ )
132
+
133
+ self.check_supported_shape(node, initializer_map)
134
+ self.check_all_params_exist(node)
135
+
136
+ def check_all_params_exist(self: ConvQuantizer, node: onnx.NodeProto) -> None:
137
+ """Verify that all required Conv attributes are present.
138
+
139
+ Args:
140
+ node (onnx.NodeProto): The Conv node being validated.
141
+
142
+ Raises:
143
+ InvalidParamError: If any required parameter is missing.
144
+ """
145
+ # May need: ["strides", "kernel_shape", "dilations", "pads"]
146
+ required_attrs = ["strides", "kernel_shape", "dilations"]
147
+ self.validate_required_attrs(node, required_attrs)
148
+
149
+ def check_supported_shape(
150
+ self: ConvQuantizer,
151
+ node: onnx.NodeProto,
152
+ initializer_map: dict[str, onnx.TensorProto],
153
+ ) -> None:
154
+ """Ensure that Conv weights are available and have the correct dimensionality.
155
+
156
+ Args:
157
+ node (onnx.NodeProto): The node being validated.
158
+ initializer_map (dict[str, onnx.TensorProto]):
159
+ Mapping of initializer tensor names to TensorProtos.
160
+
161
+ Raises:
162
+ InvalidParamError: If weights are missing or have an unsupported shape.
163
+ """
164
+ supported_size = [4]
165
+ weight_name = node.input[1]
166
+ initializer = initializer_map.get(weight_name)
167
+
168
+ if initializer is None:
169
+ raise InvalidParamError(
170
+ node.name,
171
+ node.op_type,
172
+ f"Weight tensor '{weight_name}' not found in initializers",
173
+ )
174
+
175
+ weight_dims = list(initializer.dims)
176
+
177
+ if len(weight_dims) not in supported_size:
178
+ msg = f"Unsupported Conv weight dimensionality {len(weight_dims)}. "
179
+ msg += f"Expected 4D weights for Conv2D, got shape {weight_dims}"
180
+ raise InvalidParamError(node.name, node.op_type, msg)
@@ -0,0 +1,171 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import onnx
5
+ from onnx import numpy_helper
6
+
7
+ from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_attributes
8
+ from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
9
+ from python.core.model_processing.onnx_quantizer.layers.base import (
10
+ BaseOpQuantizer,
11
+ ScaleConfig,
12
+ )
13
+
14
+
15
+ class GemmQuantizer(BaseOpQuantizer):
16
+ """
17
+ Quantizer for ONNX Gemm layers.
18
+
19
+ - Replaces standard Gemm with Int64Gemm from the `ai.onnx.contrib`
20
+ domain and makes relevant additional changes to the graph.
21
+ - Validates that all required Gemm parameters are present.
22
+ """
23
+
24
+ def __init__(
25
+ self: GemmQuantizer,
26
+ new_initializers: dict[str, onnx.TensorProto],
27
+ ) -> None:
28
+ self.new_initializers = new_initializers
29
+
30
+ def quantize(
31
+ self: GemmQuantizer,
32
+ node: onnx.NodeProto,
33
+ graph: onnx.GraphProto,
34
+ scale_config: ScaleConfig,
35
+ initializer_map: dict[str, onnx.TensorProto],
36
+ ) -> list[onnx.NodeProto]:
37
+ """
38
+ Quantize a Gemm node by:
39
+ 1. Quantizing its weights and bias.
40
+ 2. Adding a scale constant.
41
+ 3. Replacing it with an Int64Gemm node.
42
+
43
+ Args:
44
+ node (onnx.NodeProto): The node to quantize.
45
+ rescale (bool): Whether rescaling is enabled
46
+ graph (onnx.GraphProto): The ONNX graph.
47
+ scale_exponent (int): Scale exponent.
48
+ scale_base (int): The base of scaling.
49
+ initializer_map (dict[str, onnx.TensorProto]):
50
+ Map of initializer names to tensor data.
51
+
52
+ Returns:
53
+ List[onnx.NodeProto]: A list of ONNX nodes
54
+ (quantized and any auxiliary nodes).
55
+ """
56
+ _ = graph
57
+ nodes = []
58
+ output_name = f"{node.name}_int"
59
+
60
+ nodes, new_inputs = self.add_nodes_w_and_b(
61
+ node=node,
62
+ scale_exponent=scale_config.exponent,
63
+ scale_base=scale_config.base,
64
+ initializer_map=initializer_map,
65
+ )
66
+ node.input[:] = new_inputs
67
+
68
+ attrs = extract_attributes(node)
69
+ attrs.setdefault("transA", 0)
70
+ attrs.setdefault("transB", 0)
71
+ attrs["rescale"] = int(scale_config.rescale)
72
+
73
+ scale_value = self.get_scaling(
74
+ scale_config.base,
75
+ scale_config.exponent,
76
+ )
77
+
78
+ # === Create scale constant ===
79
+ scale_const_name = f"{output_name}_scaler"
80
+ scale_tensor = numpy_helper.from_array(
81
+ np.array([scale_value], dtype=np.int64),
82
+ name=scale_const_name,
83
+ )
84
+ self.new_initializers.append(scale_tensor)
85
+ node.input.append(scale_const_name)
86
+ int64_gemm = onnx.helper.make_node(
87
+ "Int64Gemm",
88
+ inputs=node.input,
89
+ outputs=node.output, # preserve original output name
90
+ name=output_name,
91
+ domain="ai.onnx.contrib",
92
+ **attrs,
93
+ )
94
+ nodes.append(int64_gemm)
95
+ return nodes
96
+
97
+ def check_supported(
98
+ self: GemmQuantizer,
99
+ node: onnx.NodeProto,
100
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
101
+ ) -> None:
102
+ """
103
+ Perform high-level validation to ensure that this node
104
+ can be quantized safely.
105
+
106
+ Args:
107
+ node (onnx.NodeProto): ONNX node to be checked
108
+ initializer_map (dict[str, onnx.TensorProto]):
109
+ Initializer map (name of weight or bias and tensor)
110
+
111
+ Raises:
112
+ InvalidParamError: If any requirement is not met.
113
+ """
114
+ _ = initializer_map
115
+ num_valid_inputs = 2
116
+ # Ensure inputs exist
117
+ if len(node.input) < num_valid_inputs:
118
+ raise InvalidParamError(
119
+ node.name,
120
+ node.op_type,
121
+ f"Expected at least 2 inputs (input, weights), got {len(node.input)}",
122
+ )
123
+ num_valid_inputs = 3
124
+
125
+ if len(node.input) < num_valid_inputs:
126
+ raise InvalidParamError(
127
+ node.name,
128
+ node.op_type,
129
+ "Expected at least 3 inputs (input, weights, bias)"
130
+ f", got {len(node.input)}",
131
+ )
132
+
133
+ # Validate attributes with defaults
134
+ attrs = {attr.name: attr for attr in node.attribute}
135
+ alpha = getattr(attrs.get("alpha"), "f", 1.0)
136
+ beta = getattr(attrs.get("beta"), "f", 1.0)
137
+ trans_a = getattr(attrs.get("transA"), "i", 0)
138
+ trans_b = getattr(attrs.get("transB"), "i", 1)
139
+
140
+ if alpha != 1.0:
141
+ raise InvalidParamError(
142
+ node.name,
143
+ node.op_type,
144
+ f"alpha value of {alpha} not supported",
145
+ "alpha",
146
+ "1.0",
147
+ )
148
+ if beta != 1.0:
149
+ raise InvalidParamError(
150
+ node.name,
151
+ node.op_type,
152
+ f"beta value of {beta} not supported",
153
+ "beta",
154
+ "1.0",
155
+ )
156
+ if trans_a not in [0, 1]:
157
+ raise InvalidParamError(
158
+ node.name,
159
+ node.op_type,
160
+ f"transA value of {trans_a} not supported",
161
+ "transA",
162
+ "(0,1)",
163
+ )
164
+ if trans_b not in [0, 1]:
165
+ raise InvalidParamError(
166
+ node.name,
167
+ node.op_type,
168
+ f"transB value of {trans_b} not supported",
169
+ "transB",
170
+ "(0,1)",
171
+ )
@@ -0,0 +1,140 @@
1
+ from __future__ import annotations
2
+
3
+ import onnx
4
+ from onnx import helper
5
+
6
+ from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
7
+ extract_attributes,
8
+ get_attribute_ints,
9
+ )
10
+ from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
11
+ from python.core.model_processing.onnx_quantizer.layers.base import (
12
+ BaseOpQuantizer,
13
+ ScaleConfig,
14
+ )
15
+
16
+
17
+ class MaxpoolQuantizer(BaseOpQuantizer):
18
+ """
19
+ Quantizer for ONNX MaxPool layers.
20
+
21
+ - Replaces standard MaxPool with Int64MaxPool from the `ai.onnx.contrib`
22
+ domain and makes relevant additional changes to the graph.
23
+ - Validates that all required MaxPool parameters are present.
24
+ """
25
+
26
+ def __init__(
27
+ self: MaxpoolQuantizer,
28
+ new_initializer: dict[str, onnx.TensorProto] | None = None,
29
+ ) -> None:
30
+ super().__init__()
31
+ self.accepted_kernel_shapes = [2]
32
+ _ = new_initializer
33
+
34
+ def quantize(
35
+ self: BaseOpQuantizer,
36
+ node: onnx.NodeProto,
37
+ graph: onnx.GraphProto,
38
+ scale_config: ScaleConfig,
39
+ initializer_map: dict[str, onnx.TensorProto],
40
+ ) -> list[onnx.NodeProto]:
41
+ """
42
+ Quantize a node by converting the node to Int64 version
43
+
44
+ Args:
45
+ node (onnx.NodeProto): The node to quantize.
46
+ rescale (bool): Whether rescaling is enabled
47
+ (Doesnt have an affect on this op type)
48
+ graph (onnx.GraphProto): The ONNX graph.
49
+ scale_exponent (int): Scale exponent.
50
+ scale_base (int): The base of scaling.
51
+ initializer_map (dict[str, onnx.TensorProto]):
52
+ Map of initializer names to tensor data.
53
+
54
+ Returns:
55
+ List[onnx.NodeProto]: A list of ONNX nodes
56
+ (quantized MaxPool and any auxiliary nodes).
57
+ """
58
+ _ = initializer_map, graph
59
+
60
+ attrs = extract_attributes(node)
61
+ attrs["rescale"] = int(scale_config.rescale)
62
+
63
+ attr_str = {
64
+ k: ",".join(map(str, v)) if isinstance(v, list) else str(v)
65
+ for k, v in attrs.items()
66
+ }
67
+ return [
68
+ helper.make_node(
69
+ "Int64MaxPool",
70
+ inputs=node.input,
71
+ outputs=node.output,
72
+ name=node.name,
73
+ domain="ai.onnx.contrib",
74
+ **attr_str,
75
+ ),
76
+ ]
77
+
78
+ def check_supported(
79
+ self: MaxpoolQuantizer,
80
+ node: onnx.NodeProto,
81
+ initializer_map: dict[str, onnx.TensorProto],
82
+ ) -> None:
83
+ """
84
+ Perform high-level validation to ensure that this node
85
+ can be quantized safely.
86
+
87
+ Args:
88
+ node (onnx.NodeProto): ONNX node to be checked
89
+ initializer_map (dict[str, onnx.TensorProto]):
90
+ Initializer map (name of weight or bias and tensor)
91
+
92
+ Raises:
93
+ InvalidParamError: If any requirement is not met.
94
+ """
95
+ _ = initializer_map
96
+ self.check_all_params_exist(node)
97
+ self.check_params_size(node)
98
+
99
+ def check_all_params_exist(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
100
+ """Checks all parameters that are needed, do exist
101
+
102
+ Args:
103
+ node (onnx.NodeProto): ONNX node to check
104
+
105
+ Raises:
106
+ InvalidParamError: If shape requirement is not met.
107
+ """
108
+ # May need: ["strides", "kernel_shape", "pads", "dilations"]
109
+ required_attrs = ["strides", "kernel_shape"]
110
+ self.validate_required_attrs(node, required_attrs)
111
+
112
+ # Check dimension of kernel
113
+ kernel_shape = get_attribute_ints(node, "kernel_shape", default=[])
114
+ if len(kernel_shape) not in self.accepted_kernel_shapes:
115
+ raise InvalidParamError(
116
+ node.name,
117
+ node.op_type,
118
+ "Currently only MaxPool2D is supported."
119
+ f"Found {len(kernel_shape)}D kernel",
120
+ "kernel_shape",
121
+ "2D",
122
+ )
123
+
124
+ def check_params_size(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
125
+ """Checks dimension of the layer and ensures that it is supported
126
+
127
+ Args:
128
+ node (onnx.NodeProto): ONNX node to check
129
+
130
+ Raises:
131
+ InvalidParamError: If shape requirement is not met.
132
+ """
133
+
134
+ kernel_shape = get_attribute_ints(node, "kernel_shape", default="N/A")
135
+ if len(kernel_shape) not in self.accepted_kernel_shapes:
136
+ raise InvalidParamError(
137
+ node.name,
138
+ node.op_type,
139
+ f"Currently only maxpool2d is supported. Found {len(kernel_shape)}D",
140
+ )
@@ -0,0 +1,76 @@
1
+ from __future__ import annotations
2
+
3
+ import onnx
4
+
5
+ from python.core.model_processing.onnx_quantizer.layers.base import (
6
+ BaseOpQuantizer,
7
+ ScaleConfig,
8
+ )
9
+
10
+
11
+ class ReluQuantizer(BaseOpQuantizer):
12
+ """
13
+ Quantizer for ONNX ReLU layers.
14
+
15
+ - Replaces standard ReLU with Int64ReLU from the `ai.onnx.contrib` domain
16
+ and makes relevant additional changes to the graph.
17
+ - Validates that all required ReLU parameters are present.
18
+ """
19
+
20
+ def __init__(
21
+ self: ReluQuantizer,
22
+ new_initializer: dict[str, onnx.TensorProto] | None = None,
23
+ ) -> None:
24
+ super().__init__()
25
+ _ = new_initializer
26
+
27
+ def quantize(
28
+ self: ReluQuantizer,
29
+ node: onnx.NodeProto,
30
+ graph: onnx.GraphProto,
31
+ scale_config: ScaleConfig,
32
+ initializer_map: dict[str, onnx.TensorProto],
33
+ ) -> list[onnx.NodeProto]:
34
+ """
35
+ Quantize a node by converting the node to Int64 version
36
+
37
+ Args:
38
+ node (onnx.NodeProto): The node to quantize.
39
+ rescale (bool): Whether rescaling is enabled
40
+ (Doesnt have an affect on this op type)
41
+ graph (onnx.GraphProto): The ONNX graph.
42
+ scale_exponent (int): Scale exponent.
43
+ scale_base (int): The base of scaling.
44
+ initializer_map (dict[str, onnx.TensorProto]):
45
+ Map of initializer names to tensor data.
46
+
47
+ Returns:
48
+ List[onnx.NodeProto]: The quantized ONNX node.
49
+ """
50
+ _ = graph, scale_config, initializer_map
51
+ return [
52
+ onnx.helper.make_node(
53
+ "Int64Relu",
54
+ inputs=node.input,
55
+ outputs=node.output, # preserve original output name
56
+ name=node.name,
57
+ domain="ai.onnx.contrib",
58
+ ),
59
+ ]
60
+
61
+ def check_supported(
62
+ self: ReluQuantizer,
63
+ node: onnx.NodeProto,
64
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
65
+ ) -> None:
66
+ """
67
+ Perform high-level validation to ensure that this node
68
+ can be quantized safely.
69
+
70
+ Args:
71
+ node (onnx.NodeProto): ONNX node to be checked
72
+ initializer_map (dict[str, onnx.TensorProto]):
73
+ Initializer map (name of weight or bias and tensor)
74
+ """
75
+ _ = node
76
+ _ = initializer_map