JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
  2. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
  3. python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +114 -59
  7. python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
  8. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  9. python/core/model_processing/onnx_custom_ops/mul.py +66 -0
  10. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  11. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  12. python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
  13. python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
  14. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  15. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  16. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  17. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  18. python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
  19. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  20. python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
  21. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
  22. python/core/utils/general_layer_functions.py +17 -12
  23. python/core/utils/model_registry.py +6 -3
  24. python/scripts/gen_and_bench.py +2 -2
  25. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  26. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  27. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  28. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  29. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  30. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  31. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  32. python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
  33. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  34. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  35. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  36. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  37. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  38. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  39. python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
  40. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  41. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  42. python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
  43. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  44. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  45. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  46. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  47. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  48. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  49. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
  50. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  51. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  52. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  53. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  54. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  55. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  56. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  57. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  58. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
  59. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
  60. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
  61. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,267 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+ from unittest.mock import Mock
5
+
6
+ import onnx
7
+ import pytest
8
+ from onnx import NodeProto
9
+
10
+ from python.core.model_processing.onnx_quantizer.exceptions import UnsupportedOpError
11
+
12
+ if TYPE_CHECKING:
13
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
14
+ ONNXOpQuantizer,
15
+ )
16
+ from python.tests.onnx_quantizer_tests.layers.base import (
17
+ LayerTestConfig,
18
+ LayerTestSpec,
19
+ SpecType,
20
+ )
21
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
22
+ from python.tests.onnx_quantizer_tests.layers_tests.base_test import (
23
+ BaseQuantizerTest,
24
+ )
25
+
26
+
27
+ class TestQuantize(BaseQuantizerTest):
28
+ """Tests for quantization functionality"""
29
+
30
+ __test__ = True
31
+
32
+ def setup_quantize_test(
33
+ self,
34
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
35
+ quantizer: ONNXOpQuantizer,
36
+ scale_exponent: int = 2,
37
+ scale_base: int = 10,
38
+ *,
39
+ rescale: bool = True,
40
+ ) -> tuple[onnx.NodeProto, tuple[str, LayerTestConfig, LayerTestSpec, NodeProto]]:
41
+ """Common setup for quantization tests"""
42
+ layer_name, config, test_spec = test_case_data
43
+
44
+ self._check_validation_dependency(test_case_data)
45
+
46
+ if test_spec.skip_reason:
47
+ pytest.skip(f"{layer_name}_{test_spec.name}: {test_spec.skip_reason}")
48
+
49
+ model = config.create_test_model(test_spec)
50
+ node = model.graph.node[0]
51
+ initializer_map = {init.name: init for init in model.graph.initializer}
52
+
53
+ mock_graph = Mock()
54
+ if node.op_type == "Constant":
55
+ mock_data_node = Mock()
56
+ mock_data_node.input = [node.output[0]]
57
+ mock_graph.node = [mock_data_node]
58
+
59
+ result = quantizer.quantize(
60
+ node=node,
61
+ rescale=rescale,
62
+ graph=mock_graph,
63
+ scale_exponent=scale_exponent,
64
+ scale_base=scale_base,
65
+ initializer_map=initializer_map,
66
+ )
67
+
68
+ return result, (layer_name, config, test_spec, node)
69
+
70
+ @pytest.mark.unit
71
+ @pytest.mark.parametrize(
72
+ "test_case_data",
73
+ TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type]
74
+ ids=BaseQuantizerTest._generate_test_id,
75
+ )
76
+ def test_quantize_individual_valid_cases(
77
+ self: TestQuantize,
78
+ quantizer: ONNXOpQuantizer,
79
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
80
+ ) -> None:
81
+ """Test quantization for each individual valid test case"""
82
+
83
+ scale_exponent, scale_base = 2, 10
84
+ rescale = True
85
+
86
+ result, (layer_name, _config, test_spec, _node) = self.setup_quantize_test(
87
+ test_case_data,
88
+ quantizer,
89
+ scale_exponent,
90
+ scale_base,
91
+ rescale=rescale,
92
+ )
93
+
94
+ # Test that the output of the quantizer quantize is in fact a node
95
+ if isinstance(result, list):
96
+ assert (
97
+ len(result) > 0
98
+ ), f"Quantize returned empty list for {layer_name}.{test_spec.name}"
99
+ for node_result in result:
100
+ assert isinstance(
101
+ node_result,
102
+ onnx.NodeProto,
103
+ ), f"Invalid node type returned for {layer_name}.{test_spec.name}"
104
+ else:
105
+ assert isinstance(
106
+ result,
107
+ onnx.NodeProto,
108
+ ), f"Quantize returned none node for {layer_name}.{test_spec.name}"
109
+
110
+ assert (
111
+ result is not None
112
+ ), f"Quantize returned None for {layer_name}.{test_spec.name}"
113
+
114
+ @pytest.mark.unit
115
+ @pytest.mark.parametrize(
116
+ "test_case_data",
117
+ TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type]
118
+ ids=BaseQuantizerTest._generate_test_id,
119
+ )
120
+ def test_quantize_preserves_node_names(
121
+ self: TestQuantize,
122
+ quantizer: ONNXOpQuantizer,
123
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
124
+ ) -> None:
125
+ """Test quantization for each individual valid test case"""
126
+
127
+ scale_exponent, scale_base = 2, 10
128
+ rescale = True
129
+ result, (_layer_name, config, _test_spec, node) = self.setup_quantize_test(
130
+ test_case_data,
131
+ quantizer,
132
+ scale_exponent,
133
+ scale_base,
134
+ rescale=rescale,
135
+ )
136
+ is_node_present = False
137
+
138
+ def check_node_and_analyze_parameters(
139
+ node: NodeProto,
140
+ result_node: NodeProto,
141
+ ) -> bool:
142
+ if node.op_type == "BatchNormalization":
143
+ pytest.skip(f"{node.op_type} alters the node structure by design")
144
+ if node.op_type in result_node.op_type:
145
+ # Assert there are no less attributes in the new node
146
+ assert len(node.attribute) <= len(result_node.attribute)
147
+ # Ensure that each original node's attributes
148
+ # are contained in the new nodes
149
+ for att in node.attribute:
150
+ assert att.name in [a.name for a in result_node.attribute]
151
+ return True
152
+ return False
153
+
154
+ # Check that result nodes have meaningful names and the relevant node is present
155
+ # And ensure that the new node has the same parameters as the old node
156
+ if isinstance(result, list):
157
+ for result_node in result:
158
+ assert (
159
+ result_node.name
160
+ ), f"Quantized node missing name for {config.op_type}"
161
+ assert (
162
+ result_node.op_type
163
+ ), f"Quantized node missing op_type for {config.op_type}"
164
+
165
+ is_node_present = is_node_present or check_node_and_analyze_parameters(
166
+ node,
167
+ result_node,
168
+ )
169
+ else:
170
+ assert result.name, f"Quantized node missing name for {config.op_type}"
171
+ is_node_present = is_node_present or check_node_and_analyze_parameters(
172
+ node,
173
+ result,
174
+ )
175
+
176
+ # Assert that the node is in fact present
177
+ assert (
178
+ is_node_present
179
+ ), "Cannot find quantized node relating to prequantized node"
180
+
181
+ @pytest.mark.unit
182
+ @pytest.mark.parametrize("scale_params", [(2, 10), (0, 5)])
183
+ @pytest.mark.parametrize(
184
+ "test_case_data",
185
+ TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type]
186
+ ids=BaseQuantizerTest._generate_test_id,
187
+ )
188
+ def test_quantize_with_different_scales(
189
+ self: TestQuantize,
190
+ quantizer: ONNXOpQuantizer,
191
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
192
+ scale_params: tuple[int, int],
193
+ ) -> None:
194
+ """Test quantization for each individual valid test case"""
195
+
196
+ # Test for both scale parameters
197
+ scale_exponent, scale_base = scale_params
198
+ rescale = True
199
+ result, (_layer_name, _config, _test_spec, _node) = self.setup_quantize_test(
200
+ test_case_data,
201
+ quantizer,
202
+ scale_exponent,
203
+ scale_base,
204
+ rescale=rescale,
205
+ )
206
+
207
+ # Should return valid result regardless of scale values
208
+ assert (
209
+ result is not None
210
+ ), f"Quantize returned None for scale={scale_exponent}, scale_base={scale_base}"
211
+
212
+ @pytest.mark.unit
213
+ @pytest.mark.parametrize("rescale", [True, False])
214
+ @pytest.mark.parametrize(
215
+ "test_case_data",
216
+ TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type]
217
+ ids=BaseQuantizerTest._generate_test_id,
218
+ )
219
+ def test_quantize_with_different_rescales(
220
+ self: TestQuantize,
221
+ quantizer: ONNXOpQuantizer,
222
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
223
+ *,
224
+ rescale: bool,
225
+ ) -> None:
226
+ """Test quantization for each individual valid test case"""
227
+
228
+ scale_exponent, scale_base = 2, 10
229
+
230
+ # Test that quantizing works with both rescaling values
231
+ result, (_layer_name, _config, _test_spec, _node) = self.setup_quantize_test(
232
+ test_case_data,
233
+ quantizer,
234
+ scale_exponent,
235
+ scale_base,
236
+ rescale=rescale,
237
+ )
238
+ assert result is not None, f"Quantize failed with rescale={rescale}"
239
+
240
+ @pytest.mark.unit
241
+ def test_quantize_unsupported_layer_returns_original(
242
+ self: TestQuantize,
243
+ quantizer: ONNXOpQuantizer,
244
+ ) -> None:
245
+ """Test that unsupported layers return Error in quantization process"""
246
+ from onnx import helper # noqa: PLC0415
247
+
248
+ mock_graph = Mock()
249
+ scale_exponent, scale_base = 2, 10
250
+ rescale = True
251
+
252
+ unsupported_node = helper.make_node(
253
+ "UnsupportedOp",
254
+ inputs=["input"],
255
+ outputs=["output"],
256
+ name="unsupported",
257
+ )
258
+ with pytest.raises(UnsupportedOpError) as excinfo:
259
+ _ = quantizer.quantize(
260
+ node=unsupported_node,
261
+ rescale=rescale,
262
+ graph=mock_graph,
263
+ scale_exponent=scale_exponent,
264
+ scale_base=scale_base,
265
+ initializer_map={},
266
+ )
267
+ assert "Unsupported op type: 'UnsupportedOp'" in str(excinfo.value)
@@ -0,0 +1,109 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import pytest
5
+
6
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
7
+ ONNXOpQuantizer,
8
+ )
9
+ from python.tests.onnx_quantizer_tests.layers.base import LayerTestConfig, SpecType
10
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
11
+
12
+
13
+ class TestScalability:
14
+ """Tests (meta) to verify the framework scales with new layers"""
15
+
16
+ @pytest.mark.unit
17
+ def test_adding_new_layer_config(self: TestScalability) -> None:
18
+ """Test that adding new layer configs is straightforward"""
19
+ two = 2
20
+ # Simulate adding a new layer type
21
+ new_layer_config = LayerTestConfig(
22
+ op_type="NewCustomOp",
23
+ valid_inputs=["input", "custom_param"],
24
+ valid_attributes={"custom_attr": 42},
25
+ required_initializers={"custom_param": np.array([1, 2, 3])},
26
+ )
27
+
28
+ # Verify config can create nodes and initializers
29
+ node = new_layer_config.create_node()
30
+ assert node.op_type == "NewCustomOp"
31
+ assert len(node.input) == two
32
+
33
+ initializers = new_layer_config.create_initializers()
34
+ assert "custom_param" in initializers
35
+
36
+ @pytest.mark.unit
37
+ def test_layer_config_extensibility(self: TestScalability) -> None:
38
+ """Test that layer configs consists of all registered handlers"""
39
+ configs = TestLayerFactory.get_layer_configs()
40
+
41
+ # Verify all expected layers are present
42
+ unsupported = ONNXOpQuantizer().handlers.keys() - set(configs.keys())
43
+ assert unsupported == set(), (
44
+ f"The following layers are not being configured for testing: {unsupported}."
45
+ " Please add configuration in tests/onnx_quantizer_tests/layers/"
46
+ )
47
+
48
+ # Verify each config has required components
49
+ for layer_type, config in configs.items():
50
+ err_msg = (
51
+ f"Quantization test config is not supported yet for {layer_type}"
52
+ " and must be implemented"
53
+ )
54
+ assert config.op_type == layer_type, err_msg
55
+ assert isinstance(
56
+ config.valid_inputs,
57
+ list,
58
+ ), err_msg
59
+ assert isinstance(
60
+ config.valid_attributes,
61
+ dict,
62
+ ), err_msg
63
+ assert isinstance(
64
+ config.required_initializers,
65
+ dict,
66
+ ), err_msg
67
+
68
+ @pytest.mark.unit
69
+ def test_every_layer_has_basic_and_e2e(self: TestScalability) -> None:
70
+ """Each registered layer must have at least one basic/valid test
71
+ and one e2e test."""
72
+ missing_basic = []
73
+ missing_e2e = []
74
+
75
+ # iterate over registered layers
76
+ for layer_name in TestLayerFactory.get_available_layers():
77
+ cases = TestLayerFactory.get_test_cases_by_layer(layer_name)
78
+ specs = [spec for _, _config, spec in cases]
79
+
80
+ # Consider a test "basic" if:
81
+ # - it has tag 'basic' or 'valid', OR
82
+ # - its spec_type is SpecType.VALID (if you use SpecType)
83
+ has_basic = any(
84
+ (
85
+ "basic" in getattr(s, "tags", set())
86
+ or "valid" in getattr(s, "tags", set())
87
+ or getattr(s, "spec_type", None) == SpecType.VALID
88
+ )
89
+ for s in specs
90
+ )
91
+
92
+ # Consider a test "e2e" if:
93
+ # - it has tag 'e2e', OR
94
+ # - its spec_type is SpecType.E2E (if you use that enum)
95
+ has_e2e = any(
96
+ (
97
+ "e2e" in getattr(s, "tags", set())
98
+ or getattr(s, "spec_type", None) == SpecType.E2E
99
+ )
100
+ for s in specs
101
+ )
102
+
103
+ if not has_basic:
104
+ missing_basic.append(layer_name)
105
+ if not has_e2e:
106
+ missing_e2e.append(layer_name)
107
+
108
+ assert not missing_basic, f"Layers missing a basic/valid test: {missing_basic}"
109
+ assert not missing_e2e, f"Layers missing an e2e test: {missing_e2e}"
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import onnx
6
+ import pytest
7
+
8
+ if TYPE_CHECKING:
9
+ from python.tests.onnx_quantizer_tests.layers.base import (
10
+ LayerTestConfig,
11
+ LayerTestSpec,
12
+ )
13
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
14
+ from python.tests.onnx_quantizer_tests.layers_tests.base_test import (
15
+ BaseQuantizerTest,
16
+ )
17
+
18
+
19
+ class TestValidation(BaseQuantizerTest):
20
+ """Ensure that layer factory models produce valid ONNX graphs."""
21
+
22
+ __test__ = True
23
+
24
+ @pytest.mark.unit
25
+ @pytest.mark.parametrize(
26
+ "test_case_data",
27
+ TestLayerFactory.get_all_test_cases(),
28
+ ids=BaseQuantizerTest._generate_test_id,
29
+ )
30
+ def test_factory_models_pass_onnx_validation(
31
+ self: TestValidation,
32
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
33
+ ) -> None:
34
+ layer_name, config, test_spec = test_case_data
35
+ test_case_id = f"{layer_name}_{test_spec.name}"
36
+
37
+ if test_spec.skip_reason:
38
+ pytest.skip(f"{test_case_id}: {test_spec.skip_reason}")
39
+
40
+ model = config.create_test_model(test_spec)
41
+ try:
42
+ onnx.checker.check_model(model)
43
+ except onnx.checker.ValidationError as e:
44
+ self._validation_failed_cases.add(test_case_id)
45
+ pytest.fail(f"Invalid ONNX model: {e}")
@@ -0,0 +1,228 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ import onnx
7
+ import pytest
8
+ from onnx import helper, numpy_helper
9
+
10
+ if TYPE_CHECKING:
11
+ from typing import Any
12
+
13
+ from onnx import GraphProto, ModelProto, NodeProto, TensorProto
14
+
15
+ from python.core.model_processing.onnx_quantizer.exceptions import (
16
+ HandlerImplementationError,
17
+ )
18
+ from python.core.model_processing.onnx_quantizer.layers.base import (
19
+ BaseOpQuantizer,
20
+ ScaleConfig,
21
+ )
22
+
23
+
24
+ class DummyQuantizer(BaseOpQuantizer):
25
+ def __init__(self: DummyQuantizer) -> None:
26
+ self.new_initializers = []
27
+
28
+
29
+ @pytest.fixture
30
+ def dummy_tensor() -> TensorProto:
31
+ return numpy_helper.from_array(np.array([[1.0, 2.0], [3.0, 4.0]]), name="W")
32
+
33
+
34
+ @pytest.fixture
35
+ def dummy_bias() -> TensorProto:
36
+ return numpy_helper.from_array(np.array([1.0, 2.0]), name="B")
37
+
38
+
39
+ @pytest.fixture
40
+ def dummy_node() -> NodeProto:
41
+ return helper.make_node(
42
+ "DummyOp",
43
+ inputs=["X", "W", "B"],
44
+ outputs=["Y"],
45
+ name="DummyOp",
46
+ )
47
+
48
+
49
+ @pytest.fixture
50
+ def dummy_graph() -> GraphProto:
51
+ return helper.make_graph([], "dummy_graph", inputs=[], outputs=[])
52
+
53
+
54
+ @pytest.fixture
55
+ def initializer_map(
56
+ dummy_tensor: TensorProto,
57
+ dummy_bias: TensorProto,
58
+ ) -> dict[str, TensorProto]:
59
+ return {"W": dummy_tensor, "B": dummy_bias}
60
+
61
+
62
+ @pytest.fixture
63
+ def minimal_model() -> ModelProto:
64
+ graph = onnx.helper.make_graph(
65
+ nodes=[], # No nodes
66
+ name="EmptyGraph",
67
+ inputs=[],
68
+ outputs=[],
69
+ initializer=[],
70
+ )
71
+ return onnx.helper.make_model(graph)
72
+
73
+
74
+ @pytest.fixture
75
+ def unsupported_model() -> ModelProto:
76
+ node = onnx.helper.make_node("UnsupportedOp", ["X"], ["Y"])
77
+ graph = onnx.helper.make_graph(
78
+ nodes=[node],
79
+ name="UnsupportedGraph",
80
+ inputs=[],
81
+ outputs=[],
82
+ initializer=[],
83
+ )
84
+ return onnx.helper.make_model(graph)
85
+
86
+
87
+ @pytest.mark.unit
88
+ def test_quantize_raises_not_implemented() -> None:
89
+ quantizer = BaseOpQuantizer()
90
+ with pytest.raises(
91
+ HandlerImplementationError,
92
+ ) as excinfo:
93
+ quantizer.quantize(
94
+ node=None,
95
+ graph=None,
96
+ scale_config=ScaleConfig(exponent=1, base=1, rescale=False),
97
+ initializer_map={},
98
+ )
99
+ assert "quantize() not implemented in subclass." in str(excinfo.value)
100
+
101
+
102
+ @pytest.mark.unit
103
+ def test_check_supported_returns_none(dummy_node: NodeProto) -> None:
104
+ quantizer = DummyQuantizer()
105
+ with pytest.raises(HandlerImplementationError) as excinfo:
106
+ quantizer.check_supported(dummy_node, {})
107
+
108
+ assert (
109
+ "Handler implementation error for operator 'DummyQuantizer':"
110
+ " check_supported() not implemented in subclass." in str(excinfo.value)
111
+ )
112
+
113
+
114
+ @pytest.mark.unit
115
+ def test_rescale_layer_modifies_node_output(
116
+ dummy_node: NodeProto,
117
+ dummy_graph: GraphProto,
118
+ ) -> None:
119
+ quantizer = DummyQuantizer()
120
+ result_nodes = quantizer.rescale_layer(
121
+ dummy_node,
122
+ scale_base=10,
123
+ scale_exponent=2,
124
+ graph=dummy_graph,
125
+ )
126
+ total_scale = 100.0
127
+ count_nodes = 2
128
+
129
+ assert len(result_nodes) == count_nodes
130
+ assert dummy_node.output[0] == "Y_raw"
131
+ assert result_nodes[1].op_type == "Div"
132
+ assert result_nodes[1].output[0] == "Y"
133
+
134
+ # Check if scale tensor added
135
+ assert len(quantizer.new_initializers) == 1
136
+ scale_tensor = quantizer.new_initializers[0]
137
+ assert scale_tensor.name.endswith("_scale")
138
+ assert scale_tensor.data_type == onnx.TensorProto.INT64
139
+ assert onnx.numpy_helper.to_array(scale_tensor)[0] == total_scale
140
+
141
+ # Validate that result_nodes are valid ONNX nodes
142
+ for node in result_nodes:
143
+ assert isinstance(node, onnx.NodeProto)
144
+ assert node.name
145
+ assert node.op_type
146
+ assert node.input
147
+ assert node.output
148
+
149
+ # Check Div node inputs: should divide Y_raw by scale
150
+ div_node = result_nodes[1]
151
+ assert len(div_node.input) == count_nodes
152
+ assert div_node.input[0] == "Y_raw"
153
+ assert div_node.input[1] == scale_tensor.name
154
+
155
+
156
+ @pytest.mark.unit
157
+ def test_add_nodes_w_and_b_creates_mul_and_cast(
158
+ dummy_node: NodeProto,
159
+ dummy_graph: GraphProto,
160
+ initializer_map: dict[str, Any],
161
+ ) -> None:
162
+ _ = dummy_graph
163
+ quantizer = DummyQuantizer()
164
+ exp = 2
165
+ base = 10
166
+ nodes, new_inputs = quantizer.add_nodes_w_and_b(
167
+ dummy_node,
168
+ scale_exponent=exp,
169
+ scale_base=base,
170
+ initializer_map=initializer_map,
171
+ )
172
+ four = 4
173
+ two = 2
174
+
175
+ assert len(nodes) == four # Mul + Cast for W, Mul + Cast for B
176
+ assert nodes[0].op_type == "Mul"
177
+ assert nodes[1].op_type == "Cast"
178
+ assert nodes[2].op_type == "Mul"
179
+ assert nodes[3].op_type == "Cast"
180
+ assert new_inputs == ["X", "W_scaled_cast", "B_scaled_cast"]
181
+ assert len(quantizer.new_initializers) == two
182
+
183
+ weight_scaled = base**exp
184
+ bias_scaled = base ** (exp * 2)
185
+
186
+ # Enhanced assertions: check node inputs/outputs and tensor details
187
+ # Mul for W: input W and W_scale, output W_scaled
188
+ assert nodes[0].input == ["W", "W_scale"]
189
+ assert nodes[0].output == ["W_scaled"]
190
+ # Cast for W: input W_scaled, output W_scaled_cast
191
+ assert nodes[1].input == ["W_scaled"]
192
+ assert nodes[1].output == ["W_scaled_cast"]
193
+ # Similarly for B
194
+ assert nodes[2].input == ["B", "B_scale"]
195
+ assert nodes[2].output == ["B_scaled"]
196
+ assert nodes[3].input == ["B_scaled"]
197
+ assert nodes[3].output == ["B_scaled_cast"]
198
+
199
+ # Check scale tensors
200
+ w_scale = quantizer.new_initializers[0]
201
+ b_scale = quantizer.new_initializers[1]
202
+ assert w_scale.name == "W_scale"
203
+ assert b_scale.name == "B_scale"
204
+ assert onnx.numpy_helper.to_array(w_scale)[0] == weight_scaled # 10**2
205
+ assert onnx.numpy_helper.to_array(b_scale)[0] == bias_scaled
206
+
207
+
208
+ @pytest.mark.unit
209
+ def test_insert_scale_node_creates_mul_and_cast(
210
+ dummy_tensor: TensorProto,
211
+ dummy_graph: GraphProto,
212
+ ) -> None:
213
+ _ = dummy_graph
214
+ quantizer = DummyQuantizer()
215
+ output_name, mul_node, cast_node = quantizer.insert_scale_node(
216
+ dummy_tensor,
217
+ scale_base=10,
218
+ scale_exponent=1,
219
+ )
220
+
221
+ assert mul_node.op_type == "Mul"
222
+ assert cast_node.op_type == "Cast"
223
+ assert "_scaled" in mul_node.output[0]
224
+ assert output_name.endswith("_cast")
225
+ assert len(quantizer.new_initializers) == 1
226
+ assert quantizer.new_initializers[0].name.endswith("_scale")
227
+ ten = 10.0
228
+ assert onnx.numpy_helper.to_array(quantizer.new_initializers[0])[0] == ten