JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (61) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
  2. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
  3. python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +114 -59
  7. python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
  8. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  9. python/core/model_processing/onnx_custom_ops/mul.py +66 -0
  10. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  11. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  12. python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
  13. python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
  14. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  15. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  16. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  17. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  18. python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
  19. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  20. python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
  21. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
  22. python/core/utils/general_layer_functions.py +17 -12
  23. python/core/utils/model_registry.py +6 -3
  24. python/scripts/gen_and_bench.py +2 -2
  25. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  26. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  27. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  28. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  29. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  30. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  31. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  32. python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
  33. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  34. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  35. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  36. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  37. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  38. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  39. python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
  40. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  41. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  42. python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
  43. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  44. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  45. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  46. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  47. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  48. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  49. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
  50. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  51. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  52. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  53. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  54. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  55. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  56. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  57. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  58. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
  59. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
  60. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
  61. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,190 @@
1
+ import numpy as np
2
+
3
+ from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
4
+ from python.tests.onnx_quantizer_tests.layers.base import (
5
+ BaseLayerConfigProvider,
6
+ LayerTestConfig,
7
+ LayerTestSpec,
8
+ e2e_test,
9
+ valid_test,
10
+ )
11
+
12
+
13
+ class BatchNormConfigProvider(BaseLayerConfigProvider):
14
+ """Test configuration provider for BatchNorm (ONNX BatchNormalization op)"""
15
+
16
+ @property
17
+ def layer_name(self) -> str:
18
+ return "BatchNormalization"
19
+
20
+ def get_config(self) -> LayerTestConfig:
21
+ rng = np.random.default_rng(TEST_RNG_SEED)
22
+
23
+ # default shapes: N x C x H x W
24
+ default_input_shape = [1, 3, 4, 4]
25
+ c = default_input_shape[1]
26
+
27
+ # typical required initializers (scale, bias, mean, var) are length C
28
+ return LayerTestConfig(
29
+ op_type="BatchNormalization",
30
+ valid_inputs=["X", "scale", "B", "input_mean", "input_var"],
31
+ valid_attributes={
32
+ "epsilon": 1e-5,
33
+ "momentum": 0.9,
34
+ "training_mode": 0,
35
+ },
36
+ required_initializers={
37
+ # Defaults are stored as numpy arrays with shape (C,)
38
+ "scale": rng.normal(1.0, 0.5, c).astype(np.float32),
39
+ "B": rng.normal(0.0, 0.5, c).astype(np.float32),
40
+ "input_mean": rng.normal(0.0, 1.0, c).astype(np.float32),
41
+ "input_var": np.abs(rng.normal(1.0, 0.5, c)).astype(np.float32),
42
+ },
43
+ input_shapes={"X": default_input_shape},
44
+ output_shapes={"batchnormalization_output": default_input_shape},
45
+ )
46
+
47
+ def get_test_specs(self) -> list[LayerTestSpec]:
48
+ rng = np.random.default_rng(TEST_RNG_SEED)
49
+ c = 3
50
+
51
+ return [
52
+ # Basic valid tests
53
+ valid_test("basic_inference")
54
+ .description("Basic BatchNormalization inference: standard shapes")
55
+ .tags("basic", "inference", "batchnorm")
56
+ .build(),
57
+ valid_test("different_input_shape")
58
+ .description("Inference with different spatial dims")
59
+ .override_input_shapes(X=[2, c, 8, 8])
60
+ .override_output_shapes(batchnormalization_output=[2, c, 8, 8])
61
+ .tags("inference", "spatial")
62
+ .build(),
63
+ valid_test("epsilon_variation")
64
+ .description("Inference with larger epsilon for numerical stability")
65
+ .override_attrs(epsilon=1e-3)
66
+ .tags("epsilon")
67
+ .build(),
68
+ valid_test("momentum_variation")
69
+ .description(
70
+ "Inference with non-default momentum (has no effect in inference mode)",
71
+ )
72
+ .override_attrs(momentum=0.5)
73
+ .tags("momentum")
74
+ .build(),
75
+ valid_test("zero_mean_input")
76
+ .description("Input with zero mean")
77
+ .override_initializer("input_mean", np.zeros((c,), dtype=np.float32))
78
+ .tags("edge", "zero_mean")
79
+ .build(),
80
+ # Scalar / broadcast style tests
81
+ valid_test("per_channel_zero_variance")
82
+ .description(
83
+ "Edge case: very small variance values (clamped by epsilon), inference",
84
+ )
85
+ .override_initializer("input_var", np.full((c,), 1e-8, dtype=np.float32))
86
+ .override_attrs(epsilon=1e-5)
87
+ .tags("edge", "small_variance")
88
+ .build(),
89
+ # E2E tests that set explicit initializer values
90
+ e2e_test("e2e_inference")
91
+ .description("E2E inference test with explicit initializers")
92
+ .override_input_shapes(X=[1, c, 2, 2])
93
+ .override_output_shapes(batchnormalization_output=[1, c, 2, 2])
94
+ .override_initializer("scale", rng.normal(1.0, 0.1, c).astype(np.float32))
95
+ .override_initializer("B", rng.normal(0.0, 0.1, c).astype(np.float32))
96
+ .override_initializer(
97
+ "input_mean",
98
+ rng.normal(0.0, 0.1, c).astype(np.float32),
99
+ )
100
+ .override_initializer(
101
+ "input_var",
102
+ np.abs(rng.normal(0.5, 0.2, c)).astype(np.float32),
103
+ )
104
+ .tags("e2e", "inference")
105
+ .build(),
106
+ e2e_test("e2e_inference_small_2x2")
107
+ .description("E2E inference with small 2x2 spatial input")
108
+ .override_input_shapes(X=[1, 3, 2, 2])
109
+ .override_output_shapes(batchnormalization_output=[1, 3, 2, 2])
110
+ .override_initializer("scale", np.array([1.0, 0.9, 1.1], dtype=np.float32))
111
+ .override_initializer("B", np.array([0.0, 0.1, -0.1], dtype=np.float32))
112
+ .override_initializer(
113
+ "input_mean",
114
+ np.array([0.5, -0.5, 0.0], dtype=np.float32),
115
+ )
116
+ .override_initializer(
117
+ "input_var",
118
+ np.array([0.25, 0.5, 0.1], dtype=np.float32),
119
+ )
120
+ .tags("e2e", "small", "2x2")
121
+ .build(),
122
+ e2e_test("e2e_inference_wide_input")
123
+ .description("E2E inference with wider input shape (C=4, H=2, W=8)")
124
+ .override_input_shapes(X=[2, 4, 2, 8])
125
+ .override_output_shapes(batchnormalization_output=[2, 4, 2, 8])
126
+ .override_initializer(
127
+ "scale",
128
+ np.array([1.0, 0.8, 1.2, 0.9], dtype=np.float32),
129
+ )
130
+ .override_initializer(
131
+ "B",
132
+ np.array([0.0, 0.1, -0.1, 0.05], dtype=np.float32),
133
+ )
134
+ .override_initializer(
135
+ "input_mean",
136
+ np.array([0.0, 0.5, -0.5, 0.2], dtype=np.float32),
137
+ )
138
+ .override_initializer(
139
+ "input_var",
140
+ np.array([1.0, 0.5, 0.25, 0.1], dtype=np.float32),
141
+ )
142
+ .tags("e2e", "wide", "C4")
143
+ .build(),
144
+ e2e_test("e2e_inference_batch2_channels3")
145
+ .description("E2E inference with batch size 2 and 3 channels")
146
+ .override_input_shapes(X=[2, 3, 4, 4])
147
+ .override_output_shapes(batchnormalization_output=[2, 3, 4, 4])
148
+ .override_initializer("scale", np.array([0.5, 1.0, 1.5], dtype=np.float32))
149
+ .override_initializer("B", np.array([0.0, 0.0, 0.0], dtype=np.float32))
150
+ .override_initializer(
151
+ "input_mean",
152
+ np.array([-0.5, 0.0, 0.5], dtype=np.float32),
153
+ )
154
+ .override_initializer(
155
+ "input_var",
156
+ np.array([0.2, 0.5, 0.8], dtype=np.float32),
157
+ )
158
+ .tags("e2e", "batch2", "C3")
159
+ .build(),
160
+ e2e_test("e2e_inference_high_epsilon")
161
+ .description("E2E inference with high epsilon for numerical stability")
162
+ .override_input_shapes(X=[1, 2, 4, 4])
163
+ .override_output_shapes(batchnormalization_output=[1, 2, 4, 4])
164
+ .override_initializer("scale", np.array([1.0, 1.0], dtype=np.float32))
165
+ .override_initializer("B", np.array([0.1, -0.1], dtype=np.float32))
166
+ .override_initializer("input_mean", np.array([0.0, 0.5], dtype=np.float32))
167
+ .override_initializer(
168
+ "input_var",
169
+ np.array([0.0, 0.0], dtype=np.float32),
170
+ ) # tiny variance
171
+ .override_attrs(epsilon=1e-2)
172
+ .tags("e2e", "high_epsilon", "numerical_stability")
173
+ .build(),
174
+ e2e_test("e2e_inference_non_square")
175
+ .description("E2E inference with non-square spatial dimensions")
176
+ .override_input_shapes(X=[1, 3, 2, 5])
177
+ .override_output_shapes(batchnormalization_output=[1, 3, 2, 5])
178
+ .override_initializer("scale", np.array([1.0, 0.9, 1.1], dtype=np.float32))
179
+ .override_initializer("B", np.array([0.0, 0.1, -0.1], dtype=np.float32))
180
+ .override_initializer(
181
+ "input_mean",
182
+ np.array([0.1, -0.1, 0.0], dtype=np.float32),
183
+ )
184
+ .override_initializer(
185
+ "input_var",
186
+ np.array([0.5, 0.25, 0.75], dtype=np.float32),
187
+ )
188
+ .tags("e2e", "non_square", "C3")
189
+ .build(),
190
+ ]
@@ -0,0 +1,39 @@
1
+ import numpy as np
2
+ from onnx import numpy_helper
3
+
4
+ from python.tests.onnx_quantizer_tests.layers.base import e2e_test, valid_test
5
+ from python.tests.onnx_quantizer_tests.layers.factory import (
6
+ BaseLayerConfigProvider,
7
+ LayerTestConfig,
8
+ )
9
+
10
+
11
+ class ConstantConfigProvider(BaseLayerConfigProvider):
12
+ """Test configuration provider for Constant layers"""
13
+
14
+ @property
15
+ def layer_name(self) -> str:
16
+ return "Constant"
17
+
18
+ def get_config(self) -> LayerTestConfig:
19
+ return LayerTestConfig(
20
+ op_type="Constant",
21
+ valid_inputs=[],
22
+ valid_attributes={
23
+ "value": numpy_helper.from_array(np.array([1.0]), name="const_value"),
24
+ },
25
+ required_initializers={},
26
+ )
27
+
28
+ def get_test_specs(self) -> list:
29
+ return [
30
+ valid_test("basic")
31
+ .description("Basic Constant node returning scalar 1.0")
32
+ .tags("basic", "constant")
33
+ .build(),
34
+ e2e_test("e2e_basic")
35
+ .description("End-to-end test for Constant node")
36
+ .override_output_shapes(constant_output=[1])
37
+ .tags("e2e", "constant")
38
+ .build(),
39
+ ]
@@ -0,0 +1,154 @@
1
+ import numpy as np
2
+
3
+ from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
4
+ from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
5
+ from python.tests.onnx_quantizer_tests.layers.base import (
6
+ LayerTestSpec,
7
+ e2e_test,
8
+ error_test,
9
+ valid_test,
10
+ )
11
+ from python.tests.onnx_quantizer_tests.layers.factory import (
12
+ BaseLayerConfigProvider,
13
+ LayerTestConfig,
14
+ )
15
+
16
+
17
+ class ConvConfigProvider(BaseLayerConfigProvider):
18
+ """Test configuration provider for Conv layers"""
19
+
20
+ @property
21
+ def layer_name(self) -> str:
22
+ return "Conv"
23
+
24
+ def get_config(self) -> LayerTestConfig:
25
+ rng = np.random.default_rng(TEST_RNG_SEED)
26
+ return LayerTestConfig(
27
+ op_type="Conv",
28
+ valid_inputs=["input", "conv_weight", "conv_bias"],
29
+ valid_attributes={
30
+ "strides": [1, 1],
31
+ "kernel_shape": [3, 3],
32
+ "dilations": [1, 1],
33
+ "pads": [1, 1, 1, 1],
34
+ },
35
+ required_initializers={
36
+ "conv_weight": rng.normal(0, 1, (32, 16, 3, 3)),
37
+ "conv_bias": rng.normal(0, 1, 32),
38
+ },
39
+ )
40
+
41
+ def get_test_specs(self) -> list[LayerTestSpec]:
42
+ """Return all test specifications for Conv layers"""
43
+ rng = np.random.default_rng(TEST_RNG_SEED)
44
+ return [
45
+ # Valid variations
46
+ valid_test("basic")
47
+ .description("Basic 2D convolution")
48
+ .tags("basic", "2d")
49
+ .build(),
50
+ valid_test("different_padding")
51
+ .description("Convolution with different padding")
52
+ .override_attrs(pads=[2, 2, 2, 2], kernel_shape=[5, 5])
53
+ .override_initializer("conv_weight", rng.normal(0, 1, (32, 16, 5, 5)))
54
+ .tags("padding", "5x5_kernel")
55
+ .build(),
56
+ # E2E test
57
+ e2e_test("e2e_basic")
58
+ .description("End-to-end test for basic 2D convolution")
59
+ .override_input_shapes(input=[1, 3, 4, 4])
60
+ .override_output_shapes(conv_output=[1, 8, 4, 4])
61
+ .override_initializer("conv_weight", rng.normal(0, 1, (8, 3, 3, 3)))
62
+ .override_initializer("conv_bias", rng.normal(0, 1, 8))
63
+ .tags("e2e", "basic", "2d")
64
+ .build(),
65
+ # Error cases
66
+ error_test("no_bias")
67
+ .description("2D convolution without bias")
68
+ .override_inputs("input", "conv_weight")
69
+ .override_attrs(strides=[2, 2], kernel_shape=[5, 5])
70
+ .override_initializer("conv_weight", rng.normal(0, 1, (64, 16, 5, 5)))
71
+ .expects_error(
72
+ InvalidParamError,
73
+ "Expected at least 3 inputs (input, weights, bias), got 2",
74
+ )
75
+ .tags("no_bias", "stride_2")
76
+ .build(),
77
+ error_test("conv3d_unsupported")
78
+ .description("3D convolution should raise error")
79
+ .override_attrs(
80
+ kernel_shape=[3, 3, 3],
81
+ strides=[1, 1, 1],
82
+ dilations=[1, 1, 1],
83
+ pads=[1, 1, 1, 1, 1, 1],
84
+ )
85
+ .override_initializer(
86
+ "conv_weight",
87
+ rng.normal(0, 1, (32, 16, 3, 3, 3)),
88
+ )
89
+ .expects_error(
90
+ InvalidParamError,
91
+ "Unsupported Conv weight dimensionality 5",
92
+ )
93
+ .tags("3d", "unsupported")
94
+ .build(),
95
+ error_test("invalid_stride")
96
+ .description("Invalid stride values")
97
+ .override_attrs(strides=[0, 1])
98
+ .override_inputs("input", "conv_weight")
99
+ .expects_error(InvalidParamError, "stride must be positive")
100
+ .tags("invalid_params")
101
+ .skip("Not yet supported")
102
+ .build(),
103
+ error_test("negative_dilation")
104
+ .description("Negative dilation values")
105
+ .override_attrs(dilations=[-1, 1])
106
+ .expects_error(InvalidParamError, "dilation must be positive")
107
+ .tags("invalid_params")
108
+ .skip("Not yet supported")
109
+ .build(),
110
+ error_test("invalid_kernel_shape_long")
111
+ .description("kernel_shape too long (length 3)")
112
+ .override_attrs(kernel_shape=[3, 3, 3])
113
+ .override_initializer(
114
+ "conv_weight",
115
+ rng.normal(0, 1, (32, 16, 3, 3, 3)),
116
+ )
117
+ .expects_error(InvalidParamError, "kernel_shape")
118
+ .tags("invalid_attr_length")
119
+ .build(),
120
+ # Missing required attributes
121
+ error_test("missing_strides")
122
+ .description("Conv node missing 'strides' attribute")
123
+ .omit_attrs("strides") # exclude strides
124
+ .override_attrs(
125
+ kernel_shape=[3, 3],
126
+ dilations=[1, 1],
127
+ pads=[1, 1, 1, 1],
128
+ ) # supply others explicitly
129
+ .expects_error(InvalidParamError, "strides")
130
+ .tags("missing_attr")
131
+ .build(),
132
+ error_test("missing_kernel_shape")
133
+ .description("Conv node missing 'kernel_shape' attribute")
134
+ .omit_attrs("kernel_shape") # exclude kernel_shape
135
+ .override_attrs(
136
+ strides=[3, 3],
137
+ dilations=[1, 1],
138
+ pads=[1, 1, 1, 1],
139
+ ) # supply others explicitly
140
+ .expects_error(InvalidParamError, "kernel_shape")
141
+ .tags("missing_attr")
142
+ .build(),
143
+ error_test("missing_dilations")
144
+ .description("Conv node missing 'dilations' attribute")
145
+ .omit_attrs("dilations") # exclude dilations
146
+ .override_attrs(
147
+ strides=[3, 3],
148
+ kernel_shape=[3, 3],
149
+ pads=[1, 1, 1, 1],
150
+ ) # supply others explicitly
151
+ .expects_error(InvalidParamError, "dilations")
152
+ .tags("missing_attr")
153
+ .build(),
154
+ ]
@@ -0,0 +1,142 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import inspect
5
+ import logging
6
+ import typing
7
+ from pathlib import Path
8
+
9
+ from .base import BaseLayerConfigProvider, LayerTestConfig, LayerTestSpec, SpecType
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class TestLayerFactory:
15
+ """Enhanced factory for creating test configurations for different layer types"""
16
+
17
+ _providers: typing.ClassVar = {}
18
+ _initialized = False
19
+
20
+ @classmethod
21
+ def _discover_providers(cls: TestLayerFactory) -> None:
22
+ """Automatically discover all BaseLayerConfigProvider subclasses"""
23
+ if cls._initialized:
24
+ return
25
+
26
+ current_dir = Path(__file__).parent
27
+ config_files = [
28
+ f.stem
29
+ for f in Path(current_dir).iterdir()
30
+ if f.is_file() and f.name.endswith("_config.py") and f.name != "__init__.py"
31
+ ]
32
+ logger.debug("Discovered config files: %s", config_files)
33
+
34
+ for module_name in config_files:
35
+ try:
36
+ module = importlib.import_module(f".{module_name}", package=__package__)
37
+
38
+ for _, obj in inspect.getmembers(module, inspect.isclass):
39
+ if (
40
+ issubclass(obj, BaseLayerConfigProvider)
41
+ and obj is not BaseLayerConfigProvider
42
+ ):
43
+
44
+ provider_instance = obj()
45
+ cls._providers[provider_instance.layer_name] = provider_instance
46
+
47
+ except ImportError as e: # noqa: PERF203
48
+ msg = f"Warning: Could not import {module_name}: {e}"
49
+ logger.warning(msg)
50
+
51
+ cls._initialized = True
52
+
53
+ # Existing methods (keep your current implementation)
54
+ @classmethod
55
+ def get_layer_configs(cls) -> dict[str, LayerTestConfig]:
56
+ """Get test configurations for all supported layers"""
57
+ cls._discover_providers()
58
+ logger.debug("Retrieved layer configs: %s", list(cls._providers.keys()))
59
+ return {
60
+ name: provider.get_config() for name, provider in cls._providers.items()
61
+ }
62
+
63
+ @classmethod
64
+ def get_layer_config(cls, layer_name: str) -> LayerTestConfig:
65
+ """Get test configuration for a specific layer"""
66
+ cls._discover_providers()
67
+ if layer_name not in cls._providers:
68
+ msg = f"No config provider found for layer: {layer_name}"
69
+ raise ValueError(msg)
70
+ return cls._providers[layer_name].get_config()
71
+
72
+ @classmethod
73
+ def get_available_layers(cls) -> list[str]:
74
+ """Get list of all available layer types"""
75
+ cls._discover_providers()
76
+ return list(cls._providers.keys())
77
+
78
+ @classmethod
79
+ def register_provider(cls, provider: BaseLayerConfigProvider) -> None:
80
+ """Register a new config provider"""
81
+ cls._providers[provider.layer_name] = provider
82
+
83
+ # NEW enhanced methods for test specifications
84
+ @classmethod
85
+ def get_all_test_cases(cls) -> list[tuple[str, LayerTestConfig, LayerTestSpec]]:
86
+ """Get all test cases as (layer_name, config, test_spec) tuples"""
87
+ cls._discover_providers()
88
+ test_cases = []
89
+
90
+ for layer_name, provider in cls._providers.items():
91
+ config = provider.get_config()
92
+ test_specs = provider.get_test_specs()
93
+
94
+ # If no test specs defined, create a basic valid test
95
+ if not test_specs:
96
+
97
+ test_specs = [LayerTestSpec("basic", SpecType.VALID, "Basic test")]
98
+
99
+ test_cases.extend((layer_name, config, spec) for spec in test_specs)
100
+
101
+ return test_cases
102
+
103
+ @classmethod
104
+ def get_test_cases_by_type(
105
+ cls,
106
+ test_type: SpecType,
107
+ ) -> list[tuple[str, LayerTestConfig, LayerTestSpec]]:
108
+ """Get test cases of a specific type"""
109
+ return [
110
+ (layer, config, spec)
111
+ for layer, config, spec in cls.get_all_test_cases()
112
+ if spec.spec_type == test_type
113
+ ]
114
+
115
+ @classmethod
116
+ def get_test_cases_by_layer(
117
+ cls,
118
+ layer_name: str,
119
+ ) -> list[tuple[str, LayerTestConfig, LayerTestSpec]]:
120
+ """Get test cases for a specific layer"""
121
+ return [
122
+ (layer, config, spec)
123
+ for layer, config, spec in cls.get_all_test_cases()
124
+ if layer == layer_name
125
+ ]
126
+
127
+ @classmethod
128
+ def get_test_cases_by_tag(
129
+ cls,
130
+ tag: str,
131
+ ) -> list[tuple[str, LayerTestConfig, LayerTestSpec]]:
132
+ """Get test cases with a specific tag"""
133
+ result = [
134
+ (layer, config, spec)
135
+ for layer, config, spec in cls.get_all_test_cases()
136
+ if tag in spec.tags
137
+ ]
138
+ logger.debug("Found tests %s", result)
139
+ if not result:
140
+ msg = f"No test cases found for tag: {tag}"
141
+ raise ValueError(msg)
142
+ return result
@@ -0,0 +1,61 @@
1
+ from python.tests.onnx_quantizer_tests.layers.base import (
2
+ e2e_test,
3
+ valid_test,
4
+ )
5
+ from python.tests.onnx_quantizer_tests.layers.factory import (
6
+ BaseLayerConfigProvider,
7
+ LayerTestConfig,
8
+ )
9
+
10
+
11
+ class FlattenConfigProvider(BaseLayerConfigProvider):
12
+ """Test configuration provider for Flatten layers"""
13
+
14
+ @property
15
+ def layer_name(self) -> str:
16
+ return "Flatten"
17
+
18
+ def get_config(self) -> LayerTestConfig:
19
+ return LayerTestConfig(
20
+ op_type="Flatten",
21
+ valid_inputs=["input"],
22
+ valid_attributes={"axis": 1},
23
+ required_initializers={},
24
+ )
25
+
26
+ def get_test_specs(self) -> list:
27
+ return [
28
+ # --- VALID TESTS ---
29
+ valid_test("basic")
30
+ .description("Basic Flatten from (1,3,4,4) to (1,48)")
31
+ .tags("basic", "flatten")
32
+ .build(),
33
+ valid_test("flatten_axis0")
34
+ .description("Flatten with axis=0 (entire tensor flattened)")
35
+ .override_attrs(axis=0)
36
+ .tags("flatten", "axis0")
37
+ .build(),
38
+ valid_test("flatten_axis2")
39
+ .description("Flatten starting at axis=2")
40
+ .override_attrs(axis=2)
41
+ .tags("flatten", "axis2")
42
+ .build(),
43
+ valid_test("flatten_axis3")
44
+ .description("Flatten starting at axis=3 (minimal flatten)")
45
+ .override_attrs(axis=3)
46
+ .tags("flatten", "axis3")
47
+ .build(),
48
+ e2e_test("e2e_basic")
49
+ .description("End-to-end test for Flatten layer")
50
+ .override_input_shapes(input=[1, 3, 4, 4])
51
+ .override_output_shapes(flatten_output=[1, 48])
52
+ .tags("e2e", "flatten")
53
+ .build(),
54
+ # --- EDGE CASE / SKIPPED TEST ---
55
+ valid_test("large_input")
56
+ .description("Large input flatten (performance test)")
57
+ .override_input_shapes(input=[1, 3, 256, 256])
58
+ .tags("flatten", "large", "performance")
59
+ .skip("Performance test, skipped by default")
60
+ .build(),
61
+ ]