JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
- python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
- python/core/circuit_models/generic_onnx.py +43 -9
- python/core/circuits/base.py +231 -71
- python/core/model_processing/converters/onnx_converter.py +114 -59
- python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
- python/core/model_processing/onnx_custom_ops/mul.py +66 -0
- python/core/model_processing/onnx_custom_ops/relu.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
- python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
- python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
- python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
- python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
- python/core/utils/general_layer_functions.py +17 -12
- python/core/utils/model_registry.py +6 -3
- python/scripts/gen_and_bench.py +2 -2
- python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
- python/tests/circuit_parent_classes/test_circuit.py +561 -38
- python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
- python/tests/onnx_quantizer_tests/__init__.py +1 -0
- python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
- python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/base.py +279 -0
- python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
- python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
- python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
- python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
- python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
- python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
- python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
- python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
- python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
- python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
- python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
- python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from python.core.model_processing.onnx_quantizer.exceptions import (
|
|
4
|
+
REPORTING_URL,
|
|
5
|
+
InvalidParamError,
|
|
6
|
+
QuantizationError,
|
|
7
|
+
UnsupportedOpError,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.mark.unit
|
|
12
|
+
def test_quantization_error_message() -> None:
|
|
13
|
+
custom_msg = "Something went wrong."
|
|
14
|
+
with pytest.raises(QuantizationError) as exc_info:
|
|
15
|
+
raise QuantizationError(custom_msg)
|
|
16
|
+
assert "This model is not supported by JSTprove." in str(exc_info.value)
|
|
17
|
+
assert custom_msg in str(exc_info.value)
|
|
18
|
+
|
|
19
|
+
assert REPORTING_URL in str(exc_info.value)
|
|
20
|
+
|
|
21
|
+
assert "Submit model support requests via the JSTprove channel:" in str(
|
|
22
|
+
exc_info.value,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.mark.unit
|
|
27
|
+
def test_invalid_param_error_basic() -> None:
|
|
28
|
+
with pytest.raises(InvalidParamError) as exc_info:
|
|
29
|
+
raise InvalidParamError(
|
|
30
|
+
node_name="Conv_1",
|
|
31
|
+
op_type="Conv",
|
|
32
|
+
message="Missing 'strides' attribute.",
|
|
33
|
+
)
|
|
34
|
+
err_msg = str(exc_info.value)
|
|
35
|
+
assert "Invalid parameters in node 'Conv_1'" in err_msg
|
|
36
|
+
assert "(op_type='Conv')" in err_msg
|
|
37
|
+
assert "Missing 'strides' attribute." in err_msg
|
|
38
|
+
assert "[Attribute:" not in err_msg
|
|
39
|
+
assert "[Expected:" not in err_msg
|
|
40
|
+
|
|
41
|
+
msg = ""
|
|
42
|
+
|
|
43
|
+
with pytest.raises(QuantizationError) as exc_info_quantization:
|
|
44
|
+
raise QuantizationError(msg)
|
|
45
|
+
# Assert contains generic error message from quantization error
|
|
46
|
+
assert str(exc_info_quantization.value) in err_msg
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.mark.unit
|
|
50
|
+
def test_invalid_param_error_with_attr_and_expected() -> None:
|
|
51
|
+
with pytest.raises(InvalidParamError) as exc_info:
|
|
52
|
+
raise InvalidParamError(
|
|
53
|
+
node_name="MaxPool_3",
|
|
54
|
+
op_type="MaxPool",
|
|
55
|
+
message="Kernel shape is invalid.",
|
|
56
|
+
attr_key="kernel_shape",
|
|
57
|
+
expected="a list of 2 positive integers",
|
|
58
|
+
)
|
|
59
|
+
err_msg = str(exc_info.value)
|
|
60
|
+
assert "Invalid parameters in node 'MaxPool_3'" in err_msg
|
|
61
|
+
assert "[Attribute: kernel_shape]" in err_msg
|
|
62
|
+
assert "[Expected: a list of 2 positive integers]" in err_msg
|
|
63
|
+
msg = ""
|
|
64
|
+
|
|
65
|
+
with pytest.raises(QuantizationError) as exc_info_quantization:
|
|
66
|
+
raise QuantizationError(msg)
|
|
67
|
+
# Assert contains generic error message from quantization error
|
|
68
|
+
assert str(exc_info_quantization.value) in err_msg
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@pytest.mark.unit
|
|
72
|
+
def test_unsupported_op_error_with_node() -> None:
|
|
73
|
+
with pytest.raises(UnsupportedOpError) as exc_info:
|
|
74
|
+
raise UnsupportedOpError(op_type="Resize", node_name="Resize_42")
|
|
75
|
+
err_msg = str(exc_info.value)
|
|
76
|
+
assert "Unsupported op type: 'Resize'" in err_msg
|
|
77
|
+
assert "in node 'Resize_42'" in err_msg
|
|
78
|
+
assert "documentation for supported layers" in err_msg
|
|
79
|
+
msg = ""
|
|
80
|
+
|
|
81
|
+
with pytest.raises(QuantizationError) as exc_info_quantization:
|
|
82
|
+
raise QuantizationError(msg)
|
|
83
|
+
# Assert contains generic error message from quantization error
|
|
84
|
+
assert str(exc_info_quantization.value) in err_msg
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@pytest.mark.unit
|
|
88
|
+
def test_unsupported_op_error_without_node() -> None:
|
|
89
|
+
with pytest.raises(UnsupportedOpError) as exc_info:
|
|
90
|
+
raise UnsupportedOpError(op_type="Upsample")
|
|
91
|
+
err_msg = str(exc_info.value)
|
|
92
|
+
assert "Unsupported op type: 'Upsample'" in err_msg
|
|
93
|
+
assert "in node" not in err_msg
|
|
94
|
+
msg = ""
|
|
95
|
+
|
|
96
|
+
with pytest.raises(QuantizationError) as exc_info_quantization:
|
|
97
|
+
raise QuantizationError(msg)
|
|
98
|
+
# Assert contains generic error message from quantization error
|
|
99
|
+
assert str(exc_info_quantization.value) in err_msg
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
import onnx
|
|
6
|
+
import pytest
|
|
7
|
+
from onnx import GraphProto, ModelProto, NodeProto, TensorProto, helper
|
|
8
|
+
|
|
9
|
+
from python.core.model_processing.onnx_quantizer.exceptions import (
|
|
10
|
+
MissingHandlerError,
|
|
11
|
+
QuantizationError,
|
|
12
|
+
UnsupportedOpError,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from python.core.model_processing.onnx_quantizer.layers.base import ScaleConfig
|
|
17
|
+
|
|
18
|
+
# Optional: mock layers if needed
|
|
19
|
+
from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
|
|
20
|
+
ONNXOpQuantizer,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Mocks
|
|
25
|
+
class MockHandler:
|
|
26
|
+
def __init__(self: MockHandler) -> None:
|
|
27
|
+
self.called_quantize = False
|
|
28
|
+
self.called_supported = False
|
|
29
|
+
|
|
30
|
+
def quantize(
|
|
31
|
+
self: MockHandler,
|
|
32
|
+
node: NodeProto,
|
|
33
|
+
graph: GraphProto,
|
|
34
|
+
scale_config: ScaleConfig,
|
|
35
|
+
initializer_map: dict[str, TensorProto],
|
|
36
|
+
) -> list[NodeProto]:
|
|
37
|
+
_ = graph, scale_config, initializer_map
|
|
38
|
+
self.called_quantize = True
|
|
39
|
+
return [node] # Return the original node as a list for simplicity
|
|
40
|
+
|
|
41
|
+
def check_supported(
|
|
42
|
+
self: MockHandler,
|
|
43
|
+
node: NodeProto,
|
|
44
|
+
initializer_map: dict[str, TensorProto],
|
|
45
|
+
) -> None:
|
|
46
|
+
_ = initializer_map
|
|
47
|
+
self.called_supported = True
|
|
48
|
+
if node.name == "bad_node":
|
|
49
|
+
msg = "Invalid node parameters"
|
|
50
|
+
raise ValueError(msg)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Fixtures
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def quantizer() -> ONNXOpQuantizer:
|
|
56
|
+
return ONNXOpQuantizer()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def dummy_node() -> NodeProto:
|
|
61
|
+
return helper.make_node("FakeOp", inputs=["x"], outputs=["y"])
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.fixture
|
|
65
|
+
def valid_node() -> NodeProto:
|
|
66
|
+
return helper.make_node("Dummy", inputs=["x"], outputs=["y"], name="good_node")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.fixture
|
|
70
|
+
def invalid_node() -> NodeProto:
|
|
71
|
+
return helper.make_node("Dummy", inputs=["x"], outputs=["y"], name="bad_node")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@pytest.fixture
|
|
75
|
+
def dummy_model(valid_node: NodeProto, invalid_node: NodeProto) -> ModelProto:
|
|
76
|
+
graph = helper.make_graph(
|
|
77
|
+
[valid_node, invalid_node],
|
|
78
|
+
"test_graph",
|
|
79
|
+
inputs=[],
|
|
80
|
+
outputs=[],
|
|
81
|
+
initializer=[helper.make_tensor("x", TensorProto.FLOAT, [1], [0.5])],
|
|
82
|
+
)
|
|
83
|
+
return helper.make_model(graph)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# Tests
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@pytest.mark.unit
|
|
90
|
+
def test_check_model_raises_on_unsupported_op() -> None:
|
|
91
|
+
quantizer = ONNXOpQuantizer()
|
|
92
|
+
|
|
93
|
+
unsupported_node = helper.make_node("UnsupportedOp", ["x"], ["y"])
|
|
94
|
+
graph = helper.make_graph([unsupported_node], "test_graph", [], [])
|
|
95
|
+
model = helper.make_model(graph)
|
|
96
|
+
|
|
97
|
+
with pytest.raises(UnsupportedOpError):
|
|
98
|
+
quantizer.check_model(model)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@pytest.mark.unit
|
|
102
|
+
def test_check_layer_invokes_check_supported() -> None:
|
|
103
|
+
quantizer = ONNXOpQuantizer()
|
|
104
|
+
handler = MockHandler()
|
|
105
|
+
quantizer.register("FakeOp", handler)
|
|
106
|
+
|
|
107
|
+
node = helper.make_node("FakeOp", ["x"], ["y"])
|
|
108
|
+
initializer_map = {}
|
|
109
|
+
|
|
110
|
+
quantizer.check_layer(node, initializer_map)
|
|
111
|
+
# Check that check_supported is called
|
|
112
|
+
assert handler.called_supported
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@pytest.mark.unit
|
|
116
|
+
def test_get_initializer_map_returns_correct_dict() -> None:
|
|
117
|
+
quantizer = ONNXOpQuantizer()
|
|
118
|
+
|
|
119
|
+
tensor = helper.make_tensor(
|
|
120
|
+
name="W",
|
|
121
|
+
data_type=TensorProto.FLOAT,
|
|
122
|
+
dims=[1],
|
|
123
|
+
vals=[1.0],
|
|
124
|
+
)
|
|
125
|
+
graph = helper.make_graph([], "test_graph", [], [], [tensor])
|
|
126
|
+
model = helper.make_model(graph)
|
|
127
|
+
|
|
128
|
+
init_map = quantizer.get_initializer_map(model)
|
|
129
|
+
# Test initializer in map
|
|
130
|
+
assert "W" in init_map
|
|
131
|
+
# Test initializer map lines up
|
|
132
|
+
assert init_map["W"] == tensor
|
|
133
|
+
# Enhanced: check tensor properties
|
|
134
|
+
assert init_map["W"].data_type == TensorProto.FLOAT
|
|
135
|
+
assert init_map["W"].dims == [1]
|
|
136
|
+
assert onnx.numpy_helper.to_array(init_map["W"])[0] == 1.0
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@pytest.mark.unit
|
|
140
|
+
def test_quantize_with_unregistered_op_warns(dummy_node: NodeProto) -> None:
|
|
141
|
+
quantizer = ONNXOpQuantizer()
|
|
142
|
+
graph = helper.make_graph([], "g", [], [])
|
|
143
|
+
with pytest.raises(UnsupportedOpError) as excinfo:
|
|
144
|
+
_ = quantizer.quantize(dummy_node, graph, 1, 1, {}, rescale=False)
|
|
145
|
+
|
|
146
|
+
captured = str(excinfo.value)
|
|
147
|
+
assert "Unsupported op type: 'FakeOp'" in captured
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# Could be unit or integration?
|
|
151
|
+
@pytest.mark.unit
|
|
152
|
+
def test_check_model_raises_unsupported(dummy_model: ModelProto) -> None:
|
|
153
|
+
quantizer = ONNXOpQuantizer()
|
|
154
|
+
quantizer.handlers = {"Dummy": MockHandler()}
|
|
155
|
+
|
|
156
|
+
# Remove one node to simulate unsupported ops
|
|
157
|
+
dummy_model.graph.node.append(helper.make_node("FakeOp", ["a"], ["b"]))
|
|
158
|
+
|
|
159
|
+
with pytest.raises(UnsupportedOpError) as excinfo:
|
|
160
|
+
quantizer.check_model(dummy_model)
|
|
161
|
+
|
|
162
|
+
assert "FakeOp" in str(excinfo.value)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@pytest.mark.unit
|
|
166
|
+
def test_check_layer_missing_handler(valid_node: NodeProto) -> None:
|
|
167
|
+
quantizer = ONNXOpQuantizer()
|
|
168
|
+
with pytest.raises(MissingHandlerError) as exc_info:
|
|
169
|
+
quantizer.check_layer(valid_node, {})
|
|
170
|
+
|
|
171
|
+
assert QuantizationError("").GENERIC_MESSAGE in str(exc_info.value)
|
|
172
|
+
assert "No quantization handler registered for operator type 'Dummy'." in str(
|
|
173
|
+
exc_info.value,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@pytest.mark.unit
|
|
178
|
+
def test_check_layer_with_bad_handler(invalid_node: NodeProto) -> None:
|
|
179
|
+
quantizer = ONNXOpQuantizer()
|
|
180
|
+
quantizer.handlers = {"Dummy": MockHandler()}
|
|
181
|
+
|
|
182
|
+
# This error is created in our mock handler
|
|
183
|
+
with pytest.raises(ValueError, match="Invalid node parameters"):
|
|
184
|
+
quantizer.check_layer(invalid_node, {})
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@pytest.mark.unit
|
|
188
|
+
def test_get_initializer_map_extracts_all() -> None:
|
|
189
|
+
one_f = 1.0
|
|
190
|
+
two_f = 2.0
|
|
191
|
+
count_init = 2
|
|
192
|
+
tensor1 = helper.make_tensor("a", TensorProto.FLOAT, [1], [one_f])
|
|
193
|
+
tensor2 = helper.make_tensor("b", TensorProto.FLOAT, [1], [two_f])
|
|
194
|
+
graph = helper.make_graph([], "g", [], [], initializer=[tensor1, tensor2])
|
|
195
|
+
model = helper.make_model(graph)
|
|
196
|
+
|
|
197
|
+
quantizer = ONNXOpQuantizer()
|
|
198
|
+
init_map = quantizer.get_initializer_map(model)
|
|
199
|
+
assert init_map["a"].float_data[0] == one_f
|
|
200
|
+
assert init_map["b"].float_data[0] == two_f
|
|
201
|
+
|
|
202
|
+
# Enhanced: check all properties
|
|
203
|
+
assert len(init_map) == count_init
|
|
204
|
+
assert init_map["a"].name == "a"
|
|
205
|
+
assert init_map["a"].data_type == TensorProto.FLOAT
|
|
206
|
+
assert init_map["a"].dims == [1]
|
|
207
|
+
assert init_map["b"].name == "b"
|
|
208
|
+
assert init_map["b"].data_type == TensorProto.FLOAT
|
|
209
|
+
assert init_map["b"].dims == [1]
|
|
210
|
+
# Using numpy_helper for consistency
|
|
211
|
+
assert onnx.numpy_helper.to_array(init_map["a"])[0] == one_f
|
|
212
|
+
assert onnx.numpy_helper.to_array(init_map["b"])[0] == two_f
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@pytest.mark.unit
|
|
216
|
+
def test_check_layer_skips_handler_without_check_supported() -> None:
|
|
217
|
+
class NoCheckHandler:
|
|
218
|
+
def quantize(self, *args: tuple, **kwargs: dict[str, Any]) -> None:
|
|
219
|
+
pass # no check_supported
|
|
220
|
+
|
|
221
|
+
quantizer = ONNXOpQuantizer()
|
|
222
|
+
quantizer.register("NoCheckOp", NoCheckHandler())
|
|
223
|
+
|
|
224
|
+
node = helper.make_node("NoCheckOp", ["x"], ["y"])
|
|
225
|
+
# Should not raise
|
|
226
|
+
quantizer.check_layer(node, {})
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@pytest.mark.unit
|
|
230
|
+
def test_register_overwrites_handler() -> None:
|
|
231
|
+
quantizer = ONNXOpQuantizer()
|
|
232
|
+
handler1 = MockHandler()
|
|
233
|
+
handler2 = MockHandler()
|
|
234
|
+
|
|
235
|
+
quantizer.register("Dummy", handler1)
|
|
236
|
+
quantizer.register("Dummy", handler2)
|
|
237
|
+
|
|
238
|
+
assert quantizer.handlers["Dummy"] is handler2
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@pytest.mark.unit
|
|
242
|
+
def test_check_empty_model() -> None:
|
|
243
|
+
model = helper.make_model(helper.make_graph([], "empty", [], []))
|
|
244
|
+
quantizer = ONNXOpQuantizer()
|
|
245
|
+
# Should not raise
|
|
246
|
+
quantizer.check_model(model)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# This file performs very basic integration tests on each registered quantizer
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import onnx
|
|
5
|
+
import pytest
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from python.core.model_processing.onnx_quantizer.layers.base import ScaleConfig
|
|
9
|
+
from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
|
|
10
|
+
ONNXOpQuantizer,
|
|
11
|
+
)
|
|
12
|
+
from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def dummy_graph() -> onnx.GraphProto:
|
|
17
|
+
return onnx.GraphProto()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def mock_initializer_map(input_names: list[str]) -> dict[str, onnx.TensorProto]:
|
|
21
|
+
rng = np.random.default_rng(TEST_RNG_SEED)
|
|
22
|
+
return {
|
|
23
|
+
name: onnx.helper.make_tensor(
|
|
24
|
+
name=name,
|
|
25
|
+
data_type=onnx.TensorProto.FLOAT,
|
|
26
|
+
dims=[2, 2], # minimal shape
|
|
27
|
+
vals=rng.random(4, dtype=np.float32).tolist(),
|
|
28
|
+
)
|
|
29
|
+
for name in input_names
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_required_input_names(op_type: str) -> list[str]:
|
|
34
|
+
try:
|
|
35
|
+
schema = onnx.defs.get_schema(op_type)
|
|
36
|
+
return [
|
|
37
|
+
inp.name or f"input{i}"
|
|
38
|
+
for i, inp in enumerate(schema.inputs)
|
|
39
|
+
if inp.option != 1
|
|
40
|
+
] # 1 = optional
|
|
41
|
+
except Exception:
|
|
42
|
+
return ["input0"] # fallback
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def validate_quantized_node(node_result: onnx.NodeProto, op_type: str) -> None:
|
|
46
|
+
"""Validate a single quantized node."""
|
|
47
|
+
assert isinstance(node_result, onnx.NodeProto), f"Invalid node type for {op_type}"
|
|
48
|
+
assert node_result.op_type, f"Missing op_type for {op_type}"
|
|
49
|
+
assert node_result.output, f"Missing outputs for {op_type}"
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
# Create a minimal model with custom opset for validation
|
|
53
|
+
temp_graph = onnx.GraphProto()
|
|
54
|
+
temp_graph.name = "temp_graph"
|
|
55
|
+
|
|
56
|
+
# Add dummy inputs/outputs to satisfy graph requirements
|
|
57
|
+
for inp in node_result.input:
|
|
58
|
+
if not any(vi.name == inp for vi in temp_graph.input):
|
|
59
|
+
temp_graph.input.append(
|
|
60
|
+
onnx.helper.make_tensor_value_info(
|
|
61
|
+
inp,
|
|
62
|
+
onnx.TensorProto.FLOAT,
|
|
63
|
+
[1],
|
|
64
|
+
),
|
|
65
|
+
)
|
|
66
|
+
for out in node_result.output:
|
|
67
|
+
if not any(vi.name == out for vi in temp_graph.output):
|
|
68
|
+
temp_graph.output.append(
|
|
69
|
+
onnx.helper.make_tensor_value_info(
|
|
70
|
+
out,
|
|
71
|
+
onnx.TensorProto.FLOAT,
|
|
72
|
+
[1],
|
|
73
|
+
),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
temp_graph.node.append(node_result)
|
|
77
|
+
temp_model = onnx.helper.make_model(temp_graph)
|
|
78
|
+
custom_domain = onnx.helper.make_operatorsetid(
|
|
79
|
+
domain="ai.onnx.contrib",
|
|
80
|
+
version=1,
|
|
81
|
+
)
|
|
82
|
+
temp_model.opset_import.append(custom_domain)
|
|
83
|
+
onnx.checker.check_model(temp_model)
|
|
84
|
+
except onnx.checker.ValidationError as e:
|
|
85
|
+
pytest.fail(f"ONNX node validation failed for {op_type}: {e}")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pytest.mark.integration
|
|
89
|
+
@pytest.mark.parametrize("op_type", list(ONNXOpQuantizer().handlers.keys()))
|
|
90
|
+
def test_registered_quantizer_quantize(
|
|
91
|
+
op_type: str,
|
|
92
|
+
dummy_graph: onnx.GraphProto,
|
|
93
|
+
) -> None:
|
|
94
|
+
quantizer = ONNXOpQuantizer()
|
|
95
|
+
handler = quantizer.handlers[op_type]
|
|
96
|
+
|
|
97
|
+
inputs = get_required_input_names(op_type)
|
|
98
|
+
dummy_initializer_map = mock_initializer_map(inputs)
|
|
99
|
+
|
|
100
|
+
dummy_node = helper.make_node(
|
|
101
|
+
op_type=op_type,
|
|
102
|
+
inputs=inputs,
|
|
103
|
+
outputs=["dummy_output"],
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
result = handler.quantize(
|
|
107
|
+
node=dummy_node,
|
|
108
|
+
graph=dummy_graph,
|
|
109
|
+
scale_config=ScaleConfig(exponent=10, base=2, rescale=True),
|
|
110
|
+
initializer_map=dummy_initializer_map,
|
|
111
|
+
)
|
|
112
|
+
assert result is not None
|
|
113
|
+
|
|
114
|
+
# Enhanced assertions: validate result type and structure
|
|
115
|
+
if isinstance(result, list):
|
|
116
|
+
assert len(result) > 0, f"Quantize returned empty list for {op_type}"
|
|
117
|
+
for node_result in result:
|
|
118
|
+
validate_quantized_node(node_result, op_type)
|
|
119
|
+
else:
|
|
120
|
+
assert result.input, f"Missing inputs for {op_type}"
|
|
121
|
+
validate_quantized_node(result, op_type)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from onnx import ModelProto
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# Helper to extract input shapes
|
|
5
|
+
def get_input_shapes(model: ModelProto) -> dict:
|
|
6
|
+
input_shapes = {}
|
|
7
|
+
for inp in model.graph.input:
|
|
8
|
+
shape = []
|
|
9
|
+
for dim in inp.type.tensor_type.shape.dim:
|
|
10
|
+
if dim.HasField("dim_value"):
|
|
11
|
+
shape.append(int(dim.dim_value))
|
|
12
|
+
elif dim.dim_param:
|
|
13
|
+
shape.append(1) # Default for dynamic dims
|
|
14
|
+
else:
|
|
15
|
+
shape.append(1)
|
|
16
|
+
input_shapes[inp.name] = shape
|
|
17
|
+
return input_shapes
|
|
Binary file
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|