JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (61) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
  2. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
  3. python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +114 -59
  7. python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
  8. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  9. python/core/model_processing/onnx_custom_ops/mul.py +66 -0
  10. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  11. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  12. python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
  13. python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
  14. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  15. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  16. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  17. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  18. python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
  19. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  20. python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
  21. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
  22. python/core/utils/general_layer_functions.py +17 -12
  23. python/core/utils/model_registry.py +6 -3
  24. python/scripts/gen_and_bench.py +2 -2
  25. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  26. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  27. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  28. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  29. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  30. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  31. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  32. python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
  33. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  34. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  35. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  36. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  37. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  38. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  39. python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
  40. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  41. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  42. python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
  43. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  44. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  45. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  46. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  47. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  48. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  49. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
  50. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  51. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  52. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  53. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  54. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  55. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  56. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  57. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  58. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
  59. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
  60. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
  61. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
1
1
  # test_converter.py
2
2
  import tempfile
3
+ from collections.abc import Generator
3
4
  from pathlib import Path
4
- from typing import Any, Generator
5
+ from typing import Any
5
6
  from unittest.mock import MagicMock, patch
6
7
 
8
+ import numpy as np
7
9
  import onnx
8
10
  import onnxruntime as ort
9
11
  import pytest
@@ -13,7 +15,7 @@ from onnx import TensorProto, helper
13
15
  from python.core.model_processing.converters.onnx_converter import ONNXConverter
14
16
 
15
17
 
16
- @pytest.fixture()
18
+ @pytest.fixture
17
19
  def temp_model_path(
18
20
  tmp_path: Generator[Path, None, None],
19
21
  ) -> Generator[Path, Any, None]:
@@ -26,7 +28,7 @@ def temp_model_path(
26
28
  model_path.unlink()
27
29
 
28
30
 
29
- @pytest.fixture()
31
+ @pytest.fixture
30
32
  def temp_quant_model_path(
31
33
  tmp_path: Generator[Path, None, None],
32
34
  ) -> Generator[Path, Any, None]:
@@ -39,7 +41,7 @@ def temp_quant_model_path(
39
41
  model_path.unlink()
40
42
 
41
43
 
42
- @pytest.fixture()
44
+ @pytest.fixture
43
45
  def converter() -> ONNXConverter:
44
46
  conv = ONNXConverter()
45
47
  conv.model = MagicMock(name="model")
@@ -47,7 +49,7 @@ def converter() -> ONNXConverter:
47
49
  return conv
48
50
 
49
51
 
50
- @pytest.mark.unit()
52
+ @pytest.mark.unit
51
53
  @patch("python.core.model_processing.converters.onnx_converter.onnx.save")
52
54
  def test_save_model(mock_save: MagicMock, converter: ONNXConverter) -> None:
53
55
  path = "model.onnx"
@@ -55,7 +57,7 @@ def test_save_model(mock_save: MagicMock, converter: ONNXConverter) -> None:
55
57
  mock_save.assert_called_once_with(converter.model, path)
56
58
 
57
59
 
58
- @pytest.mark.unit()
60
+ @pytest.mark.unit
59
61
  @patch("python.core.model_processing.converters.onnx_converter.onnx.load")
60
62
  def test_load_model(mock_load: MagicMock, converter: ONNXConverter) -> None:
61
63
  fake_model = MagicMock(name="onnx_model")
@@ -68,7 +70,7 @@ def test_load_model(mock_load: MagicMock, converter: ONNXConverter) -> None:
68
70
  assert converter.model == fake_model
69
71
 
70
72
 
71
- @pytest.mark.unit()
73
+ @pytest.mark.unit
72
74
  @patch("python.core.model_processing.converters.onnx_converter.onnx.save")
73
75
  def test_save_quantized_model(mock_save: MagicMock, converter: ONNXConverter) -> None:
74
76
  path = "quantized_model.onnx"
@@ -76,7 +78,7 @@ def test_save_quantized_model(mock_save: MagicMock, converter: ONNXConverter) ->
76
78
  mock_save.assert_called_once_with(converter.quantized_model, path)
77
79
 
78
80
 
79
- @pytest.mark.unit()
81
+ @pytest.mark.unit
80
82
  @patch("python.core.model_processing.converters.onnx_converter.Path.exists")
81
83
  @patch("python.core.model_processing.converters.onnx_converter.SessionOptions")
82
84
  @patch("python.core.model_processing.converters.onnx_converter.InferenceSession")
@@ -108,10 +110,12 @@ def test_load_quantized_model(
108
110
  assert converter.quantized_model == fake_model
109
111
 
110
112
 
111
- @pytest.mark.unit()
113
+ @pytest.mark.unit
112
114
  def test_get_outputs_with_mocked_session(converter: ONNXConverter) -> None:
113
- dummy_input = [[1.0]]
115
+ dummy_input = np.array([[1.0]]) # Use np.ndarray, not list
114
116
  dummy_output = [[2.0]]
117
+ converter.scale_base = 2
118
+ converter.scale_exponent = 10
115
119
 
116
120
  mock_sess = MagicMock()
117
121
 
@@ -132,7 +136,10 @@ def test_get_outputs_with_mocked_session(converter: ONNXConverter) -> None:
132
136
 
133
137
  result = converter.get_outputs(dummy_input)
134
138
 
135
- mock_sess.run.assert_called_once_with(["output"], {"input": dummy_input})
139
+ # Expect NumPy array to be passed into ort_sess.run()
140
+ expected_call_inputs = {"input": np.asarray(dummy_input)}
141
+ mock_sess.run.assert_called_once_with(["output"], expected_call_inputs)
142
+
136
143
  assert result == dummy_output
137
144
 
138
145
 
@@ -148,7 +155,7 @@ def create_dummy_model() -> onnx.ModelProto:
148
155
  return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)])
