JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.1.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/METADATA +2 -2
  2. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/RECORD +51 -24
  3. python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +86 -32
  7. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  8. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  9. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  10. python/core/model_processing/onnx_quantizer/layers/base.py +121 -1
  11. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  12. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  13. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  14. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  15. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  16. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -1
  17. python/core/utils/general_layer_functions.py +17 -12
  18. python/core/utils/model_registry.py +6 -3
  19. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  20. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  21. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  22. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  23. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  24. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  25. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  26. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  27. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  28. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  29. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  30. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  31. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  32. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  33. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  34. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  35. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  36. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  37. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  38. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  39. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  40. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +265 -0
  41. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  42. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  43. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  44. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  45. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  46. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  47. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  48. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  49. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/WHEEL +0 -0
  50. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/entry_points.txt +0 -0
  51. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/licenses/LICENSE +0 -0
  52. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
1
+ import numpy as np
2
+
3
+ from python.tests.onnx_quantizer_tests.layers.base import (
4
+ e2e_test,
5
+ valid_test,
6
+ )
7
+ from python.tests.onnx_quantizer_tests.layers.factory import (
8
+ BaseLayerConfigProvider,
9
+ LayerTestConfig,
10
+ )
11
+
12
+
13
+ class ReshapeConfigProvider(BaseLayerConfigProvider):
14
+ """Test configuration provider for Reshape layers"""
15
+
16
+ @property
17
+ def layer_name(self) -> str:
18
+ return "Reshape"
19
+
20
+ def get_config(self) -> LayerTestConfig:
21
+ return LayerTestConfig(
22
+ op_type="Reshape",
23
+ valid_inputs=["input", "shape"],
24
+ valid_attributes={},
25
+ required_initializers={"shape": np.array([1, -1])},
26
+ )
27
+
28
+ def get_test_specs(self) -> list:
29
+ return [
30
+ # --- VALID TESTS ---
31
+ valid_test("basic")
32
+ .description("Basic Reshape from (1,2,3,4) to (1,24)")
33
+ .tags("basic", "reshape")
34
+ .build(),
35
+ valid_test("reshape_expand_dims")
36
+ .description("Reshape expanding dimensions (1,24) → (1,3,8)")
37
+ .override_input_shapes(input=[1, 24])
38
+ .tags("reshape", "expand")
39
+ .build(),
40
+ valid_test("reshape_flatten")
41
+ .description("Reshape to flatten spatial dimensions (1,3,4,4) → (1,48)")
42
+ .override_input_shapes(input=[1, 24])
43
+ .override_initializer("shape", np.array([1, 3, -1]))
44
+ .tags("reshape", "flatten")
45
+ .build(),
46
+ e2e_test("e2e_basic")
47
+ .description("End-to-end test for Reshape layer")
48
+ .override_input_shapes(input=[1, 2, 3, 4])
49
+ .override_output_shapes(reshape_output=[1, 24])
50
+ .override_initializer("shape", np.array([1, -1]))
51
+ .tags("e2e", "reshape")
52
+ .build(),
53
+ # --- EDGE CASE / SKIPPED TEST ---
54
+ valid_test("large_input")
55
+ .description("Large reshape performance test")
56
+ .override_input_shapes(input=[1, 3, 256, 256])
57
+ .override_initializer("shape", np.array([1, -1]))
58
+ .tags("large", "performance", "reshape")
59
+ # .skip("Performance test, skipped by default")
60
+ .build(),
61
+ ]
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ import pytest
6
+ from onnx import TensorProto, helper
7
+
8
+ if TYPE_CHECKING:
9
+ from onnx import ModelProto
10
+
11
+ from python.tests.onnx_quantizer_tests.layers.base import (
12
+ LayerTestConfig,
13
+ LayerTestSpec,
14
+ )
15
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
16
+ ONNXOpQuantizer,
17
+ )
18
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
19
+
20
+
21
+ class BaseQuantizerTest:
22
+ """Base test utilities for ONNX quantizer tests."""
23
+
24
+ __test__ = False # Prevent pytest from collecting this class directly
25
+
26
+ _validation_failed_cases: ClassVar[set[str]] = set()
27
+
28
+ @pytest.fixture
29
+ def quantizer(self) -> ONNXOpQuantizer:
30
+ return ONNXOpQuantizer()
31
+
32
+ @pytest.fixture
33
+ def layer_configs(self) -> dict[str, LayerTestConfig]:
34
+ return TestLayerFactory.get_layer_configs()
35
+
36
+ @staticmethod
37
+ def _generate_test_id(
38
+ test_case_tuple: tuple[str, LayerTestConfig, LayerTestSpec],
39
+ ) -> str:
40
+ try:
41
+ layer_name, _, test_spec = test_case_tuple
42
+ except Exception:
43
+ return str(test_case_tuple)
44
+ else:
45
+ return f"{layer_name}_{test_spec.name}"
46
+
47
+ @classmethod
48
+ def _check_validation_dependency(
49
+ cls: BaseQuantizerTest,
50
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
51
+ ) -> None:
52
+ layer_name, _, test_spec = test_case_data
53
+ test_case_id = f"{layer_name}_{test_spec.name}"
54
+ if test_case_id in cls._validation_failed_cases:
55
+ pytest.skip(f"Skipping because ONNX validation failed for {test_case_id}")
56
+
57
+ @staticmethod
58
+ def create_model_with_layers(
59
+ layer_types: list[str],
60
+ layer_configs: dict[str, LayerTestConfig],
61
+ ) -> ModelProto:
62
+ """Create a model composed of several layers."""
63
+ nodes, all_initializers = [], {}
64
+
65
+ for i, layer_type in enumerate(layer_types):
66
+ config = layer_configs[layer_type]
67
+ node = config.create_node(name_suffix=f"_{i}")
68
+ if i > 0:
69
+ prev_output = f"{layer_types[i-1].lower()}_output_{i-1}"
70
+ if node.input:
71
+ node.input[0] = prev_output
72
+ nodes.append(node)
73
+ all_initializers.update(config.create_initializers())
74
+
75
+ graph = helper.make_graph(
76
+ nodes,
77
+ "test_graph",
78
+ [
79
+ helper.make_tensor_value_info(
80
+ "input",
81
+ TensorProto.FLOAT,
82
+ [1, 16, 224, 224],
83
+ ),
84
+ ],
85
+ [
86
+ helper.make_tensor_value_info(
87
+ f"{layer_types[-1].lower()}_output_{len(layer_types)-1}",
88
+ TensorProto.FLOAT,
89
+ [1, 10],
90
+ ),
91
+ ],
92
+ initializer=list(all_initializers.values()),
93
+ )
94
+ return helper.make_model(graph)
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import pytest
6
+ from onnx import TensorProto, helper
7
+
8
+ from python.core.model_processing.onnx_quantizer.exceptions import (
9
+ InvalidParamError,
10
+ UnsupportedOpError,
11
+ )
12
+
13
+ if TYPE_CHECKING:
14
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
15
+ ONNXOpQuantizer,
16
+ )
17
+ from python.tests.onnx_quantizer_tests.layers.base import (
18
+ LayerTestConfig,
19
+ LayerTestSpec,
20
+ SpecType,
21
+ )
22
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
23
+ from python.tests.onnx_quantizer_tests.layers_tests.base_test import (
24
+ BaseQuantizerTest,
25
+ )
26
+
27
+
28
+ class TestCheckModel(BaseQuantizerTest):
29
+ """Tests for ONNX model checking."""
30
+
31
+ __test__ = True
32
+
33
+ @pytest.mark.unit
34
+ @pytest.mark.parametrize(
35
+ "test_case_data",
36
+ TestLayerFactory.get_test_cases_by_type(SpecType.VALID), # type: ignore[arg-type]
37
+ ids=BaseQuantizerTest._generate_test_id,
38
+ )
39
+ def test_check_model_individual_valid_cases(
40
+ self: TestCheckModel,
41
+ quantizer: ONNXOpQuantizer,
42
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
43
+ ) -> None:
44
+ """Test each individual valid test case"""
45
+ layer_name, config, test_spec = test_case_data
46
+
47
+ # Skips if layer is not a valid onnx layer
48
+ self._check_validation_dependency(test_case_data)
49
+
50
+ if test_spec.skip_reason:
51
+ pytest.skip(f"{layer_name}_{test_spec.name}: {test_spec.skip_reason}")
52
+
53
+ # Create model from layer specs
54
+ model = config.create_test_model(test_spec)
55
+
56
+ try:
57
+ quantizer.check_model(model)
58
+ except (InvalidParamError, UnsupportedOpError) as e:
59
+ pytest.fail(f"Model check failed for {layer_name}.{test_spec.name}: {e}")
60
+ except Exception as e:
61
+ pytest.fail(f"Model check failed for {layer_name}.{test_spec.name}: {e}")
62
+
63
+ @pytest.mark.unit
64
+ def test_check_model_unsupported_layer_fails(
65
+ self: TestCheckModel,
66
+ quantizer: ONNXOpQuantizer,
67
+ ) -> None:
68
+ """Test that models with unsupported layers fail validation"""
69
+ # Create model with unsupported operation
70
+ unsupported_node = helper.make_node(
71
+ "UnsupportedOp",
72
+ inputs=["input"],
73
+ outputs=["output"],
74
+ name="unsupported",
75
+ )
76
+
77
+ graph = helper.make_graph(
78
+ [unsupported_node],
79
+ "test_graph",
80
+ [
81
+ helper.make_tensor_value_info(
82
+ "input",
83
+ TensorProto.FLOAT,
84
+ [1, 16, 224, 224],
85
+ ),
86
+ ],
87
+ [helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 10])],
88
+ )
89
+
90
+ model = helper.make_model(graph)
91
+
92
+ with pytest.raises(UnsupportedOpError):
93
+ quantizer.check_model(model)
94
+
95
+ @pytest.mark.unit
96
+ @pytest.mark.parametrize(
97
+ "layer_combination",
98
+ [
99
+ ["Conv", "Relu"],
100
+ ["Conv", "Relu", "MaxPool"],
101
+ ["Gemm", "Relu"],
102
+ ["Conv", "Reshape", "Gemm"],
103
+ ["Conv", "Flatten", "Gemm"],
104
+ ],
105
+ )
106
+ def test_check_model_multi_layer_passes(
107
+ self: TestCheckModel,
108
+ quantizer: ONNXOpQuantizer,
109
+ layer_configs: dict[str, LayerTestConfig],
110
+ layer_combination: list[str],
111
+ ) -> None:
112
+ """Test that models with multiple supported layers pass validation"""
113
+ model = self.create_model_with_layers(layer_combination, layer_configs)
114
+ # Should not raise any exception
115
+ quantizer.check_model(model)
@@ -0,0 +1,196 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import pytest
6
+
7
+ from python.core.circuit_models.generic_onnx import GenericModelONNX
8
+ from python.core.utils.helper_functions import CircuitExecutionConfig, RunType
9
+ from python.tests.onnx_quantizer_tests.layers.base import SpecType
10
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
11
+ from python.tests.onnx_quantizer_tests.layers_tests.base_test import BaseQuantizerTest
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Generator
15
+ from pathlib import Path
16
+
17
+ from python.tests.onnx_quantizer_tests.layers.base import (
18
+ LayerTestConfig,
19
+ LayerTestSpec,
20
+ )
21
+
22
+
23
+ class TestE2EQuantizer(BaseQuantizerTest):
24
+ """End-to-end tests for ONNX quantizer layers."""
25
+
26
+ __test__ = True
27
+
28
+ @pytest.fixture
29
+ def temp_quantized_model(self, tmp_path: Path) -> Generator[Path, None, None]:
30
+ """Temporary path for quantized model."""
31
+ path = tmp_path / "quantized_model.onnx"
32
+ yield path
33
+ if path.exists():
34
+ path.unlink()
35
+
36
+ @pytest.fixture
37
+ def temp_circuit_path(self, tmp_path: Path) -> Generator[Path, None, None]:
38
+ """Temporary path for circuit file."""
39
+ path = tmp_path / "circuit.txt"
40
+ yield path
41
+ if path.exists():
42
+ path.unlink()
43
+
44
+ @pytest.fixture
45
+ def temp_witness_file(self, tmp_path: Path) -> Generator[Path, None, None]:
46
+ """Temporary path for witness file."""
47
+ path = tmp_path / "witness.bin"
48
+ yield path
49
+ if path.exists():
50
+ path.unlink()
51
+
52
+ @pytest.fixture
53
+ def temp_input_file(self, tmp_path: Path) -> Generator[Path, None, None]:
54
+ """Temporary path for input JSON file."""
55
+ path = tmp_path / "input.json"
56
+ yield path
57
+ if path.exists():
58
+ path.unlink()
59
+
60
+ @pytest.fixture
61
+ def temp_output_file(self, tmp_path: Path) -> Generator[Path, None, None]:
62
+ """Temporary path for output JSON file."""
63
+ path = tmp_path / "output.json"
64
+ yield path
65
+ if path.exists():
66
+ path.unlink()
67
+
68
+ @pytest.fixture
69
+ def temp_proof_file(self, tmp_path: Path) -> Generator[Path, None, None]:
70
+ """Temporary path for proof file."""
71
+ path = tmp_path / "proof.bin"
72
+ yield path
73
+ if path.exists():
74
+ path.unlink()
75
+
76
+ @pytest.mark.e2e
77
+ @pytest.mark.parametrize(
78
+ "test_case_data",
79
+ TestLayerFactory.get_test_cases_by_type(SpecType.E2E), # type: ignore[arg-type]
80
+ ids=BaseQuantizerTest._generate_test_id,
81
+ )
82
+ def test_e2e_quantize_compile_witness_prove_verify(
83
+ self,
84
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
85
+ temp_quantized_model: Path,
86
+ temp_circuit_path: Path,
87
+ temp_witness_file: Path,
88
+ temp_input_file: Path,
89
+ temp_output_file: Path,
90
+ temp_proof_file: Path,
91
+ capsys: pytest.CaptureFixture[str],
92
+ ) -> None:
93
+ """Test end-to-end flow: quantize model, compile circuit,
94
+ generate witness, prove, and verify."""
95
+ layer_name, config, test_spec = test_case_data
96
+
97
+ # Skip if validation failed or test is skipped
98
+ self._check_validation_dependency(test_case_data)
99
+ if test_spec.skip_reason:
100
+ pytest.skip(f"{layer_name}_{test_spec.name}: {test_spec.skip_reason}")
101
+
102
+ if layer_name == "Constant":
103
+ pytest.skip(f"No e2e test for layer {layer_name} yet")
104
+
105
+ # Create original model
106
+ original_model = config.create_test_model(test_spec)
107
+
108
+ # Save quantized model to temp location
109
+ import onnx # noqa: PLC0415
110
+
111
+ onnx.save(original_model, str(temp_quantized_model))
112
+
113
+ # Create GenericONNX model instance
114
+ model = GenericModelONNX(model_name=str(temp_quantized_model))
115
+
116
+ # Step 1: Compile circuit
117
+ model.base_testing(
118
+ CircuitExecutionConfig(
119
+ run_type=RunType.COMPILE_CIRCUIT,
120
+ dev_mode=False,
121
+ circuit_path=str(temp_circuit_path),
122
+ ),
123
+ )
124
+
125
+ # Verify circuit file exists
126
+ assert (
127
+ temp_circuit_path.exists()
128
+ ), f"Circuit file not created for {layer_name}.{test_spec.name}"
129
+
130
+ # Step 2: Generate witness
131
+ model.base_testing(
132
+ CircuitExecutionConfig(
133
+ run_type=RunType.GEN_WITNESS,
134
+ dev_mode=False,
135
+ witness_file=temp_witness_file,
136
+ circuit_path=str(temp_circuit_path),
137
+ input_file=temp_input_file,
138
+ output_file=temp_output_file,
139
+ write_json=True,
140
+ ),
141
+ )
142
+ # Verify witness and output files exist
143
+ assert (
144
+ temp_witness_file.exists()
145
+ ), f"Witness file not generated for {layer_name}.{test_spec.name}"
146
+ assert (
147
+ temp_output_file.exists()
148
+ ), f"Output file not generated for {layer_name}.{test_spec.name}"
149
+
150
+ # Step 3: Prove
151
+ model.base_testing(
152
+ CircuitExecutionConfig(
153
+ run_type=RunType.PROVE_WITNESS,
154
+ dev_mode=False,
155
+ witness_file=temp_witness_file,
156
+ circuit_path=str(temp_circuit_path),
157
+ input_file=temp_input_file,
158
+ output_file=temp_output_file,
159
+ proof_file=temp_proof_file,
160
+ ),
161
+ )
162
+
163
+ # Verify proof file exists
164
+ assert (
165
+ temp_proof_file.exists()
166
+ ), f"Proof file not generated for {layer_name}.{test_spec.name}"
167
+
168
+ # Step 4: Verify
169
+ model.base_testing(
170
+ CircuitExecutionConfig(
171
+ run_type=RunType.GEN_VERIFY,
172
+ dev_mode=False,
173
+ witness_file=temp_witness_file,
174
+ circuit_path=str(temp_circuit_path),
175
+ input_file=temp_input_file,
176
+ output_file=temp_output_file,
177
+ proof_file=temp_proof_file,
178
+ ),
179
+ )
180
+
181
+ # Capture output and check for success indicators
182
+ captured = capsys.readouterr()
183
+ stdout = captured.out
184
+ stderr = captured.err
185
+
186
+ assert stderr == "", "Errors occurred during e2e test for "
187
+ f"{layer_name}.{test_spec.name}: {stderr}"
188
+
189
+ # Check for expected success messages (similar to circuit e2e tests)
190
+ assert (
191
+ "Witness Generated" in stdout
192
+ ), f"Witness generation failed for {layer_name}.{test_spec.name}"
193
+ assert "Proved" in stdout, f"Proving failed for {layer_name}.{test_spec.name}"
194
+ assert (
195
+ "Verified" in stdout
196
+ ), f"Verification failed for {layer_name}.{test_spec.name}"
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import pytest
6
+
7
+ if TYPE_CHECKING:
8
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
9
+ ONNXOpQuantizer,
10
+ )
11
+ from python.tests.onnx_quantizer_tests.layers.base import (
12
+ LayerTestConfig,
13
+ LayerTestSpec,
14
+ SpecType,
15
+ )
16
+ from python.tests.onnx_quantizer_tests.layers.factory import TestLayerFactory
17
+ from python.tests.onnx_quantizer_tests.layers_tests.base_test import (
18
+ BaseQuantizerTest,
19
+ )
20
+
21
+
22
+ class TestErrorCases(BaseQuantizerTest):
23
+ """Tests for ONNX model checking."""
24
+
25
+ __test__ = True
26
+
27
+ @pytest.mark.unit
28
+ @pytest.mark.parametrize(
29
+ "test_case_data",
30
+ TestLayerFactory.get_test_cases_by_type(SpecType.ERROR), # type: ignore[arg-type]
31
+ ids=BaseQuantizerTest._generate_test_id,
32
+ )
33
+ def test_check_model_individual_error_cases(
34
+ self: TestErrorCases,
35
+ quantizer: ONNXOpQuantizer,
36
+ test_case_data: tuple[str, LayerTestConfig, LayerTestSpec],
37
+ ) -> None:
38
+ """Test each individual error test case"""
39
+ layer_name, config, test_spec = test_case_data
40
+
41
+ # Skips if layer is not a valid onnx layer
42
+ self._check_validation_dependency(test_case_data)
43
+
44
+ if test_spec.skip_reason:
45
+ pytest.skip(f"{layer_name}_{test_spec.name}: {test_spec.skip_reason}")
46
+
47
+ # Create model from layer specs
48
+ model = config.create_test_model(test_spec)
49
+
50
+ # Ensures that expected test is in fact raised
51
+ with pytest.raises(test_spec.expected_error) as exc:
52
+ quantizer.check_model(model)
53
+
54
+ # Ensures the error message is as expected
55
+ if isinstance(test_spec.error_match, list):
56
+ for e in test_spec.error_match:
57
+ assert e in str(exc.value)
58
+ else:
59
+ assert test_spec.error_match in str(exc.value)