JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.1.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/METADATA +2 -2
  2. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/RECORD +51 -24
  3. python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +86 -32
  7. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  8. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  9. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  10. python/core/model_processing/onnx_quantizer/layers/base.py +121 -1
  11. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  12. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  13. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  14. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  15. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  16. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -1
  17. python/core/utils/general_layer_functions.py +17 -12
  18. python/core/utils/model_registry.py +6 -3
  19. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  20. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  21. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  22. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  23. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  24. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  25. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  26. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  27. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  28. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  29. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  30. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  31. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  32. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  33. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  34. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  35. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  36. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  37. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  38. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  39. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  40. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +265 -0
  41. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  42. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  43. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  44. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  45. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  46. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  47. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  48. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  49. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/WHEEL +0 -0
  50. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/entry_points.txt +0 -0
  51. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/licenses/LICENSE +0 -0
  52. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,198 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+ from unittest.mock import Mock
5
+
6
+ import numpy as np
7
+ import pytest
8
+ from onnxruntime import InferenceSession, SessionOptions
9
+ from onnxruntime_extensions import get_library_path
10
+
11
+ from python.core.model_processing.converters.onnx_converter import ONNXConverter
12
+
13
+ if TYPE_CHECKING:
14
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
15
+ ONNXOpQuantizer,
16
+ )
17
+ from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
18
+ from python.tests.onnx_quantizer_tests.layers.base import (
19
+ LayerTestConfig,
20
+ LayerTestSpec,
21
+ SpecType,
22
+ )
23
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
24
+ from python.tests.onnx_quantizer_tests.layers_tests.base_test import (
25
+ BaseQuantizerTest,
26
+ )
27
+
28
+
29
+ class TestIntegration(BaseQuantizerTest):
30
+ """Integration tests for ONNX quantizer"""
31
+
32
+ __test__ = True
33
+
34
+ @pytest.mark.integration
35
+ @pytest.mark.parametrize(
36
+ "layer_combination",
37
+ [["Conv", "Relu"], ["Gemm", "Relu"], ["Conv", "MaxPool", "Flatten", "Gemm"]],
38
+ )
39
+ def test_check_then_quantize_workflow(
40
+ self: TestIntegration,
41
+ quantizer: ONNXOpQuantizer,
42
+ layer_configs: dict[str, LayerTestConfig],
43
+ layer_combination: list[str],
44
+ ) -> None:
45
+ """Test the typical workflow: check model then quantize layers"""
46
+ mock_graph = Mock()
47
+ scale_exponent, scale_base = 2, 10
48
+ rescale = True
49
+
50
+ # Step 1: Create and validate model
51
+ model = self.create_model_with_layers(layer_combination, layer_configs)
52
+ quantizer.check_model(model) # Should not raise
53
+
54
+ # Step 2: Quantize each layer
55
+ initializer_map = quantizer.get_initializer_map(model)
56
+
57
+ for node in model.graph.node:
58
+ result = quantizer.quantize(
59
+ node=node,
60
+ rescale=rescale,
61
+ graph=mock_graph,
62
+ scale_exponent=scale_exponent,
63
+ scale_base=scale_base,
64
+ initializer_map=initializer_map,
65
+ )
66
+ assert result is not None, (
67
+ f"Quantization failed for {node.op_type}"
68
+ f" in combination {layer_combination}"
69
+ )
70
+
71
+ def skip_by_layer_name(
72
+ self,
73
+ layer_name: str,
74
+ test_spec: LayerTestSpec,
75
+ skip_layer: str,
76
+ ) -> None:
77
+ # Skip Constant nodes as they don't depend on scaled inputs
78
+ if layer_name == skip_layer:
79
+ pytest.skip(
80
+ f"Skipping accuracy test for {layer_name}."
81
+ f"{test_spec.name} as constants are scaled differently",
82
+ )
83
+
84
+ @pytest.mark.integration
85
+ @pytest.mark.parametrize(
86
+ "test_case_data",
87
+ TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type]
88
+ ids=BaseQuantizerTest._generate_test_id,
89
+ )
90
+ def test_end_to_end_quantization_accuracy(
91
+ self: TestIntegration,
92
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
93
+ ) -> None:
94
+ """Test end-to-end quantization accuracy for each valid test case.
95
+
96
+ Builds a model from the layer config, runs inference on the original model,
97
+ quantizes the model, runs inference on the quantized model, and ensures
98
+ the outputs are close.
99
+ """
100
+ cosine_similarity = 0.995
101
+ rng = np.random.default_rng(TEST_RNG_SEED + 1)
102
+
103
+ layer_name, config, test_spec = test_case_data
104
+ self.skip_by_layer_name(layer_name, test_spec, skip_layer="Constant")
105
+
106
+ # Skip if validation failed or test is skipped
107
+ self._check_validation_dependency(test_case_data)
108
+ if test_spec.skip_reason:
109
+ pytest.skip(f"{layer_name}_{test_spec.name}: {test_spec.skip_reason}")
110
+
111
+ # Create original model
112
+ original_model = config.create_test_model(test_spec)
113
+ opts = SessionOptions()
114
+ opts.register_custom_ops_library(get_library_path())
115
+ original_session = InferenceSession(
116
+ original_model.SerializeToString(),
117
+ opts,
118
+ providers=["CPUExecutionProvider"],
119
+ )
120
+
121
+ input_shapes = {i.name: tuple(i.shape) for i in original_session.get_inputs()}
122
+
123
+ # Skip if no inputs (e.g., Constant nodes)
124
+ if not input_shapes:
125
+ pytest.skip(
126
+ f"No inputs for {layer_name}.{test_spec.name}, skipping accuracy test",
127
+ )
128
+
129
+ # Create dummy inputs for all graph inputs
130
+
131
+ dummy_inputs = {}
132
+ for name, shape in input_shapes.items():
133
+ dummy_inputs[name] = rng.normal(0, 1, shape).astype(np.float32)
134
+
135
+ # Run inference on original model
136
+ output_name = original_session.get_outputs()[0].name
137
+ original_output = original_session.run([output_name], dummy_inputs)[0]
138
+
139
+ # Quantize the model
140
+
141
+ converter = ONNXConverter()
142
+ scale_base, scale_exponent = (
143
+ 2,
144
+ 10,
145
+ ) # Smaller scale to reduce quantization errors
146
+ quantized_model = converter.quantize_model(
147
+ original_model,
148
+ scale_base=scale_base,
149
+ scale_exponent=scale_exponent,
150
+ rescale_config=None, # Use default rescale
151
+ )
152
+
153
+ # Run inference on quantized model
154
+ quantized_session = InferenceSession(
155
+ quantized_model.SerializeToString(),
156
+ opts,
157
+ providers=["CPUExecutionProvider"],
158
+ )
159
+ quantized_input_names = [inp.name for inp in quantized_session.get_inputs()]
160
+ quantized_output_name = quantized_session.get_outputs()[0].name
161
+
162
+ # For quantized model, scale the inputs
163
+ scaled_inputs = {}
164
+ for name in quantized_input_names:
165
+ if name in dummy_inputs:
166
+ scaled_inputs[name] = (dummy_inputs[name]).astype(np.float64)
167
+ else:
168
+ # If quantized model has different inputs, skip or handle
169
+ pytest.skip(
170
+ f"Quantized model input mismatch for {layer_name}.{test_spec.name}",
171
+ )
172
+
173
+ quantized_output = quantized_session.run(
174
+ [quantized_output_name],
175
+ scaled_inputs,
176
+ )[0]
177
+ quantized_output = quantized_output / (scale_base ** (scale_exponent))
178
+
179
+ ratio = np.mean(quantized_output / (original_output + 1e-12))
180
+ print(f"Mean output ratio (quantized/original): {ratio:.4f}")
181
+
182
+ # Compare outputs (quantized output should be close to original if rescale=True)
183
+ # Allow some tolerance due to quantization
184
+ np.testing.assert_allclose(
185
+ original_output,
186
+ quantized_output,
187
+ rtol=0.05, # Relative tolerance
188
+ atol=0.05, # Absolute tolerance
189
+ err_msg=f"Quantization accuracy failed for {layer_name}.{test_spec.name}",
190
+ )
191
+
192
+ cos_sim = np.dot(original_output.flatten(), quantized_output.flatten()) / (
193
+ np.linalg.norm(original_output.flatten())
194
+ * np.linalg.norm(quantized_output.flatten())
195
+ + 1e-12
196
+ )
197
+ print(f"Cosine similarity: {cos_sim:.6f}")
198
+ assert cos_sim > cosine_similarity, f"Low cosine similarity ({cos_sim:.6f})"
@@ -0,0 +1,265 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+ from unittest.mock import Mock
5
+
6
+ import onnx
7
+ import pytest
8
+ from onnx import NodeProto
9
+
10
+ from python.core.model_processing.onnx_quantizer.exceptions import UnsupportedOpError
11
+
12
+ if TYPE_CHECKING:
13
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
14
+ ONNXOpQuantizer,
15
+ )
16
+ from python.tests.onnx_quantizer_tests.layers.base import (
17
+ LayerTestConfig,
18
+ LayerTestSpec,
19
+ SpecType,
20
+ )
21
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
22
+ from python.tests.onnx_quantizer_tests.layers_tests.base_test import (
23
+ BaseQuantizerTest,
24
+ )
25
+
26
+
27
+ class TestQuantize(BaseQuantizerTest):
28
+ """Tests for quantization functionality"""
29
+
30
+ __test__ = True
31
+
32
+ def setup_quantize_test(
33
+ self,
34
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
35
+ quantizer: ONNXOpQuantizer,
36
+ scale_exponent: int = 2,
37
+ scale_base: int = 10,
38
+ *,
39
+ rescale: bool = True,
40
+ ) -> tuple[onnx.NodeProto, tuple[str, LayerTestConfig, LayerTestSpec, NodeProto]]:
41
+ """Common setup for quantization tests"""
42
+ layer_name, config, test_spec = test_case_data
43
+
44
+ self._check_validation_dependency(test_case_data)
45
+
46
+ if test_spec.skip_reason:
47
+ pytest.skip(f"{layer_name}_{test_spec.name}: {test_spec.skip_reason}")
48
+
49
+ model = config.create_test_model(test_spec)
50
+ node = model.graph.node[0]
51
+ initializer_map = {init.name: init for init in model.graph.initializer}
52
+
53
+ mock_graph = Mock()
54
+ if node.op_type == "Constant":
55
+ mock_data_node = Mock()
56
+ mock_data_node.input = [node.output[0]]
57
+ mock_graph.node = [mock_data_node]
58
+
59
+ result = quantizer.quantize(
60
+ node=node,
61
+ rescale=rescale,
62
+ graph=mock_graph,
63
+ scale_exponent=scale_exponent,
64
+ scale_base=scale_base,
65
+ initializer_map=initializer_map,
66
+ )
67
+
68
+ return result, (layer_name, config, test_spec, node)
69
+
70
+ @pytest.mark.unit
71
+ @pytest.mark.parametrize(
72
+ "test_case_data",
73
+ TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type]
74
+ ids=BaseQuantizerTest._generate_test_id,
75
+ )
76
+ def test_quantize_individual_valid_cases(
77
+ self: TestQuantize,
78
+ quantizer: ONNXOpQuantizer,
79
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
80
+ ) -> None:
81
+ """Test quantization for each individual valid test case"""
82
+
83
+ scale_exponent, scale_base = 2, 10
84
+ rescale = True
85
+
86
+ result, (layer_name, _config, test_spec, _node) = self.setup_quantize_test(
87
+ test_case_data,
88
+ quantizer,
89
+ scale_exponent,
90
+ scale_base,
91
+ rescale=rescale,
92
+ )
93
+
94
+ # Test that the output of the quantizer quantize is in fact a node
95
+ if isinstance(result, list):
96
+ assert (
97
+ len(result) > 0
98
+ ), f"Quantize returned empty list for {layer_name}.{test_spec.name}"
99
+ for node_result in result:
100
+ assert isinstance(
101
+ node_result,
102
+ onnx.NodeProto,
103
+ ), f"Invalid node type returned for {layer_name}.{test_spec.name}"
104
+ else:
105
+ assert isinstance(
106
+ result,
107
+ onnx.NodeProto,
108
+ ), f"Quantize returned none node for {layer_name}.{test_spec.name}"
109
+
110
+ assert (
111
+ result is not None
112
+ ), f"Quantize returned None for {layer_name}.{test_spec.name}"
113
+
114
+ @pytest.mark.unit
115
+ @pytest.mark.parametrize(
116
+ "test_case_data",
117
+ TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type]
118
+ ids=BaseQuantizerTest._generate_test_id,
119
+ )
120
+ def test_quantize_preserves_node_names(
121
+ self: TestQuantize,
122
+ quantizer: ONNXOpQuantizer,
123
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
124
+ ) -> None:
125
+ """Test quantization for each individual valid test case"""
126
+
127
+ scale_exponent, scale_base = 2, 10
128
+ rescale = True
129
+ result, (_layer_name, config, _test_spec, node) = self.setup_quantize_test(
130
+ test_case_data,
131
+ quantizer,
132
+ scale_exponent,
133
+ scale_base,
134
+ rescale=rescale,
135
+ )
136
+ is_node_present = False
137
+
138
+ def check_node_and_analyze_parameters(
139
+ node: NodeProto,
140
+ result_node: NodeProto,
141
+ ) -> bool:
142
+ if node.op_type in result_node.op_type:
143
+ # Assert there are no less attributes in the new node
144
+ assert len(node.attribute) <= len(result_node.attribute)
145
+ # Ensure that each original node's attributes
146
+ # are contained in the new nodes
147
+ for att in node.attribute:
148
+ assert att.name in [a.name for a in result_node.attribute]
149
+ return True
150
+ return False
151
+
152
+ # Check that result nodes have meaningful names and the relevant node is present
153
+ # And ensure that the new node has the same parameters as the old node
154
+ if isinstance(result, list):
155
+ for result_node in result:
156
+ assert (
157
+ result_node.name
158
+ ), f"Quantized node missing name for {config.op_type}"
159
+ assert (
160
+ result_node.op_type
161
+ ), f"Quantized node missing op_type for {config.op_type}"
162
+
163
+ is_node_present = is_node_present or check_node_and_analyze_parameters(
164
+ node,
165
+ result_node,
166
+ )
167
+ else:
168
+ assert result.name, f"Quantized node missing name for {config.op_type}"
169
+ is_node_present = is_node_present or check_node_and_analyze_parameters(
170
+ node,
171
+ result,
172
+ )
173
+
174
+ # Assert that the node is in fact present
175
+ assert (
176
+ is_node_present
177
+ ), "Cannot find quantized node relating to prequantized node"
178
+
179
+ @pytest.mark.unit
180
+ @pytest.mark.parametrize("scale_params", [(2, 10), (0, 5)])
181
+ @pytest.mark.parametrize(
182
+ "test_case_data",
183
+ TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type]
184
+ ids=BaseQuantizerTest._generate_test_id,
185
+ )
186
+ def test_quantize_with_different_scales(
187
+ self: TestQuantize,
188
+ quantizer: ONNXOpQuantizer,
189
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
190
+ scale_params: tuple[int, int],
191
+ ) -> None:
192
+ """Test quantization for each individual valid test case"""
193
+
194
+ # Test for both scale parameters
195
+ scale_exponent, scale_base = scale_params
196
+ rescale = True
197
+ result, (_layer_name, _config, _test_spec, _node) = self.setup_quantize_test(
198
+ test_case_data,
199
+ quantizer,
200
+ scale_exponent,
201
+ scale_base,
202
+ rescale=rescale,
203
+ )
204
+
205
+ # Should return valid result regardless of scale values
206
+ assert (
207
+ result is not None
208
+ ), f"Quantize returned None for scale={scale_exponent}, scale_base={scale_base}"
209
+
210
+ @pytest.mark.unit
211
+ @pytest.mark.parametrize("rescale", [True, False])
212
+ @pytest.mark.parametrize(
213
+ "test_case_data",
214
+ TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type]
215
+ ids=BaseQuantizerTest._generate_test_id,
216
+ )
217
+ def test_quantize_with_different_rescales(
218
+ self: TestQuantize,
219
+ quantizer: ONNXOpQuantizer,
220
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
221
+ *,
222
+ rescale: bool,
223
+ ) -> None:
224
+ """Test quantization for each individual valid test case"""
225
+
226
+ scale_exponent, scale_base = 2, 10
227
+
228
+ # Test that quantizing works with both rescaling values
229
+ result, (_layer_name, _config, _test_spec, _node) = self.setup_quantize_test(
230
+ test_case_data,
231
+ quantizer,
232
+ scale_exponent,
233
+ scale_base,
234
+ rescale=rescale,
235
+ )
236
+ assert result is not None, f"Quantize failed with rescale={rescale}"
237
+
238
+ @pytest.mark.unit
239
+ def test_quantize_unsupported_layer_returns_original(
240
+ self: TestQuantize,
241
+ quantizer: ONNXOpQuantizer,
242
+ ) -> None:
243
+ """Test that unsupported layers return Error in quantization process"""
244
+ from onnx import helper # noqa: PLC0415
245
+
246
+ mock_graph = Mock()
247
+ scale_exponent, scale_base = 2, 10
248
+ rescale = True
249
+
250
+ unsupported_node = helper.make_node(
251
+ "UnsupportedOp",
252
+ inputs=["input"],
253
+ outputs=["output"],
254
+ name="unsupported",
255
+ )
256
+ with pytest.raises(UnsupportedOpError) as excinfo:
257
+ _ = quantizer.quantize(
258
+ node=unsupported_node,
259
+ rescale=rescale,
260
+ graph=mock_graph,
261
+ scale_exponent=scale_exponent,
262
+ scale_base=scale_base,
263
+ initializer_map={},
264
+ )
265
+ assert "Unsupported op type: 'UnsupportedOp'" in str(excinfo.value)
@@ -0,0 +1,109 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import pytest
5
+
6
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
7
+ ONNXOpQuantizer,
8
+ )
9
+ from python.tests.onnx_quantizer_tests.layers.base import LayerTestConfig, SpecType
10
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
11
+
12
+
13
+ class TestScalability:
14
+ """Tests (meta) to verify the framework scales with new layers"""
15
+
16
+ @pytest.mark.unit
17
+ def test_adding_new_layer_config(self: TestScalability) -> None:
18
+ """Test that adding new layer configs is straightforward"""
19
+ two = 2
20
+ # Simulate adding a new layer type
21
+ new_layer_config = LayerTestConfig(
22
+ op_type="NewCustomOp",
23
+ valid_inputs=["input", "custom_param"],
24
+ valid_attributes={"custom_attr": 42},
25
+ required_initializers={"custom_param": np.array([1, 2, 3])},
26
+ )
27
+
28
+ # Verify config can create nodes and initializers
29
+ node = new_layer_config.create_node()
30
+ assert node.op_type == "NewCustomOp"
31
+ assert len(node.input) == two
32
+
33
+ initializers = new_layer_config.create_initializers()
34
+ assert "custom_param" in initializers
35
+
36
+ @pytest.mark.unit
37
+ def test_layer_config_extensibility(self: TestScalability) -> None:
38
+ """Test that layer configs consists of all registered handlers"""
39
+ configs = TestLayerFactory.get_layer_configs()
40
+
41
+ # Verify all expected layers are present
42
+ unsupported = ONNXOpQuantizer().handlers.keys() - set(configs.keys())
43
+ assert unsupported == set(), (
44
+ f"The following layers are not being configured for testing: {unsupported}."
45
+ " Please add configuration in tests/onnx_quantizer_tests/layers/"
46
+ )
47
+
48
+ # Verify each config has required components
49
+ for layer_type, config in configs.items():
50
+ err_msg = (
51
+ f"Quantization test config is not supported yet for {layer_type}"
52
+ " and must be implemented"
53
+ )
54
+ assert config.op_type == layer_type, err_msg
55
+ assert isinstance(
56
+ config.valid_inputs,
57
+ list,
58
+ ), err_msg
59
+ assert isinstance(
60
+ config.valid_attributes,
61
+ dict,
62
+ ), err_msg
63
+ assert isinstance(
64
+ config.required_initializers,
65
+ dict,
66
+ ), err_msg
67
+
68
+ @pytest.mark.unit
69
+ def test_every_layer_has_basic_and_e2e(self: TestScalability) -> None:
70
+ """Each registered layer must have at least one basic/valid test
71
+ and one e2e test."""
72
+ missing_basic = []
73
+ missing_e2e = []
74
+
75
+ # iterate over registered layers
76
+ for layer_name in TestLayerFactory.get_available_layers():
77
+ cases = TestLayerFactory.get_test_cases_by_layer(layer_name)
78
+ specs = [spec for _, _config, spec in cases]
79
+
80
+ # Consider a test "basic" if:
81
+ # - it has tag 'basic' or 'valid', OR
82
+ # - its spec_type is SpecType.VALID (if you use SpecType)
83
+ has_basic = any(
84
+ (
85
+ "basic" in getattr(s, "tags", set())
86
+ or "valid" in getattr(s, "tags", set())
87
+ or getattr(s, "spec_type", None) == SpecType.VALID
88
+ )
89
+ for s in specs
90
+ )
91
+
92
+ # Consider a test "e2e" if:
93
+ # - it has tag 'e2e', OR
94
+ # - its spec_type is SpecType.E2E (if you use that enum)
95
+ has_e2e = any(
96
+ (
97
+ "e2e" in getattr(s, "tags", set())
98
+ or getattr(s, "spec_type", None) == SpecType.E2E
99
+ )
100
+ for s in specs
101
+ )
102
+
103
+ if not has_basic:
104
+ missing_basic.append(layer_name)
105
+ if not has_e2e:
106
+ missing_e2e.append(layer_name)
107
+
108
+ assert not missing_basic, f"Layers missing a basic/valid test: {missing_basic}"
109
+ assert not missing_e2e, f"Layers missing an e2e test: {missing_e2e}"
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import onnx
6
+ import pytest
7
+
8
+ if TYPE_CHECKING:
9
+ from python.tests.onnx_quantizer_tests.layers.base import (
10
+ LayerTestConfig,
11
+ LayerTestSpec,
12
+ )
13
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
14
+ from python.tests.onnx_quantizer_tests.layers_tests.base_test import (
15
+ BaseQuantizerTest,
16
+ )
17
+
18
+
19
+ class TestValidation(BaseQuantizerTest):
20
+ """Ensure that layer factory models produce valid ONNX graphs."""
21
+
22
+ __test__ = True
23
+
24
+ @pytest.mark.unit
25
+ @pytest.mark.parametrize(
26
+ "test_case_data",
27
+ TestLayerFactory.get_all_test_cases(),
28
+ ids=BaseQuantizerTest._generate_test_id,
29
+ )
30
+ def test_factory_models_pass_onnx_validation(
31
+ self: TestValidation,
32
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
33
+ ) -> None:
34
+ layer_name, config, test_spec = test_case_data
35
+ test_case_id = f"{layer_name}_{test_spec.name}"
36
+
37
+ if test_spec.skip_reason:
38
+ pytest.skip(f"{test_case_id}: {test_spec.skip_reason}")
39
+
40
+ model = config.create_test_model(test_spec)
41
+ try:
42
+ onnx.checker.check_model(model)
43
+ except onnx.checker.ValidationError as e:
44
+ self._validation_failed_cases.add(test_case_id)
45
+ pytest.fail(f"Invalid ONNX model: {e}")