JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
- python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
- python/core/circuit_models/generic_onnx.py +43 -9
- python/core/circuits/base.py +231 -71
- python/core/model_processing/converters/onnx_converter.py +114 -59
- python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
- python/core/model_processing/onnx_custom_ops/mul.py +66 -0
- python/core/model_processing/onnx_custom_ops/relu.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
- python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
- python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
- python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
- python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
- python/core/utils/general_layer_functions.py +17 -12
- python/core/utils/model_registry.py +6 -3
- python/scripts/gen_and_bench.py +2 -2
- python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
- python/tests/circuit_parent_classes/test_circuit.py +561 -38
- python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
- python/tests/onnx_quantizer_tests/__init__.py +1 -0
- python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
- python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/base.py +279 -0
- python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
- python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
- python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
- python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
- python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
- python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
- python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
- python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
- python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
- python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
- python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
- python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
# test_converter.py
|
|
2
2
|
import tempfile
|
|
3
|
+
from collections.abc import Generator
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Any
|
|
5
|
+
from typing import Any
|
|
5
6
|
from unittest.mock import MagicMock, patch
|
|
6
7
|
|
|
8
|
+
import numpy as np
|
|
7
9
|
import onnx
|
|
8
10
|
import onnxruntime as ort
|
|
9
11
|
import pytest
|
|
@@ -13,7 +15,7 @@ from onnx import TensorProto, helper
|
|
|
13
15
|
from python.core.model_processing.converters.onnx_converter import ONNXConverter
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
@pytest.fixture
|
|
18
|
+
@pytest.fixture
|
|
17
19
|
def temp_model_path(
|
|
18
20
|
tmp_path: Generator[Path, None, None],
|
|
19
21
|
) -> Generator[Path, Any, None]:
|
|
@@ -26,7 +28,7 @@ def temp_model_path(
|
|
|
26
28
|
model_path.unlink()
|
|
27
29
|
|
|
28
30
|
|
|
29
|
-
@pytest.fixture
|
|
31
|
+
@pytest.fixture
|
|
30
32
|
def temp_quant_model_path(
|
|
31
33
|
tmp_path: Generator[Path, None, None],
|
|
32
34
|
) -> Generator[Path, Any, None]:
|
|
@@ -39,7 +41,7 @@ def temp_quant_model_path(
|
|
|
39
41
|
model_path.unlink()
|
|
40
42
|
|
|
41
43
|
|
|
42
|
-
@pytest.fixture
|
|
44
|
+
@pytest.fixture
|
|
43
45
|
def converter() -> ONNXConverter:
|
|
44
46
|
conv = ONNXConverter()
|
|
45
47
|
conv.model = MagicMock(name="model")
|
|
@@ -47,7 +49,7 @@ def converter() -> ONNXConverter:
|
|
|
47
49
|
return conv
|
|
48
50
|
|
|
49
51
|
|
|
50
|
-
@pytest.mark.unit
|
|
52
|
+
@pytest.mark.unit
|
|
51
53
|
@patch("python.core.model_processing.converters.onnx_converter.onnx.save")
|
|
52
54
|
def test_save_model(mock_save: MagicMock, converter: ONNXConverter) -> None:
|
|
53
55
|
path = "model.onnx"
|
|
@@ -55,7 +57,7 @@ def test_save_model(mock_save: MagicMock, converter: ONNXConverter) -> None:
|
|
|
55
57
|
mock_save.assert_called_once_with(converter.model, path)
|
|
56
58
|
|
|
57
59
|
|
|
58
|
-
@pytest.mark.unit
|
|
60
|
+
@pytest.mark.unit
|
|
59
61
|
@patch("python.core.model_processing.converters.onnx_converter.onnx.load")
|
|
60
62
|
def test_load_model(mock_load: MagicMock, converter: ONNXConverter) -> None:
|
|
61
63
|
fake_model = MagicMock(name="onnx_model")
|
|
@@ -68,7 +70,7 @@ def test_load_model(mock_load: MagicMock, converter: ONNXConverter) -> None:
|
|
|
68
70
|
assert converter.model == fake_model
|
|
69
71
|
|
|
70
72
|
|
|
71
|
-
@pytest.mark.unit
|
|
73
|
+
@pytest.mark.unit
|
|
72
74
|
@patch("python.core.model_processing.converters.onnx_converter.onnx.save")
|
|
73
75
|
def test_save_quantized_model(mock_save: MagicMock, converter: ONNXConverter) -> None:
|
|
74
76
|
path = "quantized_model.onnx"
|
|
@@ -76,7 +78,7 @@ def test_save_quantized_model(mock_save: MagicMock, converter: ONNXConverter) ->
|
|
|
76
78
|
mock_save.assert_called_once_with(converter.quantized_model, path)
|
|
77
79
|
|
|
78
80
|
|
|
79
|
-
@pytest.mark.unit
|
|
81
|
+
@pytest.mark.unit
|
|
80
82
|
@patch("python.core.model_processing.converters.onnx_converter.Path.exists")
|
|
81
83
|
@patch("python.core.model_processing.converters.onnx_converter.SessionOptions")
|
|
82
84
|
@patch("python.core.model_processing.converters.onnx_converter.InferenceSession")
|
|
@@ -108,10 +110,12 @@ def test_load_quantized_model(
|
|
|
108
110
|
assert converter.quantized_model == fake_model
|
|
109
111
|
|
|
110
112
|
|
|
111
|
-
@pytest.mark.unit
|
|
113
|
+
@pytest.mark.unit
|
|
112
114
|
def test_get_outputs_with_mocked_session(converter: ONNXConverter) -> None:
|
|
113
|
-
dummy_input = [[1.0]]
|
|
115
|
+
dummy_input = np.array([[1.0]]) # Use np.ndarray, not list
|
|
114
116
|
dummy_output = [[2.0]]
|
|
117
|
+
converter.scale_base = 2
|
|
118
|
+
converter.scale_exponent = 10
|
|
115
119
|
|
|
116
120
|
mock_sess = MagicMock()
|
|
117
121
|
|
|
@@ -132,7 +136,10 @@ def test_get_outputs_with_mocked_session(converter: ONNXConverter) -> None:
|
|
|
132
136
|
|
|
133
137
|
result = converter.get_outputs(dummy_input)
|
|
134
138
|
|
|
135
|
-
|
|
139
|
+
# Expect NumPy array to be passed into ort_sess.run()
|
|
140
|
+
expected_call_inputs = {"input": np.asarray(dummy_input)}
|
|
141
|
+
mock_sess.run.assert_called_once_with(["output"], expected_call_inputs)
|
|
142
|
+
|
|
136
143
|
assert result == dummy_output
|
|
137
144
|
|
|
138
145
|
|
|
@@ -148,7 +155,7 @@ def create_dummy_model() -> onnx.ModelProto:
|
|
|
148
155
|
return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)])
|
|
149
156
|
|
|
150
157
|
|
|
151
|
-
@pytest.mark.integration
|
|
158
|
+
@pytest.mark.integration
|
|
152
159
|
def test_save_and_load_real_model() -> None:
|
|
153
160
|
converter = ONNXConverter()
|
|
154
161
|
model = create_dummy_model()
|
|
@@ -181,10 +188,12 @@ def test_save_and_load_real_model() -> None:
|
|
|
181
188
|
assert converter.model.graph.node[0].op_type == "Identity"
|
|
182
189
|
|
|
183
190
|
|
|
184
|
-
@pytest.mark.integration
|
|
191
|
+
@pytest.mark.integration
|
|
185
192
|
def test_real_inference_from_onnx() -> None:
|
|
186
193
|
converter = ONNXConverter()
|
|
187
194
|
converter.model = create_dummy_model()
|
|
195
|
+
converter.scale_base = 2
|
|
196
|
+
converter.scale_exponent = 10
|
|
188
197
|
|
|
189
198
|
# Save and load into onnxruntime
|
|
190
199
|
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp:
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
TEST_RNG_SEED = 2
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .base import BaseLayerConfigProvider, LayerTestConfig
|
|
2
|
+
from .factory import TestLayerFactory
|
|
3
|
+
|
|
4
|
+
# Auto-discover and make available all config providers
|
|
5
|
+
# This triggers the discovery process when the package is imported
|
|
6
|
+
_all_configs = TestLayerFactory.get_layer_configs()
|
|
7
|
+
|
|
8
|
+
# Export the factory and base classes
|
|
9
|
+
__all__ = [
|
|
10
|
+
"BaseLayerConfigProvider",
|
|
11
|
+
"LayerTestConfig",
|
|
12
|
+
"TestLayerFactory",
|
|
13
|
+
]
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
|
|
4
|
+
from python.tests.onnx_quantizer_tests.layers.base import (
|
|
5
|
+
BaseLayerConfigProvider,
|
|
6
|
+
LayerTestConfig,
|
|
7
|
+
LayerTestSpec,
|
|
8
|
+
e2e_test,
|
|
9
|
+
edge_case_test,
|
|
10
|
+
valid_test,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AddConfigProvider(BaseLayerConfigProvider):
|
|
15
|
+
"""Test configuration provider for Add layer"""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def layer_name(self) -> str:
|
|
19
|
+
return "Add"
|
|
20
|
+
|
|
21
|
+
def get_config(self) -> LayerTestConfig:
|
|
22
|
+
return LayerTestConfig(
|
|
23
|
+
op_type="Add",
|
|
24
|
+
valid_inputs=["A", "B"],
|
|
25
|
+
valid_attributes={}, # Add has no layer-specific attributes
|
|
26
|
+
required_initializers={},
|
|
27
|
+
input_shapes={
|
|
28
|
+
"A": [1, 3, 4, 4],
|
|
29
|
+
"B": [1, 3, 4, 4],
|
|
30
|
+
},
|
|
31
|
+
output_shapes={
|
|
32
|
+
"add_output": [1, 3, 4, 4],
|
|
33
|
+
},
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def get_test_specs(self) -> list[LayerTestSpec]:
|
|
37
|
+
rng = np.random.default_rng(TEST_RNG_SEED)
|
|
38
|
+
return [
|
|
39
|
+
# --- VALID TESTS ---
|
|
40
|
+
valid_test("basic")
|
|
41
|
+
.description("Basic elementwise Add of two same-shaped tensors")
|
|
42
|
+
.override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4])
|
|
43
|
+
.tags("basic", "elementwise", "add")
|
|
44
|
+
.build(),
|
|
45
|
+
valid_test("broadcast_add")
|
|
46
|
+
.description("Add with Numpy-style broadcasting along spatial dimensions")
|
|
47
|
+
.override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1])
|
|
48
|
+
.tags("broadcast", "elementwise", "add", "onnx14")
|
|
49
|
+
.build(),
|
|
50
|
+
valid_test("initializer_add")
|
|
51
|
+
.description(
|
|
52
|
+
"Add where second input (B) is a tensor initializer instead of input",
|
|
53
|
+
)
|
|
54
|
+
.override_input_shapes(A=[1, 3, 4, 4])
|
|
55
|
+
.override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4)))
|
|
56
|
+
.tags("initializer", "elementwise", "add", "onnxruntime")
|
|
57
|
+
.build(),
|
|
58
|
+
valid_test("scalar_add")
|
|
59
|
+
.description("Add scalar (initializer) to tensor")
|
|
60
|
+
.override_input_shapes(A=[1, 3, 4, 4])
|
|
61
|
+
.override_initializer("B", np.array([2.0], dtype=np.float32))
|
|
62
|
+
.tags("scalar", "elementwise", "add")
|
|
63
|
+
.build(),
|
|
64
|
+
# --- E2E TESTS ---
|
|
65
|
+
e2e_test("e2e_add")
|
|
66
|
+
.description("End-to-end Add test with random inputs")
|
|
67
|
+
.override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4])
|
|
68
|
+
.override_output_shapes(add_output=[1, 3, 4, 4])
|
|
69
|
+
.tags("e2e", "add", "2d")
|
|
70
|
+
.build(),
|
|
71
|
+
e2e_test("e2e_initializer_add")
|
|
72
|
+
.description(
|
|
73
|
+
"Add where second input (B) is a tensor initializer instead of input",
|
|
74
|
+
)
|
|
75
|
+
.override_input_shapes(A=[1, 3, 4, 4])
|
|
76
|
+
.override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4)))
|
|
77
|
+
.tags("initializer", "elementwise", "add", "onnxruntime")
|
|
78
|
+
.build(),
|
|
79
|
+
e2e_test("e2e_broadcast_add")
|
|
80
|
+
.description("Add with Numpy-style broadcasting along spatial dimensions")
|
|
81
|
+
.override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1])
|
|
82
|
+
.tags("broadcast", "elementwise", "add", "onnx14")
|
|
83
|
+
.build(),
|
|
84
|
+
e2e_test("e2e_scalar_add")
|
|
85
|
+
.description("Add scalar (initializer) to tensor")
|
|
86
|
+
.override_input_shapes(A=[1, 3, 4, 4])
|
|
87
|
+
.override_initializer("B", np.array([2.0], dtype=np.float32))
|
|
88
|
+
.tags("scalar", "elementwise", "add")
|
|
89
|
+
.build(),
|
|
90
|
+
# # --- EDGE CASES ---
|
|
91
|
+
edge_case_test("empty_tensor")
|
|
92
|
+
.description("Add with empty tensor input (zero elements)")
|
|
93
|
+
.override_input_shapes(A=[0], B=[0])
|
|
94
|
+
.tags("edge", "empty", "add")
|
|
95
|
+
.build(),
|
|
96
|
+
edge_case_test("large_tensor")
|
|
97
|
+
.description("Large tensor add performance/stress test")
|
|
98
|
+
.override_input_shapes(A=[1, 64, 256, 256], B=[1, 64, 256, 256])
|
|
99
|
+
.tags("large", "performance", "add")
|
|
100
|
+
.skip("Performance test, skipped by default")
|
|
101
|
+
.build(),
|
|
102
|
+
]
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import onnx
|
|
14
|
+
from onnx import TensorProto, helper, numpy_helper
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpecType(Enum):
|
|
18
|
+
"""Types of test specifications that can be run"""
|
|
19
|
+
|
|
20
|
+
VALID = "valid"
|
|
21
|
+
ERROR = "error"
|
|
22
|
+
EDGE_CASE = "edge_case"
|
|
23
|
+
E2E = "e2e"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class LayerTestSpec:
|
|
28
|
+
"""Individual test specification that can be applied to a LayerTestConfig"""
|
|
29
|
+
|
|
30
|
+
name: str
|
|
31
|
+
spec_type: SpecType
|
|
32
|
+
description: str = ""
|
|
33
|
+
|
|
34
|
+
# Overrides for the base config
|
|
35
|
+
attr_overrides: dict[str, Any] = field(default_factory=dict)
|
|
36
|
+
initializer_overrides: dict[str, np.ndarray] = field(default_factory=dict)
|
|
37
|
+
input_overrides: list[str] = field(default_factory=list)
|
|
38
|
+
input_shape_overrides: dict[str, list[int]] = field(default_factory=dict)
|
|
39
|
+
output_shape_overrides: dict[str, list[int]] = field(default_factory=dict)
|
|
40
|
+
|
|
41
|
+
# Error test specific
|
|
42
|
+
expected_error: type | None = None
|
|
43
|
+
error_match: str | None = None
|
|
44
|
+
|
|
45
|
+
# Custom validation
|
|
46
|
+
custom_validator: Callable | None = None
|
|
47
|
+
|
|
48
|
+
# Test metadata
|
|
49
|
+
tags: list[str] = field(default_factory=list)
|
|
50
|
+
skip_reason: str | None = None
|
|
51
|
+
|
|
52
|
+
# Omit attributes
|
|
53
|
+
omit_attrs: list[str] = field(default_factory=list)
|
|
54
|
+
|
|
55
|
+
# Remove __post_init__ validation - we'll validate in the builder instead
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class LayerTestConfig:
|
|
59
|
+
"""Enhanced configuration class for layer-specific test data"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self: LayerTestConfig,
|
|
63
|
+
op_type: str,
|
|
64
|
+
valid_inputs: list[str],
|
|
65
|
+
valid_attributes: dict[str, Any],
|
|
66
|
+
required_initializers: dict[str, np.ndarray],
|
|
67
|
+
input_shapes: dict[str, list[int]] | None = None,
|
|
68
|
+
output_shapes: dict[str, list[int]] | None = None,
|
|
69
|
+
) -> None:
|
|
70
|
+
self.op_type = op_type
|
|
71
|
+
self.valid_inputs = valid_inputs
|
|
72
|
+
self.valid_attributes = valid_attributes
|
|
73
|
+
self.required_initializers = required_initializers
|
|
74
|
+
self.input_shapes = input_shapes or {"input": [1, 16, 224, 224]}
|
|
75
|
+
self.output_shapes = output_shapes or {f"{op_type.lower()}_output": [1, 10]}
|
|
76
|
+
|
|
77
|
+
def create_node(
|
|
78
|
+
self: LayerTestConfig,
|
|
79
|
+
name_suffix: str = "",
|
|
80
|
+
**attr_overrides: dict[str, Any],
|
|
81
|
+
) -> onnx.NodeProto:
|
|
82
|
+
"""Create a valid node for this layer type"""
|
|
83
|
+
attrs = {**self.valid_attributes, **attr_overrides}
|
|
84
|
+
return helper.make_node(
|
|
85
|
+
self.op_type,
|
|
86
|
+
inputs=self.valid_inputs,
|
|
87
|
+
outputs=[f"{self.op_type.lower()}_output{name_suffix}"],
|
|
88
|
+
name=f"test_{self.op_type.lower()}{name_suffix}",
|
|
89
|
+
**attrs,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def create_initializers(
|
|
93
|
+
self: LayerTestConfig,
|
|
94
|
+
**initializer_overrides: dict[str, Any],
|
|
95
|
+
) -> dict[str, onnx.TensorProto]:
|
|
96
|
+
"""Create initializer tensors for this layer"""
|
|
97
|
+
initializers = {}
|
|
98
|
+
combined_inits = {**self.required_initializers, **initializer_overrides}
|
|
99
|
+
for name, data in combined_inits.items():
|
|
100
|
+
# Special handling for shape tensors in Reshape, etc.
|
|
101
|
+
if name == "shape":
|
|
102
|
+
tensor = numpy_helper.from_array(data.astype(np.int64), name=name)
|
|
103
|
+
else:
|
|
104
|
+
tensor = numpy_helper.from_array(data.astype(np.float32), name=name)
|
|
105
|
+
initializers[name] = tensor
|
|
106
|
+
return initializers
|
|
107
|
+
|
|
108
|
+
def create_test_model(self, test_spec: LayerTestSpec) -> onnx.ModelProto:
|
|
109
|
+
"""Create a complete model for a specific test case"""
|
|
110
|
+
|
|
111
|
+
# Determine node-level inputs.
|
|
112
|
+
# If dev overrides inputs explicitly,
|
|
113
|
+
# respect that; otherwise use original valid_inputs.
|
|
114
|
+
inputs = test_spec.input_overrides or self.valid_inputs
|
|
115
|
+
|
|
116
|
+
# Prepare attributes
|
|
117
|
+
attrs = {**self.valid_attributes, **test_spec.attr_overrides}
|
|
118
|
+
# Remove omitted attributes if specified
|
|
119
|
+
attrs = {**self.valid_attributes, **test_spec.attr_overrides}
|
|
120
|
+
for key in getattr(test_spec, "omit_attrs", []):
|
|
121
|
+
attrs.pop(key, None)
|
|
122
|
+
|
|
123
|
+
# Create initializers (may introduce overrides)
|
|
124
|
+
initializers = self.create_initializers(**test_spec.initializer_overrides)
|
|
125
|
+
|
|
126
|
+
# Apply shape overrides
|
|
127
|
+
input_shapes = {**self.input_shapes, **test_spec.input_shape_overrides}
|
|
128
|
+
output_shapes = {**self.output_shapes, **test_spec.output_shape_overrides}
|
|
129
|
+
|
|
130
|
+
# ----------------------------------------
|
|
131
|
+
# REMOVE graph inputs that are also initializers
|
|
132
|
+
# ----------------------------------------
|
|
133
|
+
initializer_names = set(initializers.keys())
|
|
134
|
+
|
|
135
|
+
# Also remove shapes for initializer inputs
|
|
136
|
+
for init_name in initializer_names:
|
|
137
|
+
input_shapes.pop(init_name, None)
|
|
138
|
+
|
|
139
|
+
# Create ONNX input value infos ONLY from filtered inputs
|
|
140
|
+
graph_inputs = [
|
|
141
|
+
helper.make_tensor_value_info(name, TensorProto.FLOAT, shape)
|
|
142
|
+
for name, shape in input_shapes.items()
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
# Outputs stay unchanged
|
|
146
|
+
graph_outputs = [
|
|
147
|
+
helper.make_tensor_value_info(name, TensorProto.FLOAT, shape)
|
|
148
|
+
for name, shape in output_shapes.items()
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
node = helper.make_node(
|
|
152
|
+
self.op_type,
|
|
153
|
+
inputs=inputs,
|
|
154
|
+
outputs=[f"{self.op_type.lower()}_output"],
|
|
155
|
+
name=f"test_{self.op_type.lower()}_{test_spec.name}",
|
|
156
|
+
**attrs,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Build the graph
|
|
160
|
+
graph = helper.make_graph(
|
|
161
|
+
nodes=[node],
|
|
162
|
+
name=f"{self.op_type.lower()}_test_graph_{test_spec.name}",
|
|
163
|
+
inputs=graph_inputs,
|
|
164
|
+
outputs=graph_outputs,
|
|
165
|
+
initializer=list(initializers.values()),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return helper.make_model(graph)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class TestSpecBuilder:
|
|
172
|
+
"""Builder for creating test specifications"""
|
|
173
|
+
|
|
174
|
+
def __init__(self, name: str, spec_type: SpecType) -> None:
|
|
175
|
+
self._spec = LayerTestSpec(name=name, spec_type=spec_type)
|
|
176
|
+
|
|
177
|
+
def description(self, desc: str) -> TestSpecBuilder:
|
|
178
|
+
self._spec.description = desc
|
|
179
|
+
return self
|
|
180
|
+
|
|
181
|
+
def override_attrs(self, **attrs: dict[str, Any]) -> TestSpecBuilder:
|
|
182
|
+
self._spec.attr_overrides.update(attrs)
|
|
183
|
+
return self
|
|
184
|
+
|
|
185
|
+
def omit_attrs(self, *attrs: str) -> TestSpecBuilder:
|
|
186
|
+
self._spec.omit_attrs.extend(attrs)
|
|
187
|
+
return self
|
|
188
|
+
|
|
189
|
+
def override_initializer(self, name: str, data: np.ndarray) -> TestSpecBuilder:
|
|
190
|
+
self._spec.initializer_overrides[name] = data
|
|
191
|
+
return self
|
|
192
|
+
|
|
193
|
+
def override_inputs(self, *inputs: str) -> TestSpecBuilder:
|
|
194
|
+
self._spec.input_overrides = list(inputs)
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
def override_input_shapes(self, **shapes: dict[str, list[int]]) -> TestSpecBuilder:
|
|
198
|
+
self._spec.input_shape_overrides.update(shapes)
|
|
199
|
+
return self
|
|
200
|
+
|
|
201
|
+
def override_output_shapes(self, **shapes: dict[str, list[int]]) -> TestSpecBuilder:
|
|
202
|
+
self._spec.output_shape_overrides.update(shapes)
|
|
203
|
+
return self
|
|
204
|
+
|
|
205
|
+
def expects_error(
|
|
206
|
+
self,
|
|
207
|
+
error_type: type,
|
|
208
|
+
match: str | None = None,
|
|
209
|
+
) -> TestSpecBuilder:
|
|
210
|
+
if self._spec.spec_type != SpecType.ERROR:
|
|
211
|
+
msg = "expects_error can only be used with ERROR spec type"
|
|
212
|
+
raise ValueError(msg)
|
|
213
|
+
self._spec.expected_error = error_type
|
|
214
|
+
self._spec.error_match = match
|
|
215
|
+
return self
|
|
216
|
+
|
|
217
|
+
def tags(self, *tags: str) -> TestSpecBuilder:
|
|
218
|
+
self._spec.tags.extend(tags)
|
|
219
|
+
return self
|
|
220
|
+
|
|
221
|
+
def skip(self, reason: str) -> TestSpecBuilder:
|
|
222
|
+
self._spec.skip_reason = reason
|
|
223
|
+
return self
|
|
224
|
+
|
|
225
|
+
def build(self) -> LayerTestSpec:
|
|
226
|
+
# Validate before building
|
|
227
|
+
if self._spec.spec_type == SpecType.ERROR and not self._spec.expected_error:
|
|
228
|
+
msg = (
|
|
229
|
+
f"Error test {self._spec.name} must"
|
|
230
|
+
" specify expected_error using .expects_error()"
|
|
231
|
+
)
|
|
232
|
+
raise ValueError(msg)
|
|
233
|
+
return self._spec
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
# Convenience functions
|
|
237
|
+
def valid_test(name: str) -> TestSpecBuilder:
|
|
238
|
+
return TestSpecBuilder(name, SpecType.VALID)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def error_test(name: str) -> TestSpecBuilder:
|
|
242
|
+
return TestSpecBuilder(name, SpecType.ERROR)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def edge_case_test(name: str) -> TestSpecBuilder:
|
|
246
|
+
return TestSpecBuilder(name, SpecType.EDGE_CASE)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def e2e_test(name: str) -> TestSpecBuilder:
|
|
250
|
+
return TestSpecBuilder(name, SpecType.E2E)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class BaseLayerConfigProvider(ABC):
|
|
254
|
+
"""Abstract base class for layer config providers"""
|
|
255
|
+
|
|
256
|
+
@abstractmethod
|
|
257
|
+
def get_config(self) -> LayerTestConfig:
|
|
258
|
+
"""Return the base configuration for this layer"""
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
@abstractmethod
|
|
262
|
+
def layer_name(self) -> str:
|
|
263
|
+
"""Return the layer name/op_type"""
|
|
264
|
+
|
|
265
|
+
def get_test_specs(self) -> list[LayerTestSpec]:
|
|
266
|
+
"""Return test specifications for this layer (override for custom tests)"""
|
|
267
|
+
return []
|
|
268
|
+
|
|
269
|
+
def get_valid_test_specs(self) -> list[LayerTestSpec]:
|
|
270
|
+
"""Get only valid test specifications"""
|
|
271
|
+
return [
|
|
272
|
+
spec for spec in self.get_test_specs() if spec.spec_type == SpecType.VALID
|
|
273
|
+
]
|
|
274
|
+
|
|
275
|
+
def get_error_test_specs(self) -> list[LayerTestSpec]:
|
|
276
|
+
"""Get only error test specifications"""
|
|
277
|
+
return [
|
|
278
|
+
spec for spec in self.get_test_specs() if spec.spec_type == SpecType.ERROR
|
|
279
|
+
]
|