JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.1.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (52) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/METADATA +2 -2
  2. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/RECORD +51 -24
  3. python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +86 -32
  7. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  8. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  9. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  10. python/core/model_processing/onnx_quantizer/layers/base.py +121 -1
  11. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  12. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  13. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  14. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  15. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  16. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -1
  17. python/core/utils/general_layer_functions.py +17 -12
  18. python/core/utils/model_registry.py +6 -3
  19. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  20. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  21. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  22. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  23. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  24. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  25. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  26. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  27. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  28. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  29. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  30. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  31. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  32. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  33. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  34. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  35. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  36. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  37. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  38. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  39. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  40. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +265 -0
  41. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  42. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  43. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  44. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  45. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  46. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  47. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  48. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  49. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/WHEEL +0 -0
  50. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/entry_points.txt +0 -0
  51. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/licenses/LICENSE +0 -0
  52. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,228 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ import onnx
7
+ import pytest
8
+ from onnx import helper, numpy_helper
9
+
10
+ if TYPE_CHECKING:
11
+ from typing import Any
12
+
13
+ from onnx import GraphProto, ModelProto, NodeProto, TensorProto
14
+
15
+ from python.core.model_processing.onnx_quantizer.exceptions import (
16
+ HandlerImplementationError,
17
+ )
18
+ from python.core.model_processing.onnx_quantizer.layers.base import (
19
+ BaseOpQuantizer,
20
+ ScaleConfig,
21
+ )
22
+
23
+
24
+ class DummyQuantizer(BaseOpQuantizer):
25
+ def __init__(self: DummyQuantizer) -> None:
26
+ self.new_initializers = []
27
+
28
+
29
+ @pytest.fixture
30
+ def dummy_tensor() -> TensorProto:
31
+ return numpy_helper.from_array(np.array([[1.0, 2.0], [3.0, 4.0]]), name="W")
32
+
33
+
34
+ @pytest.fixture
35
+ def dummy_bias() -> TensorProto:
36
+ return numpy_helper.from_array(np.array([1.0, 2.0]), name="B")
37
+
38
+
39
+ @pytest.fixture
40
+ def dummy_node() -> NodeProto:
41
+ return helper.make_node(
42
+ "DummyOp",
43
+ inputs=["X", "W", "B"],
44
+ outputs=["Y"],
45
+ name="DummyOp",
46
+ )
47
+
48
+
49
+ @pytest.fixture
50
+ def dummy_graph() -> GraphProto:
51
+ return helper.make_graph([], "dummy_graph", inputs=[], outputs=[])
52
+
53
+
54
+ @pytest.fixture
55
+ def initializer_map(
56
+ dummy_tensor: TensorProto,
57
+ dummy_bias: TensorProto,
58
+ ) -> dict[str, TensorProto]:
59
+ return {"W": dummy_tensor, "B": dummy_bias}
60
+
61
+
62
+ @pytest.fixture
63
+ def minimal_model() -> ModelProto:
64
+ graph = onnx.helper.make_graph(
65
+ nodes=[], # No nodes
66
+ name="EmptyGraph",
67
+ inputs=[],
68
+ outputs=[],
69
+ initializer=[],
70
+ )
71
+ return onnx.helper.make_model(graph)
72
+
73
+
74
+ @pytest.fixture
75
+ def unsupported_model() -> ModelProto:
76
+ node = onnx.helper.make_node("UnsupportedOp", ["X"], ["Y"])
77
+ graph = onnx.helper.make_graph(
78
+ nodes=[node],
79
+ name="UnsupportedGraph",
80
+ inputs=[],
81
+ outputs=[],
82
+ initializer=[],
83
+ )
84
+ return onnx.helper.make_model(graph)
85
+
86
+
87
+ @pytest.mark.unit
88
+ def test_quantize_raises_not_implemented() -> None:
89
+ quantizer = BaseOpQuantizer()
90
+ with pytest.raises(
91
+ HandlerImplementationError,
92
+ ) as excinfo:
93
+ quantizer.quantize(
94
+ node=None,
95
+ graph=None,
96
+ scale_config=ScaleConfig(exponent=1, base=1, rescale=False),
97
+ initializer_map={},
98
+ )
99
+ assert "quantize() not implemented in subclass." in str(excinfo.value)
100
+
101
+
102
+ @pytest.mark.unit
103
+ def test_check_supported_returns_none(dummy_node: NodeProto) -> None:
104
+ quantizer = DummyQuantizer()
105
+ with pytest.raises(HandlerImplementationError) as excinfo:
106
+ quantizer.check_supported(dummy_node, {})
107
+
108
+ assert (
109
+ "Handler implementation error for operator 'DummyQuantizer':"
110
+ " check_supported() not implemented in subclass." in str(excinfo.value)
111
+ )
112
+
113
+
114
+ @pytest.mark.unit
115
+ def test_rescale_layer_modifies_node_output(
116
+ dummy_node: NodeProto,
117
+ dummy_graph: GraphProto,
118
+ ) -> None:
119
+ quantizer = DummyQuantizer()
120
+ result_nodes = quantizer.rescale_layer(
121
+ dummy_node,
122
+ scale_base=10,
123
+ scale_exponent=2,
124
+ graph=dummy_graph,
125
+ )
126
+ total_scale = 100.0
127
+ count_nodes = 2
128
+
129
+ assert len(result_nodes) == count_nodes
130
+ assert dummy_node.output[0] == "Y_raw"
131
+ assert result_nodes[1].op_type == "Div"
132
+ assert result_nodes[1].output[0] == "Y"
133
+
134
+ # Check if scale tensor added
135
+ assert len(quantizer.new_initializers) == 1
136
+ scale_tensor = quantizer.new_initializers[0]
137
+ assert scale_tensor.name.endswith("_scale")
138
+ assert scale_tensor.data_type == onnx.TensorProto.INT64
139
+ assert onnx.numpy_helper.to_array(scale_tensor)[0] == total_scale
140
+
141
+ # Validate that result_nodes are valid ONNX nodes
142
+ for node in result_nodes:
143
+ assert isinstance(node, onnx.NodeProto)
144
+ assert node.name
145
+ assert node.op_type
146
+ assert node.input
147
+ assert node.output
148
+
149
+ # Check Div node inputs: should divide Y_raw by scale
150
+ div_node = result_nodes[1]
151
+ assert len(div_node.input) == count_nodes
152
+ assert div_node.input[0] == "Y_raw"
153
+ assert div_node.input[1] == scale_tensor.name
154
+
155
+
156
+ @pytest.mark.unit
157
+ def test_add_nodes_w_and_b_creates_mul_and_cast(
158
+ dummy_node: NodeProto,
159
+ dummy_graph: GraphProto,
160
+ initializer_map: dict[str, Any],
161
+ ) -> None:
162
+ _ = dummy_graph
163
+ quantizer = DummyQuantizer()
164
+ exp = 2
165
+ base = 10
166
+ nodes, new_inputs = quantizer.add_nodes_w_and_b(
167
+ dummy_node,
168
+ scale_exponent=exp,
169
+ scale_base=base,
170
+ initializer_map=initializer_map,
171
+ )
172
+ four = 4
173
+ two = 2
174
+
175
+ assert len(nodes) == four # Mul + Cast for W, Mul + Cast for B
176
+ assert nodes[0].op_type == "Mul"
177
+ assert nodes[1].op_type == "Cast"
178
+ assert nodes[2].op_type == "Mul"
179
+ assert nodes[3].op_type == "Cast"
180
+ assert new_inputs == ["X", "W_scaled_cast", "B_scaled_cast"]
181
+ assert len(quantizer.new_initializers) == two
182
+
183
+ weight_scaled = base**exp
184
+ bias_scaled = base ** (exp * 2)
185
+
186
+ # Enhanced assertions: check node inputs/outputs and tensor details
187
+ # Mul for W: input W and W_scale, output W_scaled
188
+ assert nodes[0].input == ["W", "W_scale"]
189
+ assert nodes[0].output == ["W_scaled"]
190
+ # Cast for W: input W_scaled, output W_scaled_cast
191
+ assert nodes[1].input == ["W_scaled"]
192
+ assert nodes[1].output == ["W_scaled_cast"]
193
+ # Similarly for B
194
+ assert nodes[2].input == ["B", "B_scale"]
195
+ assert nodes[2].output == ["B_scaled"]
196
+ assert nodes[3].input == ["B_scaled"]
197
+ assert nodes[3].output == ["B_scaled_cast"]
198
+
199
+ # Check scale tensors
200
+ w_scale = quantizer.new_initializers[0]
201
+ b_scale = quantizer.new_initializers[1]
202
+ assert w_scale.name == "W_scale"
203
+ assert b_scale.name == "B_scale"
204
+ assert onnx.numpy_helper.to_array(w_scale)[0] == weight_scaled # 10**2
205
+ assert onnx.numpy_helper.to_array(b_scale)[0] == bias_scaled
206
+
207
+
208
+ @pytest.mark.unit
209
+ def test_insert_scale_node_creates_mul_and_cast(
210
+ dummy_tensor: TensorProto,
211
+ dummy_graph: GraphProto,
212
+ ) -> None:
213
+ _ = dummy_graph
214
+ quantizer = DummyQuantizer()
215
+ output_name, mul_node, cast_node = quantizer.insert_scale_node(
216
+ dummy_tensor,
217
+ scale_base=10,
218
+ scale_exponent=1,
219
+ )
220
+
221
+ assert mul_node.op_type == "Mul"
222
+ assert cast_node.op_type == "Cast"
223
+ assert "_scaled" in mul_node.output[0]
224
+ assert output_name.endswith("_cast")
225
+ assert len(quantizer.new_initializers) == 1
226
+ assert quantizer.new_initializers[0].name.endswith("_scale")
227
+ ten = 10.0
228
+ assert onnx.numpy_helper.to_array(quantizer.new_initializers[0])[0] == ten
@@ -0,0 +1,99 @@
1
+ import pytest
2
+
3
+ from python.core.model_processing.onnx_quantizer.exceptions import (
4
+ REPORTING_URL,
5
+ InvalidParamError,
6
+ QuantizationError,
7
+ UnsupportedOpError,
8
+ )
9
+
10
+
11
+ @pytest.mark.unit
12
+ def test_quantization_error_message() -> None:
13
+ custom_msg = "Something went wrong."
14
+ with pytest.raises(QuantizationError) as exc_info:
15
+ raise QuantizationError(custom_msg)
16
+ assert "This model is not supported by JSTprove." in str(exc_info.value)
17
+ assert custom_msg in str(exc_info.value)
18
+
19
+ assert REPORTING_URL in str(exc_info.value)
20
+
21
+ assert "Submit model support requests via the JSTprove channel:" in str(
22
+ exc_info.value,
23
+ )
24
+
25
+
26
+ @pytest.mark.unit
27
+ def test_invalid_param_error_basic() -> None:
28
+ with pytest.raises(InvalidParamError) as exc_info:
29
+ raise InvalidParamError(
30
+ node_name="Conv_1",
31
+ op_type="Conv",
32
+ message="Missing 'strides' attribute.",
33
+ )
34
+ err_msg = str(exc_info.value)
35
+ assert "Invalid parameters in node 'Conv_1'" in err_msg
36
+ assert "(op_type='Conv')" in err_msg
37
+ assert "Missing 'strides' attribute." in err_msg
38
+ assert "[Attribute:" not in err_msg
39
+ assert "[Expected:" not in err_msg
40
+
41
+ msg = ""
42
+
43
+ with pytest.raises(QuantizationError) as exc_info_quantization:
44
+ raise QuantizationError(msg)
45
+ # Assert contains generic error message from quantization error
46
+ assert str(exc_info_quantization.value) in err_msg
47
+
48
+
49
+ @pytest.mark.unit
50
+ def test_invalid_param_error_with_attr_and_expected() -> None:
51
+ with pytest.raises(InvalidParamError) as exc_info:
52
+ raise InvalidParamError(
53
+ node_name="MaxPool_3",
54
+ op_type="MaxPool",
55
+ message="Kernel shape is invalid.",
56
+ attr_key="kernel_shape",
57
+ expected="a list of 2 positive integers",
58
+ )
59
+ err_msg = str(exc_info.value)
60
+ assert "Invalid parameters in node 'MaxPool_3'" in err_msg
61
+ assert "[Attribute: kernel_shape]" in err_msg
62
+ assert "[Expected: a list of 2 positive integers]" in err_msg
63
+ msg = ""
64
+
65
+ with pytest.raises(QuantizationError) as exc_info_quantization:
66
+ raise QuantizationError(msg)
67
+ # Assert contains generic error message from quantization error
68
+ assert str(exc_info_quantization.value) in err_msg
69
+
70
+
71
+ @pytest.mark.unit
72
+ def test_unsupported_op_error_with_node() -> None:
73
+ with pytest.raises(UnsupportedOpError) as exc_info:
74
+ raise UnsupportedOpError(op_type="Resize", node_name="Resize_42")
75
+ err_msg = str(exc_info.value)
76
+ assert "Unsupported op type: 'Resize'" in err_msg
77
+ assert "in node 'Resize_42'" in err_msg
78
+ assert "documentation for supported layers" in err_msg
79
+ msg = ""
80
+
81
+ with pytest.raises(QuantizationError) as exc_info_quantization:
82
+ raise QuantizationError(msg)
83
+ # Assert contains generic error message from quantization error
84
+ assert str(exc_info_quantization.value) in err_msg
85
+
86
+
87
+ @pytest.mark.unit
88
+ def test_unsupported_op_error_without_node() -> None:
89
+ with pytest.raises(UnsupportedOpError) as exc_info:
90
+ raise UnsupportedOpError(op_type="Upsample")
91
+ err_msg = str(exc_info.value)
92
+ assert "Unsupported op type: 'Upsample'" in err_msg
93
+ assert "in node" not in err_msg
94
+ msg = ""
95
+
96
+ with pytest.raises(QuantizationError) as exc_info_quantization:
97
+ raise QuantizationError(msg)
98
+ # Assert contains generic error message from quantization error
99
+ assert str(exc_info_quantization.value) in err_msg
@@ -0,0 +1,246 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ import onnx
6
+ import pytest
7
+ from onnx import GraphProto, ModelProto, NodeProto, TensorProto, helper
8
+
9
+ from python.core.model_processing.onnx_quantizer.exceptions import (
10
+ MissingHandlerError,
11
+ QuantizationError,
12
+ UnsupportedOpError,
13
+ )
14
+
15
+ if TYPE_CHECKING:
16
+ from python.core.model_processing.onnx_quantizer.layers.base import ScaleConfig
17
+
18
+ # Optional: mock layers if needed
19
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
20
+ ONNXOpQuantizer,
21
+ )
22
+
23
+
24
+ # Mocks
25
+ class MockHandler:
26
+ def __init__(self: MockHandler) -> None:
27
+ self.called_quantize = False
28
+ self.called_supported = False
29
+
30
+ def quantize(
31
+ self: MockHandler,
32
+ node: NodeProto,
33
+ graph: GraphProto,
34
+ scale_config: ScaleConfig,
35
+ initializer_map: dict[str, TensorProto],
36
+ ) -> list[NodeProto]:
37
+ _ = graph, scale_config, initializer_map
38
+ self.called_quantize = True
39
+ return [node] # Return the original node as a list for simplicity
40
+
41
+ def check_supported(
42
+ self: MockHandler,
43
+ node: NodeProto,
44
+ initializer_map: dict[str, TensorProto],
45
+ ) -> None:
46
+ _ = initializer_map
47
+ self.called_supported = True
48
+ if node.name == "bad_node":
49
+ msg = "Invalid node parameters"
50
+ raise ValueError(msg)
51
+
52
+
53
+ # Fixtures
54
+ @pytest.fixture
55
+ def quantizer() -> ONNXOpQuantizer:
56
+ return ONNXOpQuantizer()
57
+
58
+
59
+ @pytest.fixture
60
+ def dummy_node() -> NodeProto:
61
+ return helper.make_node("FakeOp", inputs=["x"], outputs=["y"])
62
+
63
+
64
+ @pytest.fixture
65
+ def valid_node() -> NodeProto:
66
+ return helper.make_node("Dummy", inputs=["x"], outputs=["y"], name="good_node")
67
+
68
+
69
+ @pytest.fixture
70
+ def invalid_node() -> NodeProto:
71
+ return helper.make_node("Dummy", inputs=["x"], outputs=["y"], name="bad_node")
72
+
73
+
74
+ @pytest.fixture
75
+ def dummy_model(valid_node: NodeProto, invalid_node: NodeProto) -> ModelProto:
76
+ graph = helper.make_graph(
77
+ [valid_node, invalid_node],
78
+ "test_graph",
79
+ inputs=[],
80
+ outputs=[],
81
+ initializer=[helper.make_tensor("x", TensorProto.FLOAT, [1], [0.5])],
82
+ )
83
+ return helper.make_model(graph)
84
+
85
+
86
+ # Tests
87
+
88
+
89
+ @pytest.mark.unit
90
+ def test_check_model_raises_on_unsupported_op() -> None:
91
+ quantizer = ONNXOpQuantizer()
92
+
93
+ unsupported_node = helper.make_node("UnsupportedOp", ["x"], ["y"])
94
+ graph = helper.make_graph([unsupported_node], "test_graph", [], [])
95
+ model = helper.make_model(graph)
96
+
97
+ with pytest.raises(UnsupportedOpError):
98
+ quantizer.check_model(model)
99
+
100
+
101
+ @pytest.mark.unit
102
+ def test_check_layer_invokes_check_supported() -> None:
103
+ quantizer = ONNXOpQuantizer()
104
+ handler = MockHandler()
105
+ quantizer.register("FakeOp", handler)
106
+
107
+ node = helper.make_node("FakeOp", ["x"], ["y"])
108
+ initializer_map = {}
109
+
110
+ quantizer.check_layer(node, initializer_map)
111
+ # Check that check_supported is called
112
+ assert handler.called_supported
113
+
114
+
115
+ @pytest.mark.unit
116
+ def test_get_initializer_map_returns_correct_dict() -> None:
117
+ quantizer = ONNXOpQuantizer()
118
+
119
+ tensor = helper.make_tensor(
120
+ name="W",
121
+ data_type=TensorProto.FLOAT,
122
+ dims=[1],
123
+ vals=[1.0],
124
+ )
125
+ graph = helper.make_graph([], "test_graph", [], [], [tensor])
126
+ model = helper.make_model(graph)
127
+
128
+ init_map = quantizer.get_initializer_map(model)
129
+ # Test initializer in map
130
+ assert "W" in init_map
131
+ # Test initializer map lines up
132
+ assert init_map["W"] == tensor
133
+ # Enhanced: check tensor properties
134
+ assert init_map["W"].data_type == TensorProto.FLOAT
135
+ assert init_map["W"].dims == [1]
136
+ assert onnx.numpy_helper.to_array(init_map["W"])[0] == 1.0
137
+
138
+
139
+ @pytest.mark.unit
140
+ def test_quantize_with_unregistered_op_warns(dummy_node: NodeProto) -> None:
141
+ quantizer = ONNXOpQuantizer()
142
+ graph = helper.make_graph([], "g", [], [])
143
+ with pytest.raises(UnsupportedOpError) as excinfo:
144
+ _ = quantizer.quantize(dummy_node, graph, 1, 1, {}, rescale=False)
145
+
146
+ captured = str(excinfo.value)
147
+ assert "Unsupported op type: 'FakeOp'" in captured
148
+
149
+
150
+ # Could be unit or integration?
151
+ @pytest.mark.unit
152
+ def test_check_model_raises_unsupported(dummy_model: ModelProto) -> None:
153
+ quantizer = ONNXOpQuantizer()
154
+ quantizer.handlers = {"Dummy": MockHandler()}
155
+
156
+ # Remove one node to simulate unsupported ops
157
+ dummy_model.graph.node.append(helper.make_node("FakeOp", ["a"], ["b"]))
158
+
159
+ with pytest.raises(UnsupportedOpError) as excinfo:
160
+ quantizer.check_model(dummy_model)
161
+
162
+ assert "FakeOp" in str(excinfo.value)
163
+
164
+
165
+ @pytest.mark.unit
166
+ def test_check_layer_missing_handler(valid_node: NodeProto) -> None:
167
+ quantizer = ONNXOpQuantizer()
168
+ with pytest.raises(MissingHandlerError) as exc_info:
169
+ quantizer.check_layer(valid_node, {})
170
+
171
+ assert QuantizationError("").GENERIC_MESSAGE in str(exc_info.value)
172
+ assert "No quantization handler registered for operator type 'Dummy'." in str(
173
+ exc_info.value,
174
+ )
175
+
176
+
177
+ @pytest.mark.unit
178
+ def test_check_layer_with_bad_handler(invalid_node: NodeProto) -> None:
179
+ quantizer = ONNXOpQuantizer()
180
+ quantizer.handlers = {"Dummy": MockHandler()}
181
+
182
+ # This error is created in our mock handler
183
+ with pytest.raises(ValueError, match="Invalid node parameters"):
184
+ quantizer.check_layer(invalid_node, {})
185
+
186
+
187
+ @pytest.mark.unit
188
+ def test_get_initializer_map_extracts_all() -> None:
189
+ one_f = 1.0
190
+ two_f = 2.0
191
+ count_init = 2
192
+ tensor1 = helper.make_tensor("a", TensorProto.FLOAT, [1], [one_f])
193
+ tensor2 = helper.make_tensor("b", TensorProto.FLOAT, [1], [two_f])
194
+ graph = helper.make_graph([], "g", [], [], initializer=[tensor1, tensor2])
195
+ model = helper.make_model(graph)
196
+
197
+ quantizer = ONNXOpQuantizer()
198
+ init_map = quantizer.get_initializer_map(model)
199
+ assert init_map["a"].float_data[0] == one_f
200
+ assert init_map["b"].float_data[0] == two_f
201
+
202
+ # Enhanced: check all properties
203
+ assert len(init_map) == count_init
204
+ assert init_map["a"].name == "a"
205
+ assert init_map["a"].data_type == TensorProto.FLOAT
206
+ assert init_map["a"].dims == [1]
207
+ assert init_map["b"].name == "b"
208
+ assert init_map["b"].data_type == TensorProto.FLOAT
209
+ assert init_map["b"].dims == [1]
210
+ # Using numpy_helper for consistency
211
+ assert onnx.numpy_helper.to_array(init_map["a"])[0] == one_f
212
+ assert onnx.numpy_helper.to_array(init_map["b"])[0] == two_f
213
+
214
+
215
+ @pytest.mark.unit
216
+ def test_check_layer_skips_handler_without_check_supported() -> None:
217
+ class NoCheckHandler:
218
+ def quantize(self, *args: tuple, **kwargs: dict[str, Any]) -> None:
219
+ pass # no check_supported
220
+
221
+ quantizer = ONNXOpQuantizer()
222
+ quantizer.register("NoCheckOp", NoCheckHandler())
223
+
224
+ node = helper.make_node("NoCheckOp", ["x"], ["y"])
225
+ # Should not raise
226
+ quantizer.check_layer(node, {})
227
+
228
+
229
+ @pytest.mark.unit
230
+ def test_register_overwrites_handler() -> None:
231
+ quantizer = ONNXOpQuantizer()
232
+ handler1 = MockHandler()
233
+ handler2 = MockHandler()
234
+
235
+ quantizer.register("Dummy", handler1)
236
+ quantizer.register("Dummy", handler2)
237
+
238
+ assert quantizer.handlers["Dummy"] is handler2
239
+
240
+
241
+ @pytest.mark.unit
242
+ def test_check_empty_model() -> None:
243
+ model = helper.make_model(helper.make_graph([], "empty", [], []))
244
+ quantizer = ONNXOpQuantizer()
245
+ # Should not raise
246
+ quantizer.check_model(model)
@@ -0,0 +1,121 @@
1
+ # This file performs very basic integration tests on each registered quantizer
2
+
3
+ import numpy as np
4
+ import onnx
5
+ import pytest
6
+ from onnx import helper
7
+
8
+ from python.core.model_processing.onnx_quantizer.layers.base import ScaleConfig
9
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
10
+ ONNXOpQuantizer,
11
+ )
12
+ from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
13
+
14
+
15
+ @pytest.fixture
16
+ def dummy_graph() -> onnx.GraphProto:
17
+ return onnx.GraphProto()
18
+
19
+
20
+ def mock_initializer_map(input_names: list[str]) -> dict[str, onnx.TensorProto]:
21
+ rng = np.random.default_rng(TEST_RNG_SEED)
22
+ return {
23
+ name: onnx.helper.make_tensor(
24
+ name=name,
25
+ data_type=onnx.TensorProto.FLOAT,
26
+ dims=[2, 2], # minimal shape
27
+ vals=rng.random(4, dtype=np.float32).tolist(),
28
+ )
29
+ for name in input_names
30
+ }
31
+
32
+
33
+ def get_required_input_names(op_type: str) -> list[str]:
34
+ try:
35
+ schema = onnx.defs.get_schema(op_type)
36
+ return [
37
+ inp.name or f"input{i}"
38
+ for i, inp in enumerate(schema.inputs)
39
+ if inp.option != 1
40
+ ] # 1 = optional
41
+ except Exception:
42
+ return ["input0"] # fallback
43
+
44
+
45
+ def validate_quantized_node(node_result: onnx.NodeProto, op_type: str) -> None:
46
+ """Validate a single quantized node."""
47
+ assert isinstance(node_result, onnx.NodeProto), f"Invalid node type for {op_type}"
48
+ assert node_result.op_type, f"Missing op_type for {op_type}"
49
+ assert node_result.output, f"Missing outputs for {op_type}"
50
+
51
+ try:
52
+ # Create a minimal model with custom opset for validation
53
+ temp_graph = onnx.GraphProto()
54
+ temp_graph.name = "temp_graph"
55
+
56
+ # Add dummy inputs/outputs to satisfy graph requirements
57
+ for inp in node_result.input:
58
+ if not any(vi.name == inp for vi in temp_graph.input):
59
+ temp_graph.input.append(
60
+ onnx.helper.make_tensor_value_info(
61
+ inp,
62
+ onnx.TensorProto.FLOAT,
63
+ [1],
64
+ ),
65
+ )
66
+ for out in node_result.output:
67
+ if not any(vi.name == out for vi in temp_graph.output):
68
+ temp_graph.output.append(
69
+ onnx.helper.make_tensor_value_info(
70
+ out,
71
+ onnx.TensorProto.FLOAT,
72
+ [1],
73
+ ),
74
+ )
75
+
76
+ temp_graph.node.append(node_result)
77
+ temp_model = onnx.helper.make_model(temp_graph)
78
+ custom_domain = onnx.helper.make_operatorsetid(
79
+ domain="ai.onnx.contrib",
80
+ version=1,
81
+ )
82
+ temp_model.opset_import.append(custom_domain)
83
+ onnx.checker.check_model(temp_model)
84
+ except onnx.checker.ValidationError as e:
85
+ pytest.fail(f"ONNX node validation failed for {op_type}: {e}")
86
+
87
+
88
+ @pytest.mark.integration
89
+ @pytest.mark.parametrize("op_type", list(ONNXOpQuantizer().handlers.keys()))
90
+ def test_registered_quantizer_quantize(
91
+ op_type: str,
92
+ dummy_graph: onnx.GraphProto,
93
+ ) -> None:
94
+ quantizer = ONNXOpQuantizer()
95
+ handler = quantizer.handlers[op_type]
96
+
97
+ inputs = get_required_input_names(op_type)
98
+ dummy_initializer_map = mock_initializer_map(inputs)
99
+
100
+ dummy_node = helper.make_node(
101
+ op_type=op_type,
102
+ inputs=inputs,
103
+ outputs=["dummy_output"],
104
+ )
105
+
106
+ result = handler.quantize(
107
+ node=dummy_node,
108
+ graph=dummy_graph,
109
+ scale_config=ScaleConfig(exponent=10, base=2, rescale=True),
110
+ initializer_map=dummy_initializer_map,
111
+ )
112
+ assert result is not None
113
+
114
+ # Enhanced assertions: validate result type and structure
115
+ if isinstance(result, list):
116
+ assert len(result) > 0, f"Quantize returned empty list for {op_type}"
117
+ for node_result in result:
118
+ validate_quantized_node(node_result, op_type)
119
+ else:
120
+ assert result.input, f"Missing inputs for {op_type}"
121
+ validate_quantized_node(result, op_type)