149
156
 
150
157
 
151
- @pytest.mark.integration()
158
+ @pytest.mark.integration
152
159
  def test_save_and_load_real_model() -> None:
153
160
  converter = ONNXConverter()
154
161
  model = create_dummy_model()
@@ -181,10 +188,12 @@ def test_save_and_load_real_model() -> None:
181
188
  assert converter.model.graph.node[0].op_type == "Identity"
182
189
 
183
190
 
184
- @pytest.mark.integration()
191
+ @pytest.mark.integration
185
192
  def test_real_inference_from_onnx() -> None:
186
193
  converter = ONNXConverter()
187
194
  converter.model = create_dummy_model()
195
+ converter.scale_base = 2
196
+ converter.scale_exponent = 10
188
197
 
189
198
  # Save and load into onnxruntime
190
199
  with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp:
@@ -0,0 +1 @@
1
+ TEST_RNG_SEED = 2
@@ -0,0 +1,13 @@
1
+ from .base import BaseLayerConfigProvider, LayerTestConfig
2
+ from .factory import TestLayerFactory
3
+
4
+ # Auto-discover and make available all config providers
5
+ # This triggers the discovery process when the package is imported
6
+ _all_configs = TestLayerFactory.get_layer_configs()
7
+
8
+ # Export the factory and base classes
9
+ __all__ = [
10
+ "BaseLayerConfigProvider",
11
+ "LayerTestConfig",
12
+ "TestLayerFactory",
13
+ ]
@@ -0,0 +1,102 @@
1
+ import numpy as np
2
+
3
+ from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
4
+ from python.tests.onnx_quantizer_tests.layers.base import (
5
+ BaseLayerConfigProvider,
6
+ LayerTestConfig,
7
+ LayerTestSpec,
8
+ e2e_test,
9
+ edge_case_test,
10
+ valid_test,
11
+ )
12
+
13
+
14
+ class AddConfigProvider(BaseLayerConfigProvider):
15
+ """Test configuration provider for Add layer"""
16
+
17
+ @property
18
+ def layer_name(self) -> str:
19
+ return "Add"
20
+
21
+ def get_config(self) -> LayerTestConfig:
22
+ return LayerTestConfig(
23
+ op_type="Add",
24
+ valid_inputs=["A", "B"],
25
+ valid_attributes={}, # Add has no layer-specific attributes
26
+ required_initializers={},
27
+ input_shapes={
28
+ "A": [1, 3, 4, 4],
29
+ "B": [1, 3, 4, 4],
30
+ },
31
+ output_shapes={
32
+ "add_output": [1, 3, 4, 4],
33
+ },
34
+ )
35
+
36
+ def get_test_specs(self) -> list[LayerTestSpec]:
37
+ rng = np.random.default_rng(TEST_RNG_SEED)
38
+ return [
39
+ # --- VALID TESTS ---
40
+ valid_test("basic")
41
+ .description("Basic elementwise Add of two same-shaped tensors")
42
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4])
43
+ .tags("basic", "elementwise", "add")
44
+ .build(),
45
+ valid_test("broadcast_add")
46
+ .description("Add with Numpy-style broadcasting along spatial dimensions")
47
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1])
48
+ .tags("broadcast", "elementwise", "add", "onnx14")
49
+ .build(),
50
+ valid_test("initializer_add")
51
+ .description(
52
+ "Add where second input (B) is a tensor initializer instead of input",
53
+ )
54
+ .override_input_shapes(A=[1, 3, 4, 4])
55
+ .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4)))
56
+ .tags("initializer", "elementwise", "add", "onnxruntime")
57
+ .build(),
58
+ valid_test("scalar_add")
59
+ .description("Add scalar (initializer) to tensor")
60
+ .override_input_shapes(A=[1, 3, 4, 4])
61
+ .override_initializer("B", np.array([2.0], dtype=np.float32))
62
+ .tags("scalar", "elementwise", "add")
63
+ .build(),
64
+ # --- E2E TESTS ---
65
+ e2e_test("e2e_add")
66
+ .description("End-to-end Add test with random inputs")
67
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4])
68
+ .override_output_shapes(add_output=[1, 3, 4, 4])
69
+ .tags("e2e", "add", "2d")
70
+ .build(),
71
+ e2e_test("e2e_initializer_add")
72
+ .description(
73
+ "Add where second input (B) is a tensor initializer instead of input",
74
+ )
75
+ .override_input_shapes(A=[1, 3, 4, 4])
76
+ .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4)))
77
+ .tags("initializer", "elementwise", "add", "onnxruntime")
78
+ .build(),
79
+ e2e_test("e2e_broadcast_add")
80
+ .description("Add with Numpy-style broadcasting along spatial dimensions")
81
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1])
82
+ .tags("broadcast", "elementwise", "add", "onnx14")
83
+ .build(),
84
+ e2e_test("e2e_scalar_add")
85
+ .description("Add scalar (initializer) to tensor")
86
+ .override_input_shapes(A=[1, 3, 4, 4])
87
+ .override_initializer("B", np.array([2.0], dtype=np.float32))
88
+ .tags("scalar", "elementwise", "add")
89
+ .build(),
90
+ # # --- EDGE CASES ---
91
+ edge_case_test("empty_tensor")
92
+ .description("Add with empty tensor input (zero elements)")
93
+ .override_input_shapes(A=[0], B=[0])
94
+ .tags("edge", "empty", "add")
95
+ .build(),
96
+ edge_case_test("large_tensor")
97
+ .description("Large tensor add performance/stress test")
98
+ .override_input_shapes(A=[1, 64, 256, 256], B=[1, 64, 256, 256])
99
+ .tags("large", "performance", "add")
100
+ .skip("Performance test, skipped by default")
101
+ .build(),
102
+ ]
@@ -0,0 +1,279 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass, field
5
+ from enum import Enum
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Callable
10
+
11
+
12
+ import numpy as np
13
+ import onnx
14
+ from onnx import TensorProto, helper, numpy_helper
15
+
16
+
17
+ class SpecType(Enum):
18
+ """Types of test specifications that can be run"""
19
+
20
+ VALID = "valid"
21
+ ERROR = "error"
22
+ EDGE_CASE = "edge_case"
23
+ E2E = "e2e"
24
+
25
+
26
+ @dataclass
27
+ class LayerTestSpec:
28
+ """Individual test specification that can be applied to a LayerTestConfig"""
29
+
30
+ name: str
31
+ spec_type: SpecType
32
+ description: str = ""
33
+
34
+ # Overrides for the base config
35
+ attr_overrides: dict[str, Any] = field(default_factory=dict)
36
+ initializer_overrides: dict[str, np.ndarray] = field(default_factory=dict)
37
+ input_overrides: list[str] = field(default_factory=list)
38
+ input_shape_overrides: dict[str, list[int]] = field(default_factory=dict)
39
+ output_shape_overrides: dict[str, list[int]] = field(default_factory=dict)
40
+
41
+ # Error test specific
42
+ expected_error: type | None = None
43
+ error_match: str | None = None
44
+
45
+ # Custom validation
46
+ custom_validator: Callable | None = None
47
+
48
+ # Test metadata
49
+ tags: list[str] = field(default_factory=list)
50
+ skip_reason: str | None = None
51
+
52
+ # Omit attributes
53
+ omit_attrs: list[str] = field(default_factory=list)
54
+
55
+ # Remove __post_init__ validation - we'll validate in the builder instead
56
+
57
+
58
+ class LayerTestConfig:
59
+ """Enhanced configuration class for layer-specific test data"""
60
+
61
+ def __init__(
62
+ self: LayerTestConfig,
63
+ op_type: str,
64
+ valid_inputs: list[str],
65
+ valid_attributes: dict[str, Any],
66
+ required_initializers: dict[str, np.ndarray],
67
+ input_shapes: dict[str, list[int]] | None = None,
68
+ output_shapes: dict[str, list[int]] | None = None,
69
+ ) -> None:
70
+ self.op_type = op_type
71
+ self.valid_inputs = valid_inputs
72
+ self.valid_attributes = valid_attributes
73
+ self.required_initializers = required_initializers
74
+ self.input_shapes = input_shapes or {"input": [1, 16, 224, 224]}
75
+ self.output_shapes = output_shapes or {f"{op_type.lower()}_output": [1, 10]}
76
+
77
+ def create_node(
78
+ self: LayerTestConfig,
79
+ name_suffix: str = "",
80
+ **attr_overrides: dict[str, Any],
81
+ ) -> onnx.NodeProto:
82
+ """Create a valid node for this layer type"""
83
+ attrs = {**self.valid_attributes, **attr_overrides}
84
+ return helper.make_node(
85
+ self.op_type,
86
+ inputs=self.valid_inputs,
87
+ outputs=[f"{self.op_type.lower()}_output{name_suffix}"],
88
+ name=f"test_{self.op_type.lower()}{name_suffix}",
89
+ **attrs,
90
+ )
91
+
92
+ def create_initializers(
93
+ self: LayerTestConfig,
94
+ **initializer_overrides: dict[str, Any],
95
+ ) -> dict[str, onnx.TensorProto]:
96
+ """Create initializer tensors for this layer"""
97
+ initializers = {}
98
+ combined_inits = {**self.required_initializers, **initializer_overrides}
99
+ for name, data in combined_inits.items():
100
+ # Special handling for shape tensors in Reshape, etc.
101
+ if name == "shape":
102
+ tensor = numpy_helper.from_array(data.astype(np.int64), name=name)
103
+ else:
104
+ tensor = numpy_helper.from_array(data.astype(np.float32), name=name)
105
+ initializers[name] = tensor
106
+ return initializers
107
+
108
+ def create_test_model(self, test_spec: LayerTestSpec) -> onnx.ModelProto:
109
+ """Create a complete model for a specific test case"""
110
+
111
+ # Determine node-level inputs.
112
+ # If dev overrides inputs explicitly,
113
+ # respect that; otherwise use original valid_inputs.
114
+ inputs = test_spec.input_overrides or self.valid_inputs
115
+
116
+ # Prepare attributes
117
+ attrs = {**self.valid_attributes, **test_spec.attr_overrides}
118
+ # Remove omitted attributes if specified
119
+ attrs = {**self.valid_attributes, **test_spec.attr_overrides}
120
+ for key in getattr(test_spec, "omit_attrs", []):
121
+ attrs.pop(key, None)
122
+
123
+ # Create initializers (may introduce overrides)
124
+ initializers = self.create_initializers(**test_spec.initializer_overrides)
125
+
126
+ # Apply shape overrides
127
+ input_shapes = {**self.input_shapes, **test_spec.input_shape_overrides}
128
+ output_shapes = {**self.output_shapes, **test_spec.output_shape_overrides}
129
+
130
+ # ----------------------------------------
131
+ # REMOVE graph inputs that are also initializers
132
+ # ----------------------------------------
133
+ initializer_names = set(initializers.keys())
134
+
135
+ # Also remove shapes for initializer inputs
136
+ for init_name in initializer_names:
137
+ input_shapes.pop(init_name, None)
138
+
139
+ # Create ONNX input value infos ONLY from filtered inputs
140
+ graph_inputs = [
141
+ helper.make_tensor_value_info(name, TensorProto.FLOAT, shape)
142
+ for name, shape in input_shapes.items()
143
+ ]
144
+
145
+ # Outputs stay unchanged
146
+ graph_outputs = [
147
+ helper.make_tensor_value_info(name, TensorProto.FLOAT, shape)
148
+ for name, shape in output_shapes.items()
149
+ ]
150
+
151
+ node = helper.make_node(
152
+ self.op_type,
153
+ inputs=inputs,
154
+ outputs=[f"{self.op_type.lower()}_output"],
155
+ name=f"test_{self.op_type.lower()}_{test_spec.name}",
156
+ **attrs,
157
+ )
158
+
159
+ # Build the graph
160
+ graph = helper.make_graph(
161
+ nodes=[node],
162
+ name=f"{self.op_type.lower()}_test_graph_{test_spec.name}",
163
+ inputs=graph_inputs,
164
+ outputs=graph_outputs,
165
+ initializer=list(initializers.values()),
166
+ )
167
+
168
+ return helper.make_model(graph)
169
+
170
+
171
+ class TestSpecBuilder:
172
+ """Builder for creating test specifications"""
173
+
174
+ def __init__(self, name: str, spec_type: SpecType) -> None:
175
+ self._spec = LayerTestSpec(name=name, spec_type=spec_type)
176
+
177
+ def description(self, desc: str) -> TestSpecBuilder:
178
+ self._spec.description = desc
179
+ return self
180
+
181
+ def override_attrs(self, **attrs: dict[str, Any]) -> TestSpecBuilder:
182
+ self._spec.attr_overrides.update(attrs)
183
+ return self
184
+
185
+ def omit_attrs(self, *attrs: str) -> TestSpecBuilder:
186
+ self._spec.omit_attrs.extend(attrs)
187
+ return self
188
+
189
+ def override_initializer(self, name: str, data: np.ndarray) -> TestSpecBuilder:
190
+ self._spec.initializer_overrides[name] = data
191
+ return self
192
+
193
+ def override_inputs(self, *inputs: str) -> TestSpecBuilder:
194
+ self._spec.input_overrides = list(inputs)
195
+ return self
196
+
197
+ def override_input_shapes(self, **shapes: dict[str, list[int]]) -> TestSpecBuilder:
198
+ self._spec.input_shape_overrides.update(shapes)
199
+ return self
200
+
201
+ def override_output_shapes(self, **shapes: dict[str, list[int]]) -> TestSpecBuilder:
202
+ self._spec.output_shape_overrides.update(shapes)
203
+ return self
204
+
205
+ def expects_error(
206
+ self,
207
+ error_type: type,
208
+ match: str | None = None,
209
+ ) -> TestSpecBuilder:
210
+ if self._spec.spec_type != SpecType.ERROR:
211
+ msg = "expects_error can only be used with ERROR spec type"
212
+ raise ValueError(msg)
213
+ self._spec.expected_error = error_type
214
+ self._spec.error_match = match
215
+ return self
216
+
217
+ def tags(self, *tags: str) -> TestSpecBuilder:
218
+ self._spec.tags.extend(tags)
219
+ return self
220
+
221
+ def skip(self, reason: str) -> TestSpecBuilder:
222
+ self._spec.skip_reason = reason
223
+ return self
224
+
225
+ def build(self) -> LayerTestSpec:
226
+ # Validate before building
227
+ if self._spec.spec_type == SpecType.ERROR and not self._spec.expected_error:
228
+ msg = (
229
+ f"Error test {self._spec.name} must"
230
+ " specify expected_error using .expects_error()"
231
+ )
232
+ raise ValueError(msg)
233
+ return self._spec
234
+
235
+
236
+ # Convenience functions
237
+ def valid_test(name: str) -> TestSpecBuilder:
238
+ return TestSpecBuilder(name, SpecType.VALID)
239
+
240
+
241
+ def error_test(name: str) -> TestSpecBuilder:
242
+ return TestSpecBuilder(name, SpecType.ERROR)
243
+
244
+
245
+ def edge_case_test(name: str) -> TestSpecBuilder:
246
+ return TestSpecBuilder(name, SpecType.EDGE_CASE)
247
+
248
+
249
+ def e2e_test(name: str) -> TestSpecBuilder:
250
+ return TestSpecBuilder(name, SpecType.E2E)
251
+
252
+
253
+ class BaseLayerConfigProvider(ABC):
254
+ """Abstract base class for layer config providers"""
255
+
256
+ @abstractmethod
257
+ def get_config(self) -> LayerTestConfig:
258
+ """Return the base configuration for this layer"""
259
+
260
+ @property
261
+ @abstractmethod
262
+ def layer_name(self) -> str:
263
+ """Return the layer name/op_type"""
264
+
265
+ def get_test_specs(self) -> list[LayerTestSpec]:
266
+ """Return test specifications for this layer (override for custom tests)"""
267
+ return []
268
+
269
+ def get_valid_test_specs(self) -> list[LayerTestSpec]:
270
+ """Get only valid test specifications"""
271
+ return [
272
+ spec for spec in self.get_test_specs() if spec.spec_type == SpecType.VALID
273
+ ]
274
+
275
+ def get_error_test_specs(self) -> list[LayerTestSpec]:
276
+ """Get only error test specifications"""
277
+ return [
278
+ spec for spec in self.get_test_specs() if spec.spec_type == SpecType.ERROR
279
+ ]