JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.1.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/METADATA +2 -2
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/RECORD +51 -24
- python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
- python/core/circuit_models/generic_onnx.py +43 -9
- python/core/circuits/base.py +231 -71
- python/core/model_processing/converters/onnx_converter.py +86 -32
- python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
- python/core/model_processing/onnx_custom_ops/relu.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +121 -1
- python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
- python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
- python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -1
- python/core/utils/general_layer_functions.py +17 -12
- python/core/utils/model_registry.py +6 -3
- python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
- python/tests/circuit_parent_classes/test_circuit.py +561 -38
- python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
- python/tests/onnx_quantizer_tests/__init__.py +1 -0
- python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
- python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/base.py +279 -0
- python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
- python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
- python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
- python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
- python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
- python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
- python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +265 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
- python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
- python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
- python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
- python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/WHEEL +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
|
|
4
|
+
from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
|
|
5
|
+
from python.tests.onnx_quantizer_tests.layers.base import (
|
|
6
|
+
LayerTestSpec,
|
|
7
|
+
e2e_test,
|
|
8
|
+
error_test,
|
|
9
|
+
valid_test,
|
|
10
|
+
)
|
|
11
|
+
from python.tests.onnx_quantizer_tests.layers.factory import (
|
|
12
|
+
BaseLayerConfigProvider,
|
|
13
|
+
LayerTestConfig,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ConvConfigProvider(BaseLayerConfigProvider):
|
|
18
|
+
"""Test configuration provider for Conv layers"""
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def layer_name(self) -> str:
|
|
22
|
+
return "Conv"
|
|
23
|
+
|
|
24
|
+
def get_config(self) -> LayerTestConfig:
|
|
25
|
+
rng = np.random.default_rng(TEST_RNG_SEED)
|
|
26
|
+
return LayerTestConfig(
|
|
27
|
+
op_type="Conv",
|
|
28
|
+
valid_inputs=["input", "conv_weight", "conv_bias"],
|
|
29
|
+
valid_attributes={
|
|
30
|
+
"strides": [1, 1],
|
|
31
|
+
"kernel_shape": [3, 3],
|
|
32
|
+
"dilations": [1, 1],
|
|
33
|
+
"pads": [1, 1, 1, 1],
|
|
34
|
+
},
|
|
35
|
+
required_initializers={
|
|
36
|
+
"conv_weight": rng.normal(0, 1, (32, 16, 3, 3)),
|
|
37
|
+
"conv_bias": rng.normal(0, 1, 32),
|
|
38
|
+
},
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def get_test_specs(self) -> list[LayerTestSpec]:
|
|
42
|
+
"""Return all test specifications for Conv layers"""
|
|
43
|
+
rng = np.random.default_rng(TEST_RNG_SEED)
|
|
44
|
+
return [
|
|
45
|
+
# Valid variations
|
|
46
|
+
valid_test("basic")
|
|
47
|
+
.description("Basic 2D convolution")
|
|
48
|
+
.tags("basic", "2d")
|
|
49
|
+
.build(),
|
|
50
|
+
valid_test("different_padding")
|
|
51
|
+
.description("Convolution with different padding")
|
|
52
|
+
.override_attrs(pads=[2, 2, 2, 2], kernel_shape=[5, 5])
|
|
53
|
+
.override_initializer("conv_weight", rng.normal(0, 1, (32, 16, 5, 5)))
|
|
54
|
+
.tags("padding", "5x5_kernel")
|
|
55
|
+
.build(),
|
|
56
|
+
# E2E test
|
|
57
|
+
e2e_test("e2e_basic")
|
|
58
|
+
.description("End-to-end test for basic 2D convolution")
|
|
59
|
+
.override_input_shapes(input=[1, 3, 4, 4])
|
|
60
|
+
.override_output_shapes(conv_output=[1, 8, 4, 4])
|
|
61
|
+
.override_initializer("conv_weight", rng.normal(0, 1, (8, 3, 3, 3)))
|
|
62
|
+
.override_initializer("conv_bias", rng.normal(0, 1, 8))
|
|
63
|
+
.tags("e2e", "basic", "2d")
|
|
64
|
+
.build(),
|
|
65
|
+
# Error cases
|
|
66
|
+
error_test("no_bias")
|
|
67
|
+
.description("2D convolution without bias")
|
|
68
|
+
.override_inputs("input", "conv_weight")
|
|
69
|
+
.override_attrs(strides=[2, 2], kernel_shape=[5, 5])
|
|
70
|
+
.override_initializer("conv_weight", rng.normal(0, 1, (64, 16, 5, 5)))
|
|
71
|
+
.expects_error(
|
|
72
|
+
InvalidParamError,
|
|
73
|
+
"Expected at least 3 inputs (input, weights, bias), got 2",
|
|
74
|
+
)
|
|
75
|
+
.tags("no_bias", "stride_2")
|
|
76
|
+
.build(),
|
|
77
|
+
error_test("conv3d_unsupported")
|
|
78
|
+
.description("3D convolution should raise error")
|
|
79
|
+
.override_attrs(
|
|
80
|
+
kernel_shape=[3, 3, 3],
|
|
81
|
+
strides=[1, 1, 1],
|
|
82
|
+
dilations=[1, 1, 1],
|
|
83
|
+
pads=[1, 1, 1, 1, 1, 1],
|
|
84
|
+
)
|
|
85
|
+
.override_initializer(
|
|
86
|
+
"conv_weight",
|
|
87
|
+
rng.normal(0, 1, (32, 16, 3, 3, 3)),
|
|
88
|
+
)
|
|
89
|
+
.expects_error(
|
|
90
|
+
InvalidParamError,
|
|
91
|
+
"Unsupported Conv weight dimensionality 5",
|
|
92
|
+
)
|
|
93
|
+
.tags("3d", "unsupported")
|
|
94
|
+
.build(),
|
|
95
|
+
error_test("invalid_stride")
|
|
96
|
+
.description("Invalid stride values")
|
|
97
|
+
.override_attrs(strides=[0, 1])
|
|
98
|
+
.override_inputs("input", "conv_weight")
|
|
99
|
+
.expects_error(InvalidParamError, "stride must be positive")
|
|
100
|
+
.tags("invalid_params")
|
|
101
|
+
.skip("Not yet supported")
|
|
102
|
+
.build(),
|
|
103
|
+
error_test("negative_dilation")
|
|
104
|
+
.description("Negative dilation values")
|
|
105
|
+
.override_attrs(dilations=[-1, 1])
|
|
106
|
+
.expects_error(InvalidParamError, "dilation must be positive")
|
|
107
|
+
.tags("invalid_params")
|
|
108
|
+
.skip("Not yet supported")
|
|
109
|
+
.build(),
|
|
110
|
+
error_test("invalid_kernel_shape_long")
|
|
111
|
+
.description("kernel_shape too long (length 3)")
|
|
112
|
+
.override_attrs(kernel_shape=[3, 3, 3])
|
|
113
|
+
.override_initializer(
|
|
114
|
+
"conv_weight",
|
|
115
|
+
rng.normal(0, 1, (32, 16, 3, 3, 3)),
|
|
116
|
+
)
|
|
117
|
+
.expects_error(InvalidParamError, "kernel_shape")
|
|
118
|
+
.tags("invalid_attr_length")
|
|
119
|
+
.build(),
|
|
120
|
+
# Missing required attributes
|
|
121
|
+
error_test("missing_strides")
|
|
122
|
+
.description("Conv node missing 'strides' attribute")
|
|
123
|
+
.omit_attrs("strides") # exclude strides
|
|
124
|
+
.override_attrs(
|
|
125
|
+
kernel_shape=[3, 3],
|
|
126
|
+
dilations=[1, 1],
|
|
127
|
+
pads=[1, 1, 1, 1],
|
|
128
|
+
) # supply others explicitly
|
|
129
|
+
.expects_error(InvalidParamError, "strides")
|
|
130
|
+
.tags("missing_attr")
|
|
131
|
+
.build(),
|
|
132
|
+
error_test("missing_kernel_shape")
|
|
133
|
+
.description("Conv node missing 'kernel_shape' attribute")
|
|
134
|
+
.omit_attrs("kernel_shape") # exclude kernel_shape
|
|
135
|
+
.override_attrs(
|
|
136
|
+
strides=[3, 3],
|
|
137
|
+
dilations=[1, 1],
|
|
138
|
+
pads=[1, 1, 1, 1],
|
|
139
|
+
) # supply others explicitly
|
|
140
|
+
.expects_error(InvalidParamError, "kernel_shape")
|
|
141
|
+
.tags("missing_attr")
|
|
142
|
+
.build(),
|
|
143
|
+
error_test("missing_dilations")
|
|
144
|
+
.description("Conv node missing 'dilations' attribute")
|
|
145
|
+
.omit_attrs("dilations") # exclude dilations
|
|
146
|
+
.override_attrs(
|
|
147
|
+
strides=[3, 3],
|
|
148
|
+
kernel_shape=[3, 3],
|
|
149
|
+
pads=[1, 1, 1, 1],
|
|
150
|
+
) # supply others explicitly
|
|
151
|
+
.expects_error(InvalidParamError, "dilations")
|
|
152
|
+
.tags("missing_attr")
|
|
153
|
+
.build(),
|
|
154
|
+
]
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import inspect
|
|
5
|
+
import logging
|
|
6
|
+
import typing
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from .base import BaseLayerConfigProvider, LayerTestConfig, LayerTestSpec, SpecType
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestLayerFactory:
|
|
15
|
+
"""Enhanced factory for creating test configurations for different layer types"""
|
|
16
|
+
|
|
17
|
+
_providers: typing.ClassVar = {}
|
|
18
|
+
_initialized = False
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def _discover_providers(cls: TestLayerFactory) -> None:
|
|
22
|
+
"""Automatically discover all BaseLayerConfigProvider subclasses"""
|
|
23
|
+
if cls._initialized:
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
current_dir = Path(__file__).parent
|
|
27
|
+
config_files = [
|
|
28
|
+
f.stem
|
|
29
|
+
for f in Path(current_dir).iterdir()
|
|
30
|
+
if f.is_file() and f.name.endswith("_config.py") and f.name != "__init__.py"
|
|
31
|
+
]
|
|
32
|
+
logger.debug("Discovered config files: %s", config_files)
|
|
33
|
+
|
|
34
|
+
for module_name in config_files:
|
|
35
|
+
try:
|
|
36
|
+
module = importlib.import_module(f".{module_name}", package=__package__)
|
|
37
|
+
|
|
38
|
+
for _, obj in inspect.getmembers(module, inspect.isclass):
|
|
39
|
+
if (
|
|
40
|
+
issubclass(obj, BaseLayerConfigProvider)
|
|
41
|
+
and obj is not BaseLayerConfigProvider
|
|
42
|
+
):
|
|
43
|
+
|
|
44
|
+
provider_instance = obj()
|
|
45
|
+
cls._providers[provider_instance.layer_name] = provider_instance
|
|
46
|
+
|
|
47
|
+
except ImportError as e: # noqa: PERF203
|
|
48
|
+
msg = f"Warning: Could not import {module_name}: {e}"
|
|
49
|
+
logger.warning(msg)
|
|
50
|
+
|
|
51
|
+
cls._initialized = True
|
|
52
|
+
|
|
53
|
+
# Existing methods (keep your current implementation)
|
|
54
|
+
@classmethod
|
|
55
|
+
def get_layer_configs(cls) -> dict[str, LayerTestConfig]:
|
|
56
|
+
"""Get test configurations for all supported layers"""
|
|
57
|
+
cls._discover_providers()
|
|
58
|
+
logger.debug("Retrieved layer configs: %s", list(cls._providers.keys()))
|
|
59
|
+
return {
|
|
60
|
+
name: provider.get_config() for name, provider in cls._providers.items()
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def get_layer_config(cls, layer_name: str) -> LayerTestConfig:
|
|
65
|
+
"""Get test configuration for a specific layer"""
|
|
66
|
+
cls._discover_providers()
|
|
67
|
+
if layer_name not in cls._providers:
|
|
68
|
+
msg = f"No config provider found for layer: {layer_name}"
|
|
69
|
+
raise ValueError(msg)
|
|
70
|
+
return cls._providers[layer_name].get_config()
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def get_available_layers(cls) -> list[str]:
|
|
74
|
+
"""Get list of all available layer types"""
|
|
75
|
+
cls._discover_providers()
|
|
76
|
+
return list(cls._providers.keys())
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def register_provider(cls, provider: BaseLayerConfigProvider) -> None:
|
|
80
|
+
"""Register a new config provider"""
|
|
81
|
+
cls._providers[provider.layer_name] = provider
|
|
82
|
+
|
|
83
|
+
# NEW enhanced methods for test specifications
|
|
84
|
+
@classmethod
|
|
85
|
+
def get_all_test_cases(cls) -> list[tuple[str, LayerTestConfig, LayerTestSpec]]:
|
|
86
|
+
"""Get all test cases as (layer_name, config, test_spec) tuples"""
|
|
87
|
+
cls._discover_providers()
|
|
88
|
+
test_cases = []
|
|
89
|
+
|
|
90
|
+
for layer_name, provider in cls._providers.items():
|
|
91
|
+
config = provider.get_config()
|
|
92
|
+
test_specs = provider.get_test_specs()
|
|
93
|
+
|
|
94
|
+
# If no test specs defined, create a basic valid test
|
|
95
|
+
if not test_specs:
|
|
96
|
+
|
|
97
|
+
test_specs = [LayerTestSpec("basic", SpecType.VALID, "Basic test")]
|
|
98
|
+
|
|
99
|
+
test_cases.extend((layer_name, config, spec) for spec in test_specs)
|
|
100
|
+
|
|
101
|
+
return test_cases
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def get_test_cases_by_type(
|
|
105
|
+
cls,
|
|
106
|
+
test_type: SpecType,
|
|
107
|
+
) -> list[tuple[str, LayerTestConfig, LayerTestSpec]]:
|
|
108
|
+
"""Get test cases of a specific type"""
|
|
109
|
+
return [
|
|
110
|
+
(layer, config, spec)
|
|
111
|
+
for layer, config, spec in cls.get_all_test_cases()
|
|
112
|
+
if spec.spec_type == test_type
|
|
113
|
+
]
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def get_test_cases_by_layer(
|
|
117
|
+
cls,
|
|
118
|
+
layer_name: str,
|
|
119
|
+
) -> list[tuple[str, LayerTestConfig, LayerTestSpec]]:
|
|
120
|
+
"""Get test cases for a specific layer"""
|
|
121
|
+
return [
|
|
122
|
+
(layer, config, spec)
|
|
123
|
+
for layer, config, spec in cls.get_all_test_cases()
|
|
124
|
+
if layer == layer_name
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def get_test_cases_by_tag(
|
|
129
|
+
cls,
|
|
130
|
+
tag: str,
|
|
131
|
+
) -> list[tuple[str, LayerTestConfig, LayerTestSpec]]:
|
|
132
|
+
"""Get test cases with a specific tag"""
|
|
133
|
+
result = [
|
|
134
|
+
(layer, config, spec)
|
|
135
|
+
for layer, config, spec in cls.get_all_test_cases()
|
|
136
|
+
if tag in spec.tags
|
|
137
|
+
]
|
|
138
|
+
logger.debug("Found tests %s", result)
|
|
139
|
+
if not result:
|
|
140
|
+
msg = f"No test cases found for tag: {tag}"
|
|
141
|
+
raise ValueError(msg)
|
|
142
|
+
return result
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from python.tests.onnx_quantizer_tests.layers.base import (
|
|
2
|
+
e2e_test,
|
|
3
|
+
valid_test,
|
|
4
|
+
)
|
|
5
|
+
from python.tests.onnx_quantizer_tests.layers.factory import (
|
|
6
|
+
BaseLayerConfigProvider,
|
|
7
|
+
LayerTestConfig,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FlattenConfigProvider(BaseLayerConfigProvider):
|
|
12
|
+
"""Test configuration provider for Flatten layers"""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def layer_name(self) -> str:
|
|
16
|
+
return "Flatten"
|
|
17
|
+
|
|
18
|
+
def get_config(self) -> LayerTestConfig:
|
|
19
|
+
return LayerTestConfig(
|
|
20
|
+
op_type="Flatten",
|
|
21
|
+
valid_inputs=["input"],
|
|
22
|
+
valid_attributes={"axis": 1},
|
|
23
|
+
required_initializers={},
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def get_test_specs(self) -> list:
|
|
27
|
+
return [
|
|
28
|
+
# --- VALID TESTS ---
|
|
29
|
+
valid_test("basic")
|
|
30
|
+
.description("Basic Flatten from (1,3,4,4) to (1,48)")
|
|
31
|
+
.tags("basic", "flatten")
|
|
32
|
+
.build(),
|
|
33
|
+
valid_test("flatten_axis0")
|
|
34
|
+
.description("Flatten with axis=0 (entire tensor flattened)")
|
|
35
|
+
.override_attrs(axis=0)
|
|
36
|
+
.tags("flatten", "axis0")
|
|
37
|
+
.build(),
|
|
38
|
+
valid_test("flatten_axis2")
|
|
39
|
+
.description("Flatten starting at axis=2")
|
|
40
|
+
.override_attrs(axis=2)
|
|
41
|
+
.tags("flatten", "axis2")
|
|
42
|
+
.build(),
|
|
43
|
+
valid_test("flatten_axis3")
|
|
44
|
+
.description("Flatten starting at axis=3 (minimal flatten)")
|
|
45
|
+
.override_attrs(axis=3)
|
|
46
|
+
.tags("flatten", "axis3")
|
|
47
|
+
.build(),
|
|
48
|
+
e2e_test("e2e_basic")
|
|
49
|
+
.description("End-to-end test for Flatten layer")
|
|
50
|
+
.override_input_shapes(input=[1, 3, 4, 4])
|
|
51
|
+
.override_output_shapes(flatten_output=[1, 48])
|
|
52
|
+
.tags("e2e", "flatten")
|
|
53
|
+
.build(),
|
|
54
|
+
# --- EDGE CASE / SKIPPED TEST ---
|
|
55
|
+
valid_test("large_input")
|
|
56
|
+
.description("Large input flatten (performance test)")
|
|
57
|
+
.override_input_shapes(input=[1, 3, 256, 256])
|
|
58
|
+
.tags("flatten", "large", "performance")
|
|
59
|
+
.skip("Performance test, skipped by default")
|
|
60
|
+
.build(),
|
|
61
|
+
]
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
|
|
4
|
+
from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
|
|
5
|
+
from python.tests.onnx_quantizer_tests.layers.base import (
|
|
6
|
+
LayerTestSpec,
|
|
7
|
+
e2e_test,
|
|
8
|
+
error_test,
|
|
9
|
+
valid_test,
|
|
10
|
+
)
|
|
11
|
+
from python.tests.onnx_quantizer_tests.layers.factory import (
|
|
12
|
+
BaseLayerConfigProvider,
|
|
13
|
+
LayerTestConfig,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GemmConfigProvider(BaseLayerConfigProvider):
|
|
18
|
+
"""Test configuration provider for Gemm layers"""
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def layer_name(self) -> str:
|
|
22
|
+
return "Gemm"
|
|
23
|
+
|
|
24
|
+
def get_config(self) -> LayerTestConfig:
|
|
25
|
+
rng = np.random.default_rng(TEST_RNG_SEED)
|
|
26
|
+
return LayerTestConfig(
|
|
27
|
+
op_type="Gemm",
|
|
28
|
+
valid_inputs=["input", "gemm_weight", "gemm_bias"],
|
|
29
|
+
valid_attributes={"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0},
|
|
30
|
+
required_initializers={
|
|
31
|
+
"gemm_weight": rng.normal(0, 1, (128, 256)),
|
|
32
|
+
"gemm_bias": rng.normal(0, 1, (1, 256)),
|
|
33
|
+
},
|
|
34
|
+
input_shapes={"input": [1, 128]}, # Match weight input dimension K=128
|
|
35
|
+
output_shapes={
|
|
36
|
+
"gemm_output": [1, 256],
|
|
37
|
+
}, # Match weight output dimension N=256
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def get_test_specs(self) -> list[LayerTestSpec]:
|
|
41
|
+
"""Return test specifications for Gemm layers"""
|
|
42
|
+
rng = np.random.default_rng(TEST_RNG_SEED)
|
|
43
|
+
return [
|
|
44
|
+
# --- VALID TESTS ---
|
|
45
|
+
valid_test("basic")
|
|
46
|
+
.description("Basic Gemm operation (no transposes, alpha=1, beta=1)")
|
|
47
|
+
.tags("basic")
|
|
48
|
+
.build(),
|
|
49
|
+
valid_test("transposed_weights")
|
|
50
|
+
.description("Gemm with transposed weight matrix (transB=1)")
|
|
51
|
+
.override_attrs(transB=1)
|
|
52
|
+
.override_initializer(
|
|
53
|
+
"gemm_weight",
|
|
54
|
+
rng.normal(0, 1, (256, 128)),
|
|
55
|
+
) # Transposed shape
|
|
56
|
+
.tags("transpose", "transB")
|
|
57
|
+
.build(),
|
|
58
|
+
valid_test("transposed_input")
|
|
59
|
+
.description("Gemm with transposed input (transA=1)")
|
|
60
|
+
.override_attrs(transA=1)
|
|
61
|
+
.override_input_shapes(input=[128, 1]) # Aᵀ shape → (K, M)
|
|
62
|
+
.override_output_shapes(gemm_output=[1, 256])
|
|
63
|
+
.tags("transpose", "transA")
|
|
64
|
+
.build(),
|
|
65
|
+
valid_test("double_transpose")
|
|
66
|
+
.description("Gemm with transA=1 and transB=1")
|
|
67
|
+
.override_attrs(transA=1, transB=1)
|
|
68
|
+
.override_input_shapes(input=[128, 1])
|
|
69
|
+
.override_initializer("gemm_weight", rng.normal(0, 1, (256, 128)))
|
|
70
|
+
.override_output_shapes(gemm_output=[1, 256])
|
|
71
|
+
.tags("transpose", "transA", "transB")
|
|
72
|
+
.build(),
|
|
73
|
+
e2e_test("e2e_basic")
|
|
74
|
+
.description("End-to-end test for basic Gemm layer")
|
|
75
|
+
.override_attrs(alpha=1.0, beta=1.0, transA=0, transB=0)
|
|
76
|
+
.override_input_shapes(input=[1, 4])
|
|
77
|
+
.override_output_shapes(gemm_output=[1, 8])
|
|
78
|
+
.override_initializer("gemm_weight", rng.normal(0, 1, (4, 8)))
|
|
79
|
+
.override_initializer("gemm_bias", rng.normal(0, 1, (1, 8)))
|
|
80
|
+
.tags("e2e", "basic")
|
|
81
|
+
.build(),
|
|
82
|
+
e2e_test("e2e_transA_small")
|
|
83
|
+
.description("Small end-to-end Gemm test with transposed input (transA=1)")
|
|
84
|
+
.override_attrs(transA=1, transB=0, alpha=1.0, beta=1.0)
|
|
85
|
+
.override_input_shapes(input=[4, 1]) # A^T shape → (K, M)
|
|
86
|
+
.override_output_shapes(gemm_output=[1, 6])
|
|
87
|
+
.override_initializer("gemm_weight", rng.normal(0, 1, (4, 6)))
|
|
88
|
+
.override_initializer("gemm_bias", rng.normal(0, 1, (1, 6)))
|
|
89
|
+
.tags("e2e", "transpose", "transA", "small")
|
|
90
|
+
.build(),
|
|
91
|
+
e2e_test("e2e_transB_small")
|
|
92
|
+
.description(
|
|
93
|
+
"Small end-to-end Gemm test with transposed weights (transB=1)",
|
|
94
|
+
)
|
|
95
|
+
.override_attrs(transA=0, transB=1, alpha=1.0, beta=1.0)
|
|
96
|
+
.override_input_shapes(input=[1, 4]) # A shape
|
|
97
|
+
.override_output_shapes(gemm_output=[1, 6])
|
|
98
|
+
.override_initializer("gemm_weight", rng.normal(0, 1, (6, 4))) # B^T shape
|
|
99
|
+
.override_initializer("gemm_bias", rng.normal(0, 1, (1, 6)))
|
|
100
|
+
.tags("e2e", "transpose", "transB", "small")
|
|
101
|
+
.build(),
|
|
102
|
+
e2e_test("e2e_transA_transB_small")
|
|
103
|
+
.description("Small end-to-end Gemm test with both matrices transposed")
|
|
104
|
+
.override_attrs(transA=1, transB=1, alpha=1.0, beta=1.0)
|
|
105
|
+
.override_input_shapes(input=[4, 1]) # A^T shape
|
|
106
|
+
.override_output_shapes(gemm_output=[1, 6])
|
|
107
|
+
.override_initializer("gemm_weight", rng.normal(0, 1, (6, 4))) # B^T shape
|
|
108
|
+
.override_initializer("gemm_bias", rng.normal(0, 1, (1, 6)))
|
|
109
|
+
.tags("e2e", "transpose", "transA", "transB", "small")
|
|
110
|
+
.build(),
|
|
111
|
+
# --- ERROR TESTS ---
|
|
112
|
+
# Add check on weights matrix in check_supported
|
|
113
|
+
error_test("invalid_alpha_type")
|
|
114
|
+
.description("Invalid alpha type (should be numeric)")
|
|
115
|
+
.override_attrs(alpha=-1.0)
|
|
116
|
+
.expects_error(
|
|
117
|
+
InvalidParamError,
|
|
118
|
+
"alpha value of -1.0 not supported [Attribute: alpha] [Expected: 1.0]",
|
|
119
|
+
)
|
|
120
|
+
.tags("invalid_param", "alpha")
|
|
121
|
+
.build(),
|
|
122
|
+
error_test("no_bias")
|
|
123
|
+
.description("Gemm without bias term (beta=0 should ignore bias)")
|
|
124
|
+
.override_inputs("input", "gemm_weight")
|
|
125
|
+
.override_attrs(beta=0.0)
|
|
126
|
+
.expects_error(InvalidParamError, match="3 inputs")
|
|
127
|
+
.tags("no_bias")
|
|
128
|
+
.build(),
|
|
129
|
+
error_test("different_alpha_beta")
|
|
130
|
+
.description("Gemm with different alpha and beta scaling factors")
|
|
131
|
+
.override_attrs(alpha=0.5, beta=2.0)
|
|
132
|
+
.expects_error(
|
|
133
|
+
InvalidParamError,
|
|
134
|
+
"alpha value of 0.5 not supported [Attribute: alpha] [Expected: 1.0]",
|
|
135
|
+
)
|
|
136
|
+
.tags("scaling", "alpha_beta")
|
|
137
|
+
.build(),
|
|
138
|
+
error_test("invalid_transA_value")
|
|
139
|
+
.description("transA must be 0 or 1")
|
|
140
|
+
.override_attrs(transA=2)
|
|
141
|
+
.expects_error(InvalidParamError, "transA value of 2 not supported")
|
|
142
|
+
.tags("transpose", "invalid_attr")
|
|
143
|
+
.build(),
|
|
144
|
+
error_test("invalid_transB_value")
|
|
145
|
+
.description("transB must be 0 or 1")
|
|
146
|
+
.override_attrs(transB=-1)
|
|
147
|
+
.expects_error(InvalidParamError, "transB value of -1 not supported")
|
|
148
|
+
.tags("transpose", "invalid_attr")
|
|
149
|
+
.build(),
|
|
150
|
+
# --- EDGE CASE / SKIPPED TESTS ---
|
|
151
|
+
valid_test("large_matrix")
|
|
152
|
+
.description("Large matrix multiplication performance test")
|
|
153
|
+
.override_initializer("gemm_weight", rng.normal(0, 1, (1024, 2048)))
|
|
154
|
+
.override_initializer("gemm_bias", rng.normal(0, 1, (1, 2048)))
|
|
155
|
+
.override_input_shapes(input=[1, 1024])
|
|
156
|
+
.override_output_shapes(gemm_output=[1, 2048])
|
|
157
|
+
.tags("large", "performance")
|
|
158
|
+
.skip("Performance test, not run by default")
|
|
159
|
+
.build(),
|
|
160
|
+
]
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
|
|
2
|
+
from python.tests.onnx_quantizer_tests.layers.base import (
|
|
3
|
+
e2e_test,
|
|
4
|
+
error_test,
|
|
5
|
+
valid_test,
|
|
6
|
+
)
|
|
7
|
+
from python.tests.onnx_quantizer_tests.layers.factory import (
|
|
8
|
+
BaseLayerConfigProvider,
|
|
9
|
+
LayerTestConfig,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MaxPoolConfigProvider(BaseLayerConfigProvider):
|
|
14
|
+
"""Test configuration provider for MaxPool layers"""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def layer_name(self) -> str:
|
|
18
|
+
return "MaxPool"
|
|
19
|
+
|
|
20
|
+
def get_config(self) -> LayerTestConfig:
|
|
21
|
+
return LayerTestConfig(
|
|
22
|
+
op_type="MaxPool",
|
|
23
|
+
valid_inputs=["input"],
|
|
24
|
+
valid_attributes={
|
|
25
|
+
"kernel_shape": [2, 2],
|
|
26
|
+
"strides": [2, 2],
|
|
27
|
+
"dilations": [1, 1],
|
|
28
|
+
"pads": [0, 0, 0, 0],
|
|
29
|
+
},
|
|
30
|
+
required_initializers={},
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def get_test_specs(self) -> list:
|
|
34
|
+
return [
|
|
35
|
+
# --- VALID TESTS ---
|
|
36
|
+
valid_test("basic")
|
|
37
|
+
.description("Basic MaxPool with 2x2 kernel and stride 2")
|
|
38
|
+
.tags("basic", "pool", "2d")
|
|
39
|
+
.build(),
|
|
40
|
+
valid_test("larger_kernel")
|
|
41
|
+
.description("MaxPool with 3x3 kernel and stride 1")
|
|
42
|
+
.override_attrs(kernel_shape=[3, 3], strides=[1, 1])
|
|
43
|
+
.tags("kernel_3x3", "stride_1", "pool")
|
|
44
|
+
.build(),
|
|
45
|
+
valid_test("dilated_pool")
|
|
46
|
+
.description("MaxPool with dilation > 1")
|
|
47
|
+
.override_attrs(dilations=[2, 2])
|
|
48
|
+
.tags("dilation", "pool")
|
|
49
|
+
.build(),
|
|
50
|
+
valid_test("stride_one")
|
|
51
|
+
.description("MaxPool with stride 1 (overlapping windows)")
|
|
52
|
+
.override_attrs(strides=[1, 1])
|
|
53
|
+
.tags("stride_1", "pool", "overlap")
|
|
54
|
+
.build(),
|
|
55
|
+
e2e_test("e2e_basic")
|
|
56
|
+
.description("End-to-end test for 2D MaxPool")
|
|
57
|
+
.override_input_shapes(input=[1, 3, 4, 4])
|
|
58
|
+
.override_output_shapes(maxpool_output=[1, 3, 2, 2])
|
|
59
|
+
.tags("e2e", "pool", "2d")
|
|
60
|
+
.build(),
|
|
61
|
+
# # --- ERROR TESTS ---
|
|
62
|
+
error_test("asymmetric_padding")
|
|
63
|
+
.description("MaxPool with asymmetric padding")
|
|
64
|
+
.override_attrs(pads=[1, 0, 2, 1])
|
|
65
|
+
.expects_error(InvalidParamError, "pads[2]=2 >= kernel[0]=2")
|
|
66
|
+
.tags("padding", "asymmetric", "pool")
|
|
67
|
+
.build(),
|
|
68
|
+
error_test("invalid_kernel_shape")
|
|
69
|
+
.description("Invalid kernel shape length (3D instead of 2D)")
|
|
70
|
+
.override_attrs(kernel_shape=[2, 2, 2])
|
|
71
|
+
.expects_error(InvalidParamError, "Currently only MaxPool2D is supported")
|
|
72
|
+
.tags("invalid_attr_length", "kernel_shape")
|
|
73
|
+
.build(),
|
|
74
|
+
# --- EDGE CASE / SKIPPED TEST ---
|
|
75
|
+
valid_test("large_input")
|
|
76
|
+
.description("Large MaxPool input (performance/stress test)")
|
|
77
|
+
.override_input_shapes(input=[1, 3, 64, 64])
|
|
78
|
+
.override_attrs(kernel_shape=[3, 3], strides=[2, 2])
|
|
79
|
+
.tags("large", "performance", "pool")
|
|
80
|
+
.skip("Performance test, skipped by default")
|
|
81
|
+
.build(),
|
|
82
|
+
]
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from python.tests.onnx_quantizer_tests.layers.base import (
|
|
2
|
+
e2e_test,
|
|
3
|
+
valid_test,
|
|
4
|
+
)
|
|
5
|
+
from python.tests.onnx_quantizer_tests.layers.factory import (
|
|
6
|
+
BaseLayerConfigProvider,
|
|
7
|
+
LayerTestConfig,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ReluConfigProvider(BaseLayerConfigProvider):
|
|
12
|
+
"""Test configuration provider for Relu layers"""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def layer_name(self) -> str:
|
|
16
|
+
return "Relu"
|
|
17
|
+
|
|
18
|
+
def get_config(self) -> LayerTestConfig:
|
|
19
|
+
return LayerTestConfig(
|
|
20
|
+
op_type="Relu",
|
|
21
|
+
valid_inputs=["input"],
|
|
22
|
+
valid_attributes={},
|
|
23
|
+
required_initializers={},
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def get_test_specs(self) -> list:
|
|
27
|
+
return [
|
|
28
|
+
# --- VALID TESTS ---
|
|
29
|
+
valid_test("basic")
|
|
30
|
+
.description("Basic ReLU activation")
|
|
31
|
+
.tags("basic", "activation")
|
|
32
|
+
.build(),
|
|
33
|
+
valid_test("negative_inputs")
|
|
34
|
+
.description("ReLU should zero out negative input values")
|
|
35
|
+
.override_input_shapes(input=[1, 3, 4, 4])
|
|
36
|
+
.tags("activation", "negative_values")
|
|
37
|
+
.build(),
|
|
38
|
+
valid_test("high_dimension_input")
|
|
39
|
+
.description("ReLU applied to a 5D input tensor (NCHWT layout)")
|
|
40
|
+
.override_input_shapes(input=[1, 3, 4, 4, 2])
|
|
41
|
+
.tags("activation", "high_dim", "5d")
|
|
42
|
+
.build(),
|
|
43
|
+
valid_test("scalar_input")
|
|
44
|
+
.description("ReLU with scalar input (edge case)")
|
|
45
|
+
.override_input_shapes(input=[1])
|
|
46
|
+
.tags("activation", "scalar")
|
|
47
|
+
.build(),
|
|
48
|
+
e2e_test("e2e_basic")
|
|
49
|
+
.description("End-to-end test for ReLU activation")
|
|
50
|
+
.override_input_shapes(input=[1, 3, 4, 4])
|
|
51
|
+
.override_output_shapes(relu_output=[1, 3, 4, 4])
|
|
52
|
+
.tags("e2e", "activation")
|
|
53
|
+
.build(),
|
|
54
|
+
# --- EDGE CASE / SKIPPED TEST ---
|
|
55
|
+
valid_test("large_input")
|
|
56
|
+
.description("Large input tensor for ReLU (performance/stress test)")
|
|
57
|
+
.override_input_shapes(input=[1, 3, 512, 512])
|
|
58
|
+
.tags("large", "performance", "activation")
|
|
59
|
+
.skip("Performance test, skipped by default")
|
|
60
|
+
.build(),
|
|
61
|
+
]
|