JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +5 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import ClassVar
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import numpy_helper
|
|
8
|
+
|
|
9
|
+
from python.core.model_processing.onnx_quantizer.exceptions import (
|
|
10
|
+
HandlerImplementationError,
|
|
11
|
+
)
|
|
12
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
13
|
+
BaseOpQuantizer,
|
|
14
|
+
ScaleConfig,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ConstantQuantizer(BaseOpQuantizer):
|
|
19
|
+
"""
|
|
20
|
+
Quantizer for ONNX Constant node.
|
|
21
|
+
|
|
22
|
+
This quantizer only modifies constants that are:
|
|
23
|
+
- Numeric tensors
|
|
24
|
+
- Used directly in computation
|
|
25
|
+
|
|
26
|
+
Constants used for shape, indexing, or other non-numeric roles are left unchanged.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
DATA_OPS: ClassVar = {
|
|
30
|
+
"Add",
|
|
31
|
+
"Mul",
|
|
32
|
+
"Conv",
|
|
33
|
+
"MatMul",
|
|
34
|
+
"Sub",
|
|
35
|
+
"Div",
|
|
36
|
+
"Gemm",
|
|
37
|
+
} # ops that consume numeric constants
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self: ConstantQuantizer,
|
|
41
|
+
new_initializer: dict[str, onnx.TensorProto] | None = None,
|
|
42
|
+
) -> None:
|
|
43
|
+
super().__init__()
|
|
44
|
+
_ = new_initializer
|
|
45
|
+
|
|
46
|
+
def quantize(
|
|
47
|
+
self: ConstantQuantizer,
|
|
48
|
+
node: onnx.NodeProto,
|
|
49
|
+
graph: onnx.GraphProto,
|
|
50
|
+
scale_config: ScaleConfig,
|
|
51
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
52
|
+
) -> list[onnx.NodeProto]:
|
|
53
|
+
"""Apply quantization scaling to a constant if it is used in
|
|
54
|
+
numeric computation.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
node (onnx.NodeProto): The Constant node to quantize.
|
|
58
|
+
rescale (bool): Whether rescaling is enabled
|
|
59
|
+
(Doesnt have an affect on this op type in some cases)
|
|
60
|
+
graph (onnx.GraphProto): The ONNX graph.
|
|
61
|
+
scale_exponent (int): Scale exponent.
|
|
62
|
+
scale_base (int): The base of scaling
|
|
63
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
64
|
+
Map of initializer names to tensor data.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
list[onnx.NodeProto]: The modified node (possibly unchanged).
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
HandlerImplementationError: If tensor is unreadable
|
|
71
|
+
"""
|
|
72
|
+
_ = initializer_map
|
|
73
|
+
self.validate_node_has_output(node)
|
|
74
|
+
|
|
75
|
+
output_name = node.output[0]
|
|
76
|
+
|
|
77
|
+
is_data_constant = any(
|
|
78
|
+
output_name in n.input and n.op_type in self.DATA_OPS for n in graph.node
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if not is_data_constant:
|
|
82
|
+
# Skip quantization for non-numeric constants
|
|
83
|
+
return [node]
|
|
84
|
+
|
|
85
|
+
# Safe to quantize: numeric constant used in computation
|
|
86
|
+
for attr in node.attribute:
|
|
87
|
+
if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR:
|
|
88
|
+
try:
|
|
89
|
+
arr = numpy_helper.to_array(attr.t).astype(np.float64)
|
|
90
|
+
except (ValueError, Exception) as e:
|
|
91
|
+
raise HandlerImplementationError(
|
|
92
|
+
op_type="Constant",
|
|
93
|
+
message="Failed to read tensor from Constant node"
|
|
94
|
+
f" '{node.name}': {e}",
|
|
95
|
+
) from e
|
|
96
|
+
|
|
97
|
+
arr *= self.get_scaling(
|
|
98
|
+
scale_config.base,
|
|
99
|
+
scale_config.exponent,
|
|
100
|
+
)
|
|
101
|
+
attr.t.CopyFrom(numpy_helper.from_array(arr, name=""))
|
|
102
|
+
|
|
103
|
+
node.name += "_quant"
|
|
104
|
+
return [node]
|
|
105
|
+
|
|
106
|
+
def check_supported(
|
|
107
|
+
self: ConstantQuantizer,
|
|
108
|
+
node: onnx.NodeProto,
|
|
109
|
+
initializer_map: dict[str, onnx.TensorProto] | None = None,
|
|
110
|
+
) -> None:
|
|
111
|
+
"""All Constant nodes are supported... For now.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
node (onnx.NodeProto): Node to be checked
|
|
115
|
+
initializer_map (dict[str, onnx.TensorProto], optional):
|
|
116
|
+
Map of initializer names to tensor data. Defaults to None.
|
|
117
|
+
"""
|
|
118
|
+
_ = node, initializer_map
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import onnx
|
|
5
|
+
from onnx import numpy_helper
|
|
6
|
+
|
|
7
|
+
from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_attributes
|
|
8
|
+
from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
|
|
9
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
10
|
+
BaseOpQuantizer,
|
|
11
|
+
ScaleConfig,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ConvQuantizer(BaseOpQuantizer):
|
|
16
|
+
"""
|
|
17
|
+
Quantizer for ONNX Conv layers.
|
|
18
|
+
|
|
19
|
+
- Replaces standard Conv with Int64Conv from the `ai.onnx.contrib` domain
|
|
20
|
+
and makes relevant additional changes to the graph.
|
|
21
|
+
- Validates that all required Conv parameters are present.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self: ConvQuantizer,
|
|
26
|
+
new_initializers: dict[str, onnx.TensorProto],
|
|
27
|
+
) -> None:
|
|
28
|
+
self.new_initializers = new_initializers
|
|
29
|
+
|
|
30
|
+
def quantize(
|
|
31
|
+
self: ConvQuantizer,
|
|
32
|
+
node: onnx.NodeProto,
|
|
33
|
+
graph: onnx.GraphProto,
|
|
34
|
+
scale_config: ScaleConfig,
|
|
35
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
36
|
+
) -> list[onnx.NodeProto]:
|
|
37
|
+
"""
|
|
38
|
+
Quantize a Conv node by:
|
|
39
|
+
1. Quantizing its weights and bias.
|
|
40
|
+
2. Adding a scale constant.
|
|
41
|
+
3. Replacing it with an Int64Conv node.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
node (onnx.NodeProto): The node to quantize.
|
|
45
|
+
rescale (bool): Whether rescaling is enabled
|
|
46
|
+
(Doesnt have an affect on this op type)
|
|
47
|
+
graph (onnx.GraphProto): The ONNX graph.
|
|
48
|
+
scale_exponent (int): Scale exponent.
|
|
49
|
+
scale_base (int): The base of scaling.
|
|
50
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
51
|
+
Map of initializer names to tensor data.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
list[onnx.NodeProto]: A list of ONNX nodes
|
|
55
|
+
(quantized and any auxiliary nodes).
|
|
56
|
+
"""
|
|
57
|
+
_ = graph
|
|
58
|
+
|
|
59
|
+
nodes = []
|
|
60
|
+
output_name = f"{node.name}_int"
|
|
61
|
+
|
|
62
|
+
nodes, node.input[:] = self.add_nodes_w_and_b(
|
|
63
|
+
node=node,
|
|
64
|
+
scale_exponent=scale_config.exponent,
|
|
65
|
+
scale_base=scale_config.base,
|
|
66
|
+
initializer_map=initializer_map,
|
|
67
|
+
)
|
|
68
|
+
attrs = extract_attributes(node)
|
|
69
|
+
attrs.setdefault("group", 1)
|
|
70
|
+
attrs.setdefault("auto_pad", "NOTSET")
|
|
71
|
+
|
|
72
|
+
attrs["rescale"] = int(scale_config.rescale)
|
|
73
|
+
|
|
74
|
+
scale_value = self.get_scaling(
|
|
75
|
+
scale_config.base,
|
|
76
|
+
scale_config.exponent,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Create scale constant
|
|
80
|
+
scale_const_name = f"{output_name}_scaler"
|
|
81
|
+
scale_tensor = numpy_helper.from_array(
|
|
82
|
+
np.array([scale_value], dtype=np.int64),
|
|
83
|
+
name=scale_const_name,
|
|
84
|
+
)
|
|
85
|
+
self.new_initializers.append(scale_tensor)
|
|
86
|
+
node.input.append(scale_const_name)
|
|
87
|
+
int64_conv_node = onnx.helper.make_node(
|
|
88
|
+
"Int64Conv",
|
|
89
|
+
inputs=node.input,
|
|
90
|
+
outputs=node.output, # preserve original output name
|
|
91
|
+
name=node.name,
|
|
92
|
+
domain="ai.onnx.contrib",
|
|
93
|
+
**attrs,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
nodes.append(int64_conv_node)
|
|
97
|
+
return nodes
|
|
98
|
+
|
|
99
|
+
def check_supported(
|
|
100
|
+
self: ConvQuantizer,
|
|
101
|
+
node: onnx.NodeProto,
|
|
102
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
103
|
+
) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Perform high-level validation to ensure that this Conv node
|
|
106
|
+
can be quantized safely.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
node (onnx.NodeProto): ONNX node to be checked
|
|
110
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
111
|
+
Initializer map (name of weight or bias and tensor)
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
InvalidParamError: If any requirement is not met.
|
|
115
|
+
"""
|
|
116
|
+
num_inputs = 2
|
|
117
|
+
if len(node.input) < num_inputs:
|
|
118
|
+
raise InvalidParamError(
|
|
119
|
+
node.name,
|
|
120
|
+
node.op_type,
|
|
121
|
+
f"Expected at least 2 inputs (input, weights), got {len(node.input)}",
|
|
122
|
+
)
|
|
123
|
+
num_inputs = 3
|
|
124
|
+
|
|
125
|
+
if len(node.input) < num_inputs:
|
|
126
|
+
raise InvalidParamError(
|
|
127
|
+
node.name,
|
|
128
|
+
node.op_type,
|
|
129
|
+
"Expected at least 3 inputs (input, weights, bias),"
|
|
130
|
+
f" got {len(node.input)}",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.check_supported_shape(node, initializer_map)
|
|
134
|
+
self.check_all_params_exist(node)
|
|
135
|
+
|
|
136
|
+
def check_all_params_exist(self: ConvQuantizer, node: onnx.NodeProto) -> None:
|
|
137
|
+
"""Verify that all required Conv attributes are present.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
node (onnx.NodeProto): The Conv node being validated.
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
InvalidParamError: If any required parameter is missing.
|
|
144
|
+
"""
|
|
145
|
+
# May need: ["strides", "kernel_shape", "dilations", "pads"]
|
|
146
|
+
required_attrs = ["strides", "kernel_shape", "dilations"]
|
|
147
|
+
self.validate_required_attrs(node, required_attrs)
|
|
148
|
+
|
|
149
|
+
def check_supported_shape(
|
|
150
|
+
self: ConvQuantizer,
|
|
151
|
+
node: onnx.NodeProto,
|
|
152
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
153
|
+
) -> None:
|
|
154
|
+
"""Ensure that Conv weights are available and have the correct dimensionality.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
node (onnx.NodeProto): The node being validated.
|
|
158
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
159
|
+
Mapping of initializer tensor names to TensorProtos.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
InvalidParamError: If weights are missing or have an unsupported shape.
|
|
163
|
+
"""
|
|
164
|
+
supported_size = [4]
|
|
165
|
+
weight_name = node.input[1]
|
|
166
|
+
initializer = initializer_map.get(weight_name)
|
|
167
|
+
|
|
168
|
+
if initializer is None:
|
|
169
|
+
raise InvalidParamError(
|
|
170
|
+
node.name,
|
|
171
|
+
node.op_type,
|
|
172
|
+
f"Weight tensor '{weight_name}' not found in initializers",
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
weight_dims = list(initializer.dims)
|
|
176
|
+
|
|
177
|
+
if len(weight_dims) not in supported_size:
|
|
178
|
+
msg = f"Unsupported Conv weight dimensionality {len(weight_dims)}. "
|
|
179
|
+
msg += f"Expected 4D weights for Conv2D, got shape {weight_dims}"
|
|
180
|
+
raise InvalidParamError(node.name, node.op_type, msg)
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import onnx
|
|
5
|
+
from onnx import numpy_helper
|
|
6
|
+
|
|
7
|
+
from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_attributes
|
|
8
|
+
from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
|
|
9
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
10
|
+
BaseOpQuantizer,
|
|
11
|
+
ScaleConfig,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GemmQuantizer(BaseOpQuantizer):
|
|
16
|
+
"""
|
|
17
|
+
Quantizer for ONNX Gemm layers.
|
|
18
|
+
|
|
19
|
+
- Replaces standard Gemm with Int64Gemm from the `ai.onnx.contrib`
|
|
20
|
+
domain and makes relevant additional changes to the graph.
|
|
21
|
+
- Validates that all required Gemm parameters are present.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self: GemmQuantizer,
|
|
26
|
+
new_initializers: dict[str, onnx.TensorProto],
|
|
27
|
+
) -> None:
|
|
28
|
+
self.new_initializers = new_initializers
|
|
29
|
+
|
|
30
|
+
def quantize(
|
|
31
|
+
self: GemmQuantizer,
|
|
32
|
+
node: onnx.NodeProto,
|
|
33
|
+
graph: onnx.GraphProto,
|
|
34
|
+
scale_config: ScaleConfig,
|
|
35
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
36
|
+
) -> list[onnx.NodeProto]:
|
|
37
|
+
"""
|
|
38
|
+
Quantize a Gemm node by:
|
|
39
|
+
1. Quantizing its weights and bias.
|
|
40
|
+
2. Adding a scale constant.
|
|
41
|
+
3. Replacing it with an Int64Gemm node.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
node (onnx.NodeProto): The node to quantize.
|
|
45
|
+
rescale (bool): Whether rescaling is enabled
|
|
46
|
+
graph (onnx.GraphProto): The ONNX graph.
|
|
47
|
+
scale_exponent (int): Scale exponent.
|
|
48
|
+
scale_base (int): The base of scaling.
|
|
49
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
50
|
+
Map of initializer names to tensor data.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
List[onnx.NodeProto]: A list of ONNX nodes
|
|
54
|
+
(quantized and any auxiliary nodes).
|
|
55
|
+
"""
|
|
56
|
+
_ = graph
|
|
57
|
+
nodes = []
|
|
58
|
+
output_name = f"{node.name}_int"
|
|
59
|
+
|
|
60
|
+
nodes, new_inputs = self.add_nodes_w_and_b(
|
|
61
|
+
node=node,
|
|
62
|
+
scale_exponent=scale_config.exponent,
|
|
63
|
+
scale_base=scale_config.base,
|
|
64
|
+
initializer_map=initializer_map,
|
|
65
|
+
)
|
|
66
|
+
node.input[:] = new_inputs
|
|
67
|
+
|
|
68
|
+
attrs = extract_attributes(node)
|
|
69
|
+
attrs.setdefault("transA", 0)
|
|
70
|
+
attrs.setdefault("transB", 0)
|
|
71
|
+
attrs["rescale"] = int(scale_config.rescale)
|
|
72
|
+
|
|
73
|
+
scale_value = self.get_scaling(
|
|
74
|
+
scale_config.base,
|
|
75
|
+
scale_config.exponent,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# === Create scale constant ===
|
|
79
|
+
scale_const_name = f"{output_name}_scaler"
|
|
80
|
+
scale_tensor = numpy_helper.from_array(
|
|
81
|
+
np.array([scale_value], dtype=np.int64),
|
|
82
|
+
name=scale_const_name,
|
|
83
|
+
)
|
|
84
|
+
self.new_initializers.append(scale_tensor)
|
|
85
|
+
node.input.append(scale_const_name)
|
|
86
|
+
int64_gemm = onnx.helper.make_node(
|
|
87
|
+
"Int64Gemm",
|
|
88
|
+
inputs=node.input,
|
|
89
|
+
outputs=node.output, # preserve original output name
|
|
90
|
+
name=output_name,
|
|
91
|
+
domain="ai.onnx.contrib",
|
|
92
|
+
**attrs,
|
|
93
|
+
)
|
|
94
|
+
nodes.append(int64_gemm)
|
|
95
|
+
return nodes
|
|
96
|
+
|
|
97
|
+
def check_supported(
|
|
98
|
+
self: GemmQuantizer,
|
|
99
|
+
node: onnx.NodeProto,
|
|
100
|
+
initializer_map: dict[str, onnx.TensorProto] | None = None,
|
|
101
|
+
) -> None:
|
|
102
|
+
"""
|
|
103
|
+
Perform high-level validation to ensure that this node
|
|
104
|
+
can be quantized safely.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
node (onnx.NodeProto): ONNX node to be checked
|
|
108
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
109
|
+
Initializer map (name of weight or bias and tensor)
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
InvalidParamError: If any requirement is not met.
|
|
113
|
+
"""
|
|
114
|
+
_ = initializer_map
|
|
115
|
+
num_valid_inputs = 2
|
|
116
|
+
# Ensure inputs exist
|
|
117
|
+
if len(node.input) < num_valid_inputs:
|
|
118
|
+
raise InvalidParamError(
|
|
119
|
+
node.name,
|
|
120
|
+
node.op_type,
|
|
121
|
+
f"Expected at least 2 inputs (input, weights), got {len(node.input)}",
|
|
122
|
+
)
|
|
123
|
+
num_valid_inputs = 3
|
|
124
|
+
|
|
125
|
+
if len(node.input) < num_valid_inputs:
|
|
126
|
+
raise InvalidParamError(
|
|
127
|
+
node.name,
|
|
128
|
+
node.op_type,
|
|
129
|
+
"Expected at least 3 inputs (input, weights, bias)"
|
|
130
|
+
f", got {len(node.input)}",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Validate attributes with defaults
|
|
134
|
+
attrs = {attr.name: attr for attr in node.attribute}
|
|
135
|
+
alpha = getattr(attrs.get("alpha"), "f", 1.0)
|
|
136
|
+
beta = getattr(attrs.get("beta"), "f", 1.0)
|
|
137
|
+
trans_a = getattr(attrs.get("transA"), "i", 0)
|
|
138
|
+
trans_b = getattr(attrs.get("transB"), "i", 1)
|
|
139
|
+
|
|
140
|
+
if alpha != 1.0:
|
|
141
|
+
raise InvalidParamError(
|
|
142
|
+
node.name,
|
|
143
|
+
node.op_type,
|
|
144
|
+
f"alpha value of {alpha} not supported",
|
|
145
|
+
"alpha",
|
|
146
|
+
"1.0",
|
|
147
|
+
)
|
|
148
|
+
if beta != 1.0:
|
|
149
|
+
raise InvalidParamError(
|
|
150
|
+
node.name,
|
|
151
|
+
node.op_type,
|
|
152
|
+
f"beta value of {beta} not supported",
|
|
153
|
+
"beta",
|
|
154
|
+
"1.0",
|
|
155
|
+
)
|
|
156
|
+
if trans_a not in [0, 1]:
|
|
157
|
+
raise InvalidParamError(
|
|
158
|
+
node.name,
|
|
159
|
+
node.op_type,
|
|
160
|
+
f"transA value of {trans_a} not supported",
|
|
161
|
+
"transA",
|
|
162
|
+
"(0,1)",
|
|
163
|
+
)
|
|
164
|
+
if trans_b not in [0, 1]:
|
|
165
|
+
raise InvalidParamError(
|
|
166
|
+
node.name,
|
|
167
|
+
node.op_type,
|
|
168
|
+
f"transB value of {trans_b} not supported",
|
|
169
|
+
"transB",
|
|
170
|
+
"(0,1)",
|
|
171
|
+
)
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import onnx
|
|
4
|
+
from onnx import helper
|
|
5
|
+
|
|
6
|
+
from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
|
|
7
|
+
extract_attributes,
|
|
8
|
+
get_attribute_ints,
|
|
9
|
+
)
|
|
10
|
+
from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
|
|
11
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
12
|
+
BaseOpQuantizer,
|
|
13
|
+
ScaleConfig,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MaxpoolQuantizer(BaseOpQuantizer):
|
|
18
|
+
"""
|
|
19
|
+
Quantizer for ONNX MaxPool layers.
|
|
20
|
+
|
|
21
|
+
- Replaces standard MaxPool with Int64MaxPool from the `ai.onnx.contrib`
|
|
22
|
+
domain and makes relevant additional changes to the graph.
|
|
23
|
+
- Validates that all required MaxPool parameters are present.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self: MaxpoolQuantizer,
|
|
28
|
+
new_initializer: dict[str, onnx.TensorProto] | None = None,
|
|
29
|
+
) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.accepted_kernel_shapes = [2]
|
|
32
|
+
_ = new_initializer
|
|
33
|
+
|
|
34
|
+
def quantize(
|
|
35
|
+
self: BaseOpQuantizer,
|
|
36
|
+
node: onnx.NodeProto,
|
|
37
|
+
graph: onnx.GraphProto,
|
|
38
|
+
scale_config: ScaleConfig,
|
|
39
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
40
|
+
) -> list[onnx.NodeProto]:
|
|
41
|
+
"""
|
|
42
|
+
Quantize a node by converting the node to Int64 version
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
node (onnx.NodeProto): The node to quantize.
|
|
46
|
+
rescale (bool): Whether rescaling is enabled
|
|
47
|
+
(Doesnt have an affect on this op type)
|
|
48
|
+
graph (onnx.GraphProto): The ONNX graph.
|
|
49
|
+
scale_exponent (int): Scale exponent.
|
|
50
|
+
scale_base (int): The base of scaling.
|
|
51
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
52
|
+
Map of initializer names to tensor data.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
List[onnx.NodeProto]: A list of ONNX nodes
|
|
56
|
+
(quantized MaxPool and any auxiliary nodes).
|
|
57
|
+
"""
|
|
58
|
+
_ = initializer_map, graph
|
|
59
|
+
|
|
60
|
+
attrs = extract_attributes(node)
|
|
61
|
+
attrs["rescale"] = int(scale_config.rescale)
|
|
62
|
+
|
|
63
|
+
attr_str = {
|
|
64
|
+
k: ",".join(map(str, v)) if isinstance(v, list) else str(v)
|
|
65
|
+
for k, v in attrs.items()
|
|
66
|
+
}
|
|
67
|
+
return [
|
|
68
|
+
helper.make_node(
|
|
69
|
+
"Int64MaxPool",
|
|
70
|
+
inputs=node.input,
|
|
71
|
+
outputs=node.output,
|
|
72
|
+
name=node.name,
|
|
73
|
+
domain="ai.onnx.contrib",
|
|
74
|
+
**attr_str,
|
|
75
|
+
),
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
def check_supported(
|
|
79
|
+
self: MaxpoolQuantizer,
|
|
80
|
+
node: onnx.NodeProto,
|
|
81
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Perform high-level validation to ensure that this node
|
|
85
|
+
can be quantized safely.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
node (onnx.NodeProto): ONNX node to be checked
|
|
89
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
90
|
+
Initializer map (name of weight or bias and tensor)
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
InvalidParamError: If any requirement is not met.
|
|
94
|
+
"""
|
|
95
|
+
_ = initializer_map
|
|
96
|
+
self.check_all_params_exist(node)
|
|
97
|
+
self.check_params_size(node)
|
|
98
|
+
|
|
99
|
+
def check_all_params_exist(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
|
|
100
|
+
"""Checks all parameters that are needed, do exist
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
node (onnx.NodeProto): ONNX node to check
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
InvalidParamError: If shape requirement is not met.
|
|
107
|
+
"""
|
|
108
|
+
# May need: ["strides", "kernel_shape", "pads", "dilations"]
|
|
109
|
+
required_attrs = ["strides", "kernel_shape"]
|
|
110
|
+
self.validate_required_attrs(node, required_attrs)
|
|
111
|
+
|
|
112
|
+
# Check dimension of kernel
|
|
113
|
+
kernel_shape = get_attribute_ints(node, "kernel_shape", default=[])
|
|
114
|
+
if len(kernel_shape) not in self.accepted_kernel_shapes:
|
|
115
|
+
raise InvalidParamError(
|
|
116
|
+
node.name,
|
|
117
|
+
node.op_type,
|
|
118
|
+
"Currently only MaxPool2D is supported."
|
|
119
|
+
f"Found {len(kernel_shape)}D kernel",
|
|
120
|
+
"kernel_shape",
|
|
121
|
+
"2D",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def check_params_size(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
|
|
125
|
+
"""Checks dimension of the layer and ensures that it is supported
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
node (onnx.NodeProto): ONNX node to check
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
InvalidParamError: If shape requirement is not met.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
kernel_shape = get_attribute_ints(node, "kernel_shape", default="N/A")
|
|
135
|
+
if len(kernel_shape) not in self.accepted_kernel_shapes:
|
|
136
|
+
raise InvalidParamError(
|
|
137
|
+
node.name,
|
|
138
|
+
node.op_type,
|
|
139
|
+
f"Currently only maxpool2d is supported. Found {len(kernel_shape)}D",
|
|
140
|
+
)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import onnx
|
|
4
|
+
|
|
5
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
6
|
+
BaseOpQuantizer,
|
|
7
|
+
ScaleConfig,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ReluQuantizer(BaseOpQuantizer):
|
|
12
|
+
"""
|
|
13
|
+
Quantizer for ONNX ReLU layers.
|
|
14
|
+
|
|
15
|
+
- Replaces standard ReLU with Int64ReLU from the `ai.onnx.contrib` domain
|
|
16
|
+
and makes relevant additional changes to the graph.
|
|
17
|
+
- Validates that all required ReLU parameters are present.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self: ReluQuantizer,
|
|
22
|
+
new_initializer: dict[str, onnx.TensorProto] | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
super().__init__()
|
|
25
|
+
_ = new_initializer
|
|
26
|
+
|
|
27
|
+
def quantize(
|
|
28
|
+
self: ReluQuantizer,
|
|
29
|
+
node: onnx.NodeProto,
|
|
30
|
+
graph: onnx.GraphProto,
|
|
31
|
+
scale_config: ScaleConfig,
|
|
32
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
33
|
+
) -> list[onnx.NodeProto]:
|
|
34
|
+
"""
|
|
35
|
+
Quantize a node by converting the node to Int64 version
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
node (onnx.NodeProto): The node to quantize.
|
|
39
|
+
rescale (bool): Whether rescaling is enabled
|
|
40
|
+
(Doesnt have an affect on this op type)
|
|
41
|
+
graph (onnx.GraphProto): The ONNX graph.
|
|
42
|
+
scale_exponent (int): Scale exponent.
|
|
43
|
+
scale_base (int): The base of scaling.
|
|
44
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
45
|
+
Map of initializer names to tensor data.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
List[onnx.NodeProto]: The quantized ONNX node.
|
|
49
|
+
"""
|
|
50
|
+
_ = graph, scale_config, initializer_map
|
|
51
|
+
return [
|
|
52
|
+
onnx.helper.make_node(
|
|
53
|
+
"Int64Relu",
|
|
54
|
+
inputs=node.input,
|
|
55
|
+
outputs=node.output, # preserve original output name
|
|
56
|
+
name=node.name,
|
|
57
|
+
domain="ai.onnx.contrib",
|
|
58
|
+
),
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
def check_supported(
|
|
62
|
+
self: ReluQuantizer,
|
|
63
|
+
node: onnx.NodeProto,
|
|
64
|
+
initializer_map: dict[str, onnx.TensorProto] | None = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Perform high-level validation to ensure that this node
|
|
68
|
+
can be quantized safely.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
node (onnx.NodeProto): ONNX node to be checked
|
|
72
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
73
|
+
Initializer map (name of weight or bias and tensor)
|
|
74
|
+
"""
|
|
75
|
+
_ = node
|
|
76
|
+
_ = initializer_map
|