JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.1.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/METADATA +2 -2
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/RECORD +51 -24
- python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
- python/core/circuit_models/generic_onnx.py +43 -9
- python/core/circuits/base.py +231 -71
- python/core/model_processing/converters/onnx_converter.py +86 -32
- python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
- python/core/model_processing/onnx_custom_ops/relu.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +121 -1
- python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
- python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
- python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -1
- python/core/utils/general_layer_functions.py +17 -12
- python/core/utils/model_registry.py +6 -3
- python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
- python/tests/circuit_parent_classes/test_circuit.py +561 -38
- python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
- python/tests/onnx_quantizer_tests/__init__.py +1 -0
- python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
- python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/base.py +279 -0
- python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
- python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
- python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
- python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
- python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
- python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
- python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +265 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
- python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
- python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
- python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
- python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/WHEEL +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import onnx
|
|
7
|
+
import pytest
|
|
8
|
+
from onnx import helper, numpy_helper
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from onnx import GraphProto, ModelProto, NodeProto, TensorProto
|
|
14
|
+
|
|
15
|
+
from python.core.model_processing.onnx_quantizer.exceptions import (
|
|
16
|
+
HandlerImplementationError,
|
|
17
|
+
)
|
|
18
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
19
|
+
BaseOpQuantizer,
|
|
20
|
+
ScaleConfig,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DummyQuantizer(BaseOpQuantizer):
|
|
25
|
+
def __init__(self: DummyQuantizer) -> None:
|
|
26
|
+
self.new_initializers = []
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.fixture
|
|
30
|
+
def dummy_tensor() -> TensorProto:
|
|
31
|
+
return numpy_helper.from_array(np.array([[1.0, 2.0], [3.0, 4.0]]), name="W")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.fixture
|
|
35
|
+
def dummy_bias() -> TensorProto:
|
|
36
|
+
return numpy_helper.from_array(np.array([1.0, 2.0]), name="B")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
def dummy_node() -> NodeProto:
|
|
41
|
+
return helper.make_node(
|
|
42
|
+
"DummyOp",
|
|
43
|
+
inputs=["X", "W", "B"],
|
|
44
|
+
outputs=["Y"],
|
|
45
|
+
name="DummyOp",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.fixture
|
|
50
|
+
def dummy_graph() -> GraphProto:
|
|
51
|
+
return helper.make_graph([], "dummy_graph", inputs=[], outputs=[])
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def initializer_map(
|
|
56
|
+
dummy_tensor: TensorProto,
|
|
57
|
+
dummy_bias: TensorProto,
|
|
58
|
+
) -> dict[str, TensorProto]:
|
|
59
|
+
return {"W": dummy_tensor, "B": dummy_bias}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@pytest.fixture
|
|
63
|
+
def minimal_model() -> ModelProto:
|
|
64
|
+
graph = onnx.helper.make_graph(
|
|
65
|
+
nodes=[], # No nodes
|
|
66
|
+
name="EmptyGraph",
|
|
67
|
+
inputs=[],
|
|
68
|
+
outputs=[],
|
|
69
|
+
initializer=[],
|
|
70
|
+
)
|
|
71
|
+
return onnx.helper.make_model(graph)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@pytest.fixture
|
|
75
|
+
def unsupported_model() -> ModelProto:
|
|
76
|
+
node = onnx.helper.make_node("UnsupportedOp", ["X"], ["Y"])
|
|
77
|
+
graph = onnx.helper.make_graph(
|
|
78
|
+
nodes=[node],
|
|
79
|
+
name="UnsupportedGraph",
|
|
80
|
+
inputs=[],
|
|
81
|
+
outputs=[],
|
|
82
|
+
initializer=[],
|
|
83
|
+
)
|
|
84
|
+
return onnx.helper.make_model(graph)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@pytest.mark.unit
|
|
88
|
+
def test_quantize_raises_not_implemented() -> None:
|
|
89
|
+
quantizer = BaseOpQuantizer()
|
|
90
|
+
with pytest.raises(
|
|
91
|
+
HandlerImplementationError,
|
|
92
|
+
) as excinfo:
|
|
93
|
+
quantizer.quantize(
|
|
94
|
+
node=None,
|
|
95
|
+
graph=None,
|
|
96
|
+
scale_config=ScaleConfig(exponent=1, base=1, rescale=False),
|
|
97
|
+
initializer_map={},
|
|
98
|
+
)
|
|
99
|
+
assert "quantize() not implemented in subclass." in str(excinfo.value)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@pytest.mark.unit
|
|
103
|
+
def test_check_supported_returns_none(dummy_node: NodeProto) -> None:
|
|
104
|
+
quantizer = DummyQuantizer()
|
|
105
|
+
with pytest.raises(HandlerImplementationError) as excinfo:
|
|
106
|
+
quantizer.check_supported(dummy_node, {})
|
|
107
|
+
|
|
108
|
+
assert (
|
|
109
|
+
"Handler implementation error for operator 'DummyQuantizer':"
|
|
110
|
+
" check_supported() not implemented in subclass." in str(excinfo.value)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@pytest.mark.unit
|
|
115
|
+
def test_rescale_layer_modifies_node_output(
|
|
116
|
+
dummy_node: NodeProto,
|
|
117
|
+
dummy_graph: GraphProto,
|
|
118
|
+
) -> None:
|
|
119
|
+
quantizer = DummyQuantizer()
|
|
120
|
+
result_nodes = quantizer.rescale_layer(
|
|
121
|
+
dummy_node,
|
|
122
|
+
scale_base=10,
|
|
123
|
+
scale_exponent=2,
|
|
124
|
+
graph=dummy_graph,
|
|
125
|
+
)
|
|
126
|
+
total_scale = 100.0
|
|
127
|
+
count_nodes = 2
|
|
128
|
+
|
|
129
|
+
assert len(result_nodes) == count_nodes
|
|
130
|
+
assert dummy_node.output[0] == "Y_raw"
|
|
131
|
+
assert result_nodes[1].op_type == "Div"
|
|
132
|
+
assert result_nodes[1].output[0] == "Y"
|
|
133
|
+
|
|
134
|
+
# Check if scale tensor added
|
|
135
|
+
assert len(quantizer.new_initializers) == 1
|
|
136
|
+
scale_tensor = quantizer.new_initializers[0]
|
|
137
|
+
assert scale_tensor.name.endswith("_scale")
|
|
138
|
+
assert scale_tensor.data_type == onnx.TensorProto.INT64
|
|
139
|
+
assert onnx.numpy_helper.to_array(scale_tensor)[0] == total_scale
|
|
140
|
+
|
|
141
|
+
# Validate that result_nodes are valid ONNX nodes
|
|
142
|
+
for node in result_nodes:
|
|
143
|
+
assert isinstance(node, onnx.NodeProto)
|
|
144
|
+
assert node.name
|
|
145
|
+
assert node.op_type
|
|
146
|
+
assert node.input
|
|
147
|
+
assert node.output
|
|
148
|
+
|
|
149
|
+
# Check Div node inputs: should divide Y_raw by scale
|
|
150
|
+
div_node = result_nodes[1]
|
|
151
|
+
assert len(div_node.input) == count_nodes
|
|
152
|
+
assert div_node.input[0] == "Y_raw"
|
|
153
|
+
assert div_node.input[1] == scale_tensor.name
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.unit
|
|
157
|
+
def test_add_nodes_w_and_b_creates_mul_and_cast(
|
|
158
|
+
dummy_node: NodeProto,
|
|
159
|
+
dummy_graph: GraphProto,
|
|
160
|
+
initializer_map: dict[str, Any],
|
|
161
|
+
) -> None:
|
|
162
|
+
_ = dummy_graph
|
|
163
|
+
quantizer = DummyQuantizer()
|
|
164
|
+
exp = 2
|
|
165
|
+
base = 10
|
|
166
|
+
nodes, new_inputs = quantizer.add_nodes_w_and_b(
|
|
167
|
+
dummy_node,
|
|
168
|
+
scale_exponent=exp,
|
|
169
|
+
scale_base=base,
|
|
170
|
+
initializer_map=initializer_map,
|
|
171
|
+
)
|
|
172
|
+
four = 4
|
|
173
|
+
two = 2
|
|
174
|
+
|
|
175
|
+
assert len(nodes) == four # Mul + Cast for W, Mul + Cast for B
|
|
176
|
+
assert nodes[0].op_type == "Mul"
|
|
177
|
+
assert nodes[1].op_type == "Cast"
|
|
178
|
+
assert nodes[2].op_type == "Mul"
|
|
179
|
+
assert nodes[3].op_type == "Cast"
|
|
180
|
+
assert new_inputs == ["X", "W_scaled_cast", "B_scaled_cast"]
|
|
181
|
+
assert len(quantizer.new_initializers) == two
|
|
182
|
+
|
|
183
|
+
weight_scaled = base**exp
|
|
184
|
+
bias_scaled = base ** (exp * 2)
|
|
185
|
+
|
|
186
|
+
# Enhanced assertions: check node inputs/outputs and tensor details
|
|
187
|
+
# Mul for W: input W and W_scale, output W_scaled
|
|
188
|
+
assert nodes[0].input == ["W", "W_scale"]
|
|
189
|
+
assert nodes[0].output == ["W_scaled"]
|
|
190
|
+
# Cast for W: input W_scaled, output W_scaled_cast
|
|
191
|
+
assert nodes[1].input == ["W_scaled"]
|
|
192
|
+
assert nodes[1].output == ["W_scaled_cast"]
|
|
193
|
+
# Similarly for B
|
|
194
|
+
assert nodes[2].input == ["B", "B_scale"]
|
|
195
|
+
assert nodes[2].output == ["B_scaled"]
|
|
196
|
+
assert nodes[3].input == ["B_scaled"]
|
|
197
|
+
assert nodes[3].output == ["B_scaled_cast"]
|
|
198
|
+
|
|
199
|
+
# Check scale tensors
|
|
200
|
+
w_scale = quantizer.new_initializers[0]
|
|
201
|
+
b_scale = quantizer.new_initializers[1]
|
|
202
|
+
assert w_scale.name == "W_scale"
|
|
203
|
+
assert b_scale.name == "B_scale"
|
|
204
|
+
assert onnx.numpy_helper.to_array(w_scale)[0] == weight_scaled # 10**2
|
|
205
|
+
assert onnx.numpy_helper.to_array(b_scale)[0] == bias_scaled
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@pytest.mark.unit
|
|
209
|
+
def test_insert_scale_node_creates_mul_and_cast(
|
|
210
|
+
dummy_tensor: TensorProto,
|
|
211
|
+
dummy_graph: GraphProto,
|
|
212
|
+
) -> None:
|
|
213
|
+
_ = dummy_graph
|
|
214
|
+
quantizer = DummyQuantizer()
|
|
215
|
+
output_name, mul_node, cast_node = quantizer.insert_scale_node(
|
|
216
|
+
dummy_tensor,
|
|
217
|
+
scale_base=10,
|
|
218
|
+
scale_exponent=1,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
assert mul_node.op_type == "Mul"
|
|
222
|
+
assert cast_node.op_type == "Cast"
|
|
223
|
+
assert "_scaled" in mul_node.output[0]
|
|
224
|
+
assert output_name.endswith("_cast")
|
|
225
|
+
assert len(quantizer.new_initializers) == 1
|
|
226
|
+
assert quantizer.new_initializers[0].name.endswith("_scale")
|
|
227
|
+
ten = 10.0
|
|
228
|
+
assert onnx.numpy_helper.to_array(quantizer.new_initializers[0])[0] == ten
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from python.core.model_processing.onnx_quantizer.exceptions import (
|
|
4
|
+
REPORTING_URL,
|
|
5
|
+
InvalidParamError,
|
|
6
|
+
QuantizationError,
|
|
7
|
+
UnsupportedOpError,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.mark.unit
|
|
12
|
+
def test_quantization_error_message() -> None:
|
|
13
|
+
custom_msg = "Something went wrong."
|
|
14
|
+
with pytest.raises(QuantizationError) as exc_info:
|
|
15
|
+
raise QuantizationError(custom_msg)
|
|
16
|
+
assert "This model is not supported by JSTprove." in str(exc_info.value)
|
|
17
|
+
assert custom_msg in str(exc_info.value)
|
|
18
|
+
|
|
19
|
+
assert REPORTING_URL in str(exc_info.value)
|
|
20
|
+
|
|
21
|
+
assert "Submit model support requests via the JSTprove channel:" in str(
|
|
22
|
+
exc_info.value,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.mark.unit
|
|
27
|
+
def test_invalid_param_error_basic() -> None:
|
|
28
|
+
with pytest.raises(InvalidParamError) as exc_info:
|
|
29
|
+
raise InvalidParamError(
|
|
30
|
+
node_name="Conv_1",
|
|
31
|
+
op_type="Conv",
|
|
32
|
+
message="Missing 'strides' attribute.",
|
|
33
|
+
)
|
|
34
|
+
err_msg = str(exc_info.value)
|
|
35
|
+
assert "Invalid parameters in node 'Conv_1'" in err_msg
|
|
36
|
+
assert "(op_type='Conv')" in err_msg
|
|
37
|
+
assert "Missing 'strides' attribute." in err_msg
|
|
38
|
+
assert "[Attribute:" not in err_msg
|
|
39
|
+
assert "[Expected:" not in err_msg
|
|
40
|
+
|
|
41
|
+
msg = ""
|
|
42
|
+
|
|
43
|
+
with pytest.raises(QuantizationError) as exc_info_quantization:
|
|
44
|
+
raise QuantizationError(msg)
|
|
45
|
+
# Assert contains generic error message from quantization error
|
|
46
|
+
assert str(exc_info_quantization.value) in err_msg
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.mark.unit
|
|
50
|
+
def test_invalid_param_error_with_attr_and_expected() -> None:
|
|
51
|
+
with pytest.raises(InvalidParamError) as exc_info:
|
|
52
|
+
raise InvalidParamError(
|
|
53
|
+
node_name="MaxPool_3",
|
|
54
|
+
op_type="MaxPool",
|
|
55
|
+
message="Kernel shape is invalid.",
|
|
56
|
+
attr_key="kernel_shape",
|
|
57
|
+
expected="a list of 2 positive integers",
|
|
58
|
+
)
|
|
59
|
+
err_msg = str(exc_info.value)
|
|
60
|
+
assert "Invalid parameters in node 'MaxPool_3'" in err_msg
|
|
61
|
+
assert "[Attribute: kernel_shape]" in err_msg
|
|
62
|
+
assert "[Expected: a list of 2 positive integers]" in err_msg
|
|
63
|
+
msg = ""
|
|
64
|
+
|
|
65
|
+
with pytest.raises(QuantizationError) as exc_info_quantization:
|
|
66
|
+
raise QuantizationError(msg)
|
|
67
|
+
# Assert contains generic error message from quantization error
|
|
68
|
+
assert str(exc_info_quantization.value) in err_msg
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@pytest.mark.unit
|
|
72
|
+
def test_unsupported_op_error_with_node() -> None:
|
|
73
|
+
with pytest.raises(UnsupportedOpError) as exc_info:
|
|
74
|
+
raise UnsupportedOpError(op_type="Resize", node_name="Resize_42")
|
|
75
|
+
err_msg = str(exc_info.value)
|
|
76
|
+
assert "Unsupported op type: 'Resize'" in err_msg
|
|
77
|
+
assert "in node 'Resize_42'" in err_msg
|
|
78
|
+
assert "documentation for supported layers" in err_msg
|
|
79
|
+
msg = ""
|
|
80
|
+
|
|
81
|
+
with pytest.raises(QuantizationError) as exc_info_quantization:
|
|
82
|
+
raise QuantizationError(msg)
|
|
83
|
+
# Assert contains generic error message from quantization error
|
|
84
|
+
assert str(exc_info_quantization.value) in err_msg
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@pytest.mark.unit
|
|
88
|
+
def test_unsupported_op_error_without_node() -> None:
|
|
89
|
+
with pytest.raises(UnsupportedOpError) as exc_info:
|
|
90
|
+
raise UnsupportedOpError(op_type="Upsample")
|
|
91
|
+
err_msg = str(exc_info.value)
|
|
92
|
+
assert "Unsupported op type: 'Upsample'" in err_msg
|
|
93
|
+
assert "in node" not in err_msg
|
|
94
|
+
msg = ""
|
|
95
|
+
|
|
96
|
+
with pytest.raises(QuantizationError) as exc_info_quantization:
|
|
97
|
+
raise QuantizationError(msg)
|
|
98
|
+
# Assert contains generic error message from quantization error
|
|
99
|
+
assert str(exc_info_quantization.value) in err_msg
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
import onnx
|
|
6
|
+
import pytest
|
|
7
|
+
from onnx import GraphProto, ModelProto, NodeProto, TensorProto, helper
|
|
8
|
+
|
|
9
|
+
from python.core.model_processing.onnx_quantizer.exceptions import (
|
|
10
|
+
MissingHandlerError,
|
|
11
|
+
QuantizationError,
|
|
12
|
+
UnsupportedOpError,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from python.core.model_processing.onnx_quantizer.layers.base import ScaleConfig
|
|
17
|
+
|
|
18
|
+
# Optional: mock layers if needed
|
|
19
|
+
from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
|
|
20
|
+
ONNXOpQuantizer,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Mocks
|
|
25
|
+
class MockHandler:
|
|
26
|
+
def __init__(self: MockHandler) -> None:
|
|
27
|
+
self.called_quantize = False
|
|
28
|
+
self.called_supported = False
|
|
29
|
+
|
|
30
|
+
def quantize(
|
|
31
|
+
self: MockHandler,
|
|
32
|
+
node: NodeProto,
|
|
33
|
+
graph: GraphProto,
|
|
34
|
+
scale_config: ScaleConfig,
|
|
35
|
+
initializer_map: dict[str, TensorProto],
|
|
36
|
+
) -> list[NodeProto]:
|
|
37
|
+
_ = graph, scale_config, initializer_map
|
|
38
|
+
self.called_quantize = True
|
|
39
|
+
return [node] # Return the original node as a list for simplicity
|
|
40
|
+
|
|
41
|
+
def check_supported(
|
|
42
|
+
self: MockHandler,
|
|
43
|
+
node: NodeProto,
|
|
44
|
+
initializer_map: dict[str, TensorProto],
|
|
45
|
+
) -> None:
|
|
46
|
+
_ = initializer_map
|
|
47
|
+
self.called_supported = True
|
|
48
|
+
if node.name == "bad_node":
|
|
49
|
+
msg = "Invalid node parameters"
|
|
50
|
+
raise ValueError(msg)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Fixtures
|
|
54
|
+
@pytest.fixture
|
|
55
|
+
def quantizer() -> ONNXOpQuantizer:
|
|
56
|
+
return ONNXOpQuantizer()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def dummy_node() -> NodeProto:
|
|
61
|
+
return helper.make_node("FakeOp", inputs=["x"], outputs=["y"])
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.fixture
|
|
65
|
+
def valid_node() -> NodeProto:
|
|
66
|
+
return helper.make_node("Dummy", inputs=["x"], outputs=["y"], name="good_node")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.fixture
|
|
70
|
+
def invalid_node() -> NodeProto:
|
|
71
|
+
return helper.make_node("Dummy", inputs=["x"], outputs=["y"], name="bad_node")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@pytest.fixture
|
|
75
|
+
def dummy_model(valid_node: NodeProto, invalid_node: NodeProto) -> ModelProto:
|
|
76
|
+
graph = helper.make_graph(
|
|
77
|
+
[valid_node, invalid_node],
|
|
78
|
+
"test_graph",
|
|
79
|
+
inputs=[],
|
|
80
|
+
outputs=[],
|
|
81
|
+
initializer=[helper.make_tensor("x", TensorProto.FLOAT, [1], [0.5])],
|
|
82
|
+
)
|
|
83
|
+
return helper.make_model(graph)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# Tests
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@pytest.mark.unit
|
|
90
|
+
def test_check_model_raises_on_unsupported_op() -> None:
|
|
91
|
+
quantizer = ONNXOpQuantizer()
|
|
92
|
+
|
|
93
|
+
unsupported_node = helper.make_node("UnsupportedOp", ["x"], ["y"])
|
|
94
|
+
graph = helper.make_graph([unsupported_node], "test_graph", [], [])
|
|
95
|
+
model = helper.make_model(graph)
|
|
96
|
+
|
|
97
|
+
with pytest.raises(UnsupportedOpError):
|
|
98
|
+
quantizer.check_model(model)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@pytest.mark.unit
|
|
102
|
+
def test_check_layer_invokes_check_supported() -> None:
|
|
103
|
+
quantizer = ONNXOpQuantizer()
|
|
104
|
+
handler = MockHandler()
|
|
105
|
+
quantizer.register("FakeOp", handler)
|
|
106
|
+
|
|
107
|
+
node = helper.make_node("FakeOp", ["x"], ["y"])
|
|
108
|
+
initializer_map = {}
|
|
109
|
+
|
|
110
|
+
quantizer.check_layer(node, initializer_map)
|
|
111
|
+
# Check that check_supported is called
|
|
112
|
+
assert handler.called_supported
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@pytest.mark.unit
|
|
116
|
+
def test_get_initializer_map_returns_correct_dict() -> None:
|
|
117
|
+
quantizer = ONNXOpQuantizer()
|
|
118
|
+
|
|
119
|
+
tensor = helper.make_tensor(
|
|
120
|
+
name="W",
|
|
121
|
+
data_type=TensorProto.FLOAT,
|
|
122
|
+
dims=[1],
|
|
123
|
+
vals=[1.0],
|
|
124
|
+
)
|
|
125
|
+
graph = helper.make_graph([], "test_graph", [], [], [tensor])
|
|
126
|
+
model = helper.make_model(graph)
|
|
127
|
+
|
|
128
|
+
init_map = quantizer.get_initializer_map(model)
|
|
129
|
+
# Test initializer in map
|
|
130
|
+
assert "W" in init_map
|
|
131
|
+
# Test initializer map lines up
|
|
132
|
+
assert init_map["W"] == tensor
|
|
133
|
+
# Enhanced: check tensor properties
|
|
134
|
+
assert init_map["W"].data_type == TensorProto.FLOAT
|
|
135
|
+
assert init_map["W"].dims == [1]
|
|
136
|
+
assert onnx.numpy_helper.to_array(init_map["W"])[0] == 1.0
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@pytest.mark.unit
|
|
140
|
+
def test_quantize_with_unregistered_op_warns(dummy_node: NodeProto) -> None:
|
|
141
|
+
quantizer = ONNXOpQuantizer()
|
|
142
|
+
graph = helper.make_graph([], "g", [], [])
|
|
143
|
+
with pytest.raises(UnsupportedOpError) as excinfo:
|
|
144
|
+
_ = quantizer.quantize(dummy_node, graph, 1, 1, {}, rescale=False)
|
|
145
|
+
|
|
146
|
+
captured = str(excinfo.value)
|
|
147
|
+
assert "Unsupported op type: 'FakeOp'" in captured
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# Could be unit or integration?
|
|
151
|
+
@pytest.mark.unit
|
|
152
|
+
def test_check_model_raises_unsupported(dummy_model: ModelProto) -> None:
|
|
153
|
+
quantizer = ONNXOpQuantizer()
|
|
154
|
+
quantizer.handlers = {"Dummy": MockHandler()}
|
|
155
|
+
|
|
156
|
+
# Remove one node to simulate unsupported ops
|
|
157
|
+
dummy_model.graph.node.append(helper.make_node("FakeOp", ["a"], ["b"]))
|
|
158
|
+
|
|
159
|
+
with pytest.raises(UnsupportedOpError) as excinfo:
|
|
160
|
+
quantizer.check_model(dummy_model)
|
|
161
|
+
|
|
162
|
+
assert "FakeOp" in str(excinfo.value)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@pytest.mark.unit
|
|
166
|
+
def test_check_layer_missing_handler(valid_node: NodeProto) -> None:
|
|
167
|
+
quantizer = ONNXOpQuantizer()
|
|
168
|
+
with pytest.raises(MissingHandlerError) as exc_info:
|
|
169
|
+
quantizer.check_layer(valid_node, {})
|
|
170
|
+
|
|
171
|
+
assert QuantizationError("").GENERIC_MESSAGE in str(exc_info.value)
|
|
172
|
+
assert "No quantization handler registered for operator type 'Dummy'." in str(
|
|
173
|
+
exc_info.value,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@pytest.mark.unit
|
|
178
|
+
def test_check_layer_with_bad_handler(invalid_node: NodeProto) -> None:
|
|
179
|
+
quantizer = ONNXOpQuantizer()
|
|
180
|
+
quantizer.handlers = {"Dummy": MockHandler()}
|
|
181
|
+
|
|
182
|
+
# This error is created in our mock handler
|
|
183
|
+
with pytest.raises(ValueError, match="Invalid node parameters"):
|
|
184
|
+
quantizer.check_layer(invalid_node, {})
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@pytest.mark.unit
|
|
188
|
+
def test_get_initializer_map_extracts_all() -> None:
|
|
189
|
+
one_f = 1.0
|
|
190
|
+
two_f = 2.0
|
|
191
|
+
count_init = 2
|
|
192
|
+
tensor1 = helper.make_tensor("a", TensorProto.FLOAT, [1], [one_f])
|
|
193
|
+
tensor2 = helper.make_tensor("b", TensorProto.FLOAT, [1], [two_f])
|
|
194
|
+
graph = helper.make_graph([], "g", [], [], initializer=[tensor1, tensor2])
|
|
195
|
+
model = helper.make_model(graph)
|
|
196
|
+
|
|
197
|
+
quantizer = ONNXOpQuantizer()
|
|
198
|
+
init_map = quantizer.get_initializer_map(model)
|
|
199
|
+
assert init_map["a"].float_data[0] == one_f
|
|
200
|
+
assert init_map["b"].float_data[0] == two_f
|
|
201
|
+
|
|
202
|
+
# Enhanced: check all properties
|
|
203
|
+
assert len(init_map) == count_init
|
|
204
|
+
assert init_map["a"].name == "a"
|
|
205
|
+
assert init_map["a"].data_type == TensorProto.FLOAT
|
|
206
|
+
assert init_map["a"].dims == [1]
|
|
207
|
+
assert init_map["b"].name == "b"
|
|
208
|
+
assert init_map["b"].data_type == TensorProto.FLOAT
|
|
209
|
+
assert init_map["b"].dims == [1]
|
|
210
|
+
# Using numpy_helper for consistency
|
|
211
|
+
assert onnx.numpy_helper.to_array(init_map["a"])[0] == one_f
|
|
212
|
+
assert onnx.numpy_helper.to_array(init_map["b"])[0] == two_f
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@pytest.mark.unit
|
|
216
|
+
def test_check_layer_skips_handler_without_check_supported() -> None:
|
|
217
|
+
class NoCheckHandler:
|
|
218
|
+
def quantize(self, *args: tuple, **kwargs: dict[str, Any]) -> None:
|
|
219
|
+
pass # no check_supported
|
|
220
|
+
|
|
221
|
+
quantizer = ONNXOpQuantizer()
|
|
222
|
+
quantizer.register("NoCheckOp", NoCheckHandler())
|
|
223
|
+
|
|
224
|
+
node = helper.make_node("NoCheckOp", ["x"], ["y"])
|
|
225
|
+
# Should not raise
|
|
226
|
+
quantizer.check_layer(node, {})
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@pytest.mark.unit
|
|
230
|
+
def test_register_overwrites_handler() -> None:
|
|
231
|
+
quantizer = ONNXOpQuantizer()
|
|
232
|
+
handler1 = MockHandler()
|
|
233
|
+
handler2 = MockHandler()
|
|
234
|
+
|
|
235
|
+
quantizer.register("Dummy", handler1)
|
|
236
|
+
quantizer.register("Dummy", handler2)
|
|
237
|
+
|
|
238
|
+
assert quantizer.handlers["Dummy"] is handler2
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@pytest.mark.unit
|
|
242
|
+
def test_check_empty_model() -> None:
|
|
243
|
+
model = helper.make_model(helper.make_graph([], "empty", [], []))
|
|
244
|
+
quantizer = ONNXOpQuantizer()
|
|
245
|
+
# Should not raise
|
|
246
|
+
quantizer.check_model(model)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# This file performs very basic integration tests on each registered quantizer
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import onnx
|
|
5
|
+
import pytest
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from python.core.model_processing.onnx_quantizer.layers.base import ScaleConfig
|
|
9
|
+
from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
|
|
10
|
+
ONNXOpQuantizer,
|
|
11
|
+
)
|
|
12
|
+
from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def dummy_graph() -> onnx.GraphProto:
|
|
17
|
+
return onnx.GraphProto()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def mock_initializer_map(input_names: list[str]) -> dict[str, onnx.TensorProto]:
|
|
21
|
+
rng = np.random.default_rng(TEST_RNG_SEED)
|
|
22
|
+
return {
|
|
23
|
+
name: onnx.helper.make_tensor(
|
|
24
|
+
name=name,
|
|
25
|
+
data_type=onnx.TensorProto.FLOAT,
|
|
26
|
+
dims=[2, 2], # minimal shape
|
|
27
|
+
vals=rng.random(4, dtype=np.float32).tolist(),
|
|
28
|
+
)
|
|
29
|
+
for name in input_names
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_required_input_names(op_type: str) -> list[str]:
|
|
34
|
+
try:
|
|
35
|
+
schema = onnx.defs.get_schema(op_type)
|
|
36
|
+
return [
|
|
37
|
+
inp.name or f"input{i}"
|
|
38
|
+
for i, inp in enumerate(schema.inputs)
|
|
39
|
+
if inp.option != 1
|
|
40
|
+
] # 1 = optional
|
|
41
|
+
except Exception:
|
|
42
|
+
return ["input0"] # fallback
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def validate_quantized_node(node_result: onnx.NodeProto, op_type: str) -> None:
|
|
46
|
+
"""Validate a single quantized node."""
|
|
47
|
+
assert isinstance(node_result, onnx.NodeProto), f"Invalid node type for {op_type}"
|
|
48
|
+
assert node_result.op_type, f"Missing op_type for {op_type}"
|
|
49
|
+
assert node_result.output, f"Missing outputs for {op_type}"
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
# Create a minimal model with custom opset for validation
|
|
53
|
+
temp_graph = onnx.GraphProto()
|
|
54
|
+
temp_graph.name = "temp_graph"
|
|
55
|
+
|
|
56
|
+
# Add dummy inputs/outputs to satisfy graph requirements
|
|
57
|
+
for inp in node_result.input:
|
|
58
|
+
if not any(vi.name == inp for vi in temp_graph.input):
|
|
59
|
+
temp_graph.input.append(
|
|
60
|
+
onnx.helper.make_tensor_value_info(
|
|
61
|
+
inp,
|
|
62
|
+
onnx.TensorProto.FLOAT,
|
|
63
|
+
[1],
|
|
64
|
+
),
|
|
65
|
+
)
|
|
66
|
+
for out in node_result.output:
|
|
67
|
+
if not any(vi.name == out for vi in temp_graph.output):
|
|
68
|
+
temp_graph.output.append(
|
|
69
|
+
onnx.helper.make_tensor_value_info(
|
|
70
|
+
out,
|
|
71
|
+
onnx.TensorProto.FLOAT,
|
|
72
|
+
[1],
|
|
73
|
+
),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
temp_graph.node.append(node_result)
|
|
77
|
+
temp_model = onnx.helper.make_model(temp_graph)
|
|
78
|
+
custom_domain = onnx.helper.make_operatorsetid(
|
|
79
|
+
domain="ai.onnx.contrib",
|
|
80
|
+
version=1,
|
|
81
|
+
)
|
|
82
|
+
temp_model.opset_import.append(custom_domain)
|
|
83
|
+
onnx.checker.check_model(temp_model)
|
|
84
|
+
except onnx.checker.ValidationError as e:
|
|
85
|
+
pytest.fail(f"ONNX node validation failed for {op_type}: {e}")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pytest.mark.integration
|
|
89
|
+
@pytest.mark.parametrize("op_type", list(ONNXOpQuantizer().handlers.keys()))
|
|
90
|
+
def test_registered_quantizer_quantize(
|
|
91
|
+
op_type: str,
|
|
92
|
+
dummy_graph: onnx.GraphProto,
|
|
93
|
+
) -> None:
|
|
94
|
+
quantizer = ONNXOpQuantizer()
|
|
95
|
+
handler = quantizer.handlers[op_type]
|
|
96
|
+
|
|
97
|
+
inputs = get_required_input_names(op_type)
|
|
98
|
+
dummy_initializer_map = mock_initializer_map(inputs)
|
|
99
|
+
|
|
100
|
+
dummy_node = helper.make_node(
|
|
101
|
+
op_type=op_type,
|
|
102
|
+
inputs=inputs,
|
|
103
|
+
outputs=["dummy_output"],
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
result = handler.quantize(
|
|
107
|
+
node=dummy_node,
|
|
108
|
+
graph=dummy_graph,
|
|
109
|
+
scale_config=ScaleConfig(exponent=10, base=2, rescale=True),
|
|
110
|
+
initializer_map=dummy_initializer_map,
|
|
111
|
+
)
|
|
112
|
+
assert result is not None
|
|
113
|
+
|
|
114
|
+
# Enhanced assertions: validate result type and structure
|
|
115
|
+
if isinstance(result, list):
|
|
116
|
+
assert len(result) > 0, f"Quantize returned empty list for {op_type}"
|
|
117
|
+
for node_result in result:
|
|
118
|
+
validate_quantized_node(node_result, op_type)
|
|
119
|
+
else:
|
|
120
|
+
assert result.input, f"Missing inputs for {op_type}"
|
|
121
|
+
validate_quantized_node(result, op_type)
|