JSTprove 1.3.0__py3-none-macosx_11_0_arm64.whl → 2.0.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- {jstprove-1.3.0.dist-info → jstprove-2.0.0.dist-info}/METADATA +5 -1
- {jstprove-1.3.0.dist-info → jstprove-2.0.0.dist-info}/RECORD +10 -10
- {jstprove-1.3.0.dist-info → jstprove-2.0.0.dist-info}/WHEEL +1 -1
- python/core/binaries/onnx_generic_circuit_2-0-0 +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +45 -2
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +79 -4
- python/tests/onnx_quantizer_tests/layers/maxpool_config.py +106 -0
- python/core/binaries/onnx_generic_circuit_1-3-0 +0 -0
- {jstprove-1.3.0.dist-info → jstprove-2.0.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.3.0.dist-info → jstprove-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.3.0.dist-info → jstprove-2.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: JSTprove
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.0.0
|
|
4
4
|
Summary: Zero-knowledge proofs of ML inference on ONNX models
|
|
5
5
|
Author: Inference Labs Inc
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -239,6 +239,10 @@ To keep paths simple (and to match our scripts), **clone Expander as a subfolder
|
|
|
239
239
|
git clone https://github.com/PolyhedraZK/Expander.git
|
|
240
240
|
cd Expander
|
|
241
241
|
|
|
242
|
+
git fetch
|
|
243
|
+
git checkout af1b7473bc858d250e481d6bb7db98a1ee6b7fc5
|
|
244
|
+
|
|
245
|
+
|
|
242
246
|
# Build (uses the toolchain you configured with rustup)
|
|
243
247
|
cargo build --release
|
|
244
248
|
```
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
jstprove-
|
|
1
|
+
jstprove-2.0.0.dist-info/licenses/LICENSE,sha256=UXQRcYRUH-PfN27n3P-FMaZFY6jr9jFPKcwT7CWbljw,1160
|
|
2
2
|
python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
python/core/__init__.py,sha256=RlfbqGAaUulKl44QGMCkkGJBQZ8R_AgC5bU5zS7BjnA,97
|
|
4
4
|
python/core/binaries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
python/core/binaries/expander-exec,sha256=C_1JcezdfLp9sFOQ2z3wp2gcq1k8zjIR09CxJKGGIuM,7095168
|
|
6
|
-
python/core/binaries/
|
|
6
|
+
python/core/binaries/onnx_generic_circuit_2-0-0,sha256=mPtP1TSpMBtVw56vtGv7Gr7PztEwbdVQ5qwrg6pwDCM,3237712
|
|
7
7
|
python/core/circuit_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
python/core/circuit_models/generic_onnx.py,sha256=P65UZkfVBTE6YhaQ951S6QoTHPuU5ntDt8QL5pXghvw,8787
|
|
9
9
|
python/core/circuit_models/simple_circuit.py,sha256=igQrZtQyreyHc26iAgCyDb0TuD2bJAoumYhc1pYPDzQ,4682
|
|
@@ -30,14 +30,14 @@ python/core/model_processing/onnx_quantizer/exceptions.py,sha256=vzxBRbpvk4ZZbga
|
|
|
30
30
|
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=5I67frJn4j2T1LTvODHixQK4VaqazJFJ0T1BCvqLPgg,9655
|
|
31
31
|
python/core/model_processing/onnx_quantizer/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
32
|
python/core/model_processing/onnx_quantizer/layers/add.py,sha256=AGxzqMa0jABIEKOIgPqEAA7EpZtynQtnD9nxI2NHc0s,1409
|
|
33
|
-
python/core/model_processing/onnx_quantizer/layers/base.py,sha256=
|
|
33
|
+
python/core/model_processing/onnx_quantizer/layers/base.py,sha256=H48okaq1tvl6ckGZV6BlXCVLLKIy1HpvomGPWBLI-8Q,22544
|
|
34
34
|
python/core/model_processing/onnx_quantizer/layers/batchnorm.py,sha256=KSBDPHd52f5Qyf-cnIDFPmfzssaJgMPiTmpIWEdM41U,7718
|
|
35
35
|
python/core/model_processing/onnx_quantizer/layers/clip.py,sha256=HrhiLtqC3cIAvU0wRCqp8_8ZSFH8a3F1Jf_qkXlY44s,3043
|
|
36
36
|
python/core/model_processing/onnx_quantizer/layers/constant.py,sha256=l1IvgvXkmFMiaBsym8wchPF-y1ZH-c5PmFUy92IXWok,3694
|
|
37
37
|
python/core/model_processing/onnx_quantizer/layers/conv.py,sha256=TlUpCRO6PPqH7MPkIrEiEcVfzuiN1WMYEiNIjhYXtWM,4451
|
|
38
38
|
python/core/model_processing/onnx_quantizer/layers/gemm.py,sha256=7fCUMv8OLVZ45a2lYjA2XNvcW3By7lSbX7zeForNK-0,3950
|
|
39
39
|
python/core/model_processing/onnx_quantizer/layers/max.py,sha256=3gUxrdXwcVAtgR-_j4xQ0085Wj0oEBLT897TImxF2d4,1343
|
|
40
|
-
python/core/model_processing/onnx_quantizer/layers/maxpool.py,sha256=
|
|
40
|
+
python/core/model_processing/onnx_quantizer/layers/maxpool.py,sha256=NUa63fYNWrOy1zQMP8YFH_fknvh_3ZSMDfXU7xIJ9tc,7134
|
|
41
41
|
python/core/model_processing/onnx_quantizer/layers/min.py,sha256=cQbXzGOApR6HUJZMARXy87W8IbUC562jnAQm8J8ynQI,1709
|
|
42
42
|
python/core/model_processing/onnx_quantizer/layers/mul.py,sha256=qHsmnYPH-c5uiFeDCvV6e1xSgmIXJ64Sjvh0LYDYEqQ,1396
|
|
43
43
|
python/core/model_processing/onnx_quantizer/layers/relu.py,sha256=d-5fyeKNLTgKKnqCwURpxkjl7QdbJQpuovtCFBM03FA,1685
|
|
@@ -98,7 +98,7 @@ python/tests/onnx_quantizer_tests/layers/factory.py,sha256=WLLEP9ECmSpTliSjhtdWO
|
|
|
98
98
|
python/tests/onnx_quantizer_tests/layers/flatten_config.py,sha256=Xln5Hh6gyeM5gGRCjLGvIL-u08NEs1tXSF32urCqPfE,2110
|
|
99
99
|
python/tests/onnx_quantizer_tests/layers/gemm_config.py,sha256=t7nJY-Wnj6YUD821-jaWzgrQVPa6ytwER3hFMsvyY6Y,7294
|
|
100
100
|
python/tests/onnx_quantizer_tests/layers/max_config.py,sha256=vzR8-2wbPGcH0GMmAJ_sXSEdMtZOjVNGufU__N3Jfyw,3906
|
|
101
|
-
python/tests/onnx_quantizer_tests/layers/maxpool_config.py,sha256=
|
|
101
|
+
python/tests/onnx_quantizer_tests/layers/maxpool_config.py,sha256=nQ7VQ5peq01ODrRoMl7HuMPN958juXteboD12dRsZS8,8275
|
|
102
102
|
python/tests/onnx_quantizer_tests/layers/min_config.py,sha256=izKtCaMXoQHiAfmcGlJRQdKMQz3Su8n0p2mEn0y56Do,3774
|
|
103
103
|
python/tests/onnx_quantizer_tests/layers/mul_config.py,sha256=_Oy4b97ORxFlF3w0BmJ94hNA968HQx2AvwYiASrGPxw,4135
|
|
104
104
|
python/tests/onnx_quantizer_tests/layers/relu_config.py,sha256=_aHuddDApLUBOa0FiR9h4fNfmMSnH5r4JzOMLW0KaTk,2197
|
|
@@ -115,8 +115,8 @@ python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py,sha256=RfnIIi
|
|
|
115
115
|
python/tests/onnx_quantizer_tests/layers_tests/test_validation.py,sha256=jz-WtIEP-jjUklOOAnznwPUXbf07U2PAMGrhzMWP0JU,1371
|
|
116
116
|
python/tests/utils_testing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
117
117
|
python/tests/utils_testing/test_helper_functions.py,sha256=xmeGQieh4LE9U-CDKBlHhSWqH0cAmmDU3qXNbDkkvms,27192
|
|
118
|
-
jstprove-
|
|
119
|
-
jstprove-
|
|
120
|
-
jstprove-
|
|
121
|
-
jstprove-
|
|
122
|
-
jstprove-
|
|
118
|
+
jstprove-2.0.0.dist-info/METADATA,sha256=2y7oLbQStZnZoukzpTT8Avi5d3gTWtidjoo56IWVyCg,14166
|
|
119
|
+
jstprove-2.0.0.dist-info/WHEEL,sha256=W0m4OUfxniOjS_4shj7CwgYZPxSrjpkVN6CH28tx5ZE,106
|
|
120
|
+
jstprove-2.0.0.dist-info/entry_points.txt,sha256=nGcTSO-4q08gPl1IoWdrPaiY7IbO7XvmXKkd34dYHc8,49
|
|
121
|
+
jstprove-2.0.0.dist-info/top_level.txt,sha256=J-z0poNcsv31IHB413--iOY8LoHBKiTHeybHX3abokI,7
|
|
122
|
+
jstprove-2.0.0.dist-info/RECORD,,
|
|
Binary file
|
|
@@ -483,12 +483,13 @@ class QuantizerBase:
|
|
|
483
483
|
node.input[:] = new_inputs
|
|
484
484
|
|
|
485
485
|
# (2) Collect & merge attributes
|
|
486
|
+
self.apply_default_attrs(node)
|
|
486
487
|
attrs = extract_attributes(node)
|
|
487
|
-
for k, v in self.DEFAULT_ATTRS.items():
|
|
488
|
-
attrs.setdefault(k, v)
|
|
489
488
|
if self.USE_SCALING:
|
|
490
489
|
attrs["rescale"] = int(scale_config.rescale)
|
|
491
490
|
|
|
491
|
+
attrs = self._serialize_quantized_attrs(attrs)
|
|
492
|
+
|
|
492
493
|
# (3) Add scaling constant if needed
|
|
493
494
|
if self.USE_SCALING:
|
|
494
495
|
scale_value = self.get_scaling(scale_config.base, scale_config.exponent)
|
|
@@ -546,6 +547,7 @@ class QuantizerBase:
|
|
|
546
547
|
- The resulting model will not make accurate prediction and should be
|
|
547
548
|
used solely for analysis and keeping track of w_and_b
|
|
548
549
|
"""
|
|
550
|
+
self.apply_default_attrs(node)
|
|
549
551
|
# If subclass does not want auto-scaling, do nothing
|
|
550
552
|
if not getattr(self, "USE_WB", False):
|
|
551
553
|
return
|
|
@@ -580,6 +582,47 @@ class QuantizerBase:
|
|
|
580
582
|
|
|
581
583
|
initializer_map[tensor.name] = new_tensor
|
|
582
584
|
|
|
585
|
+
def apply_default_attrs(self, node: onnx.NodeProto) -> None:
|
|
586
|
+
"""
|
|
587
|
+
Ensure DEFAULT_ATTRS are explicitly present on the node.
|
|
588
|
+
Does not overwrite existing attributes.
|
|
589
|
+
"""
|
|
590
|
+
if not getattr(self, "DEFAULT_ATTRS", None):
|
|
591
|
+
return
|
|
592
|
+
|
|
593
|
+
existing = {attr.name for attr in node.attribute}
|
|
594
|
+
|
|
595
|
+
for name, value in self.DEFAULT_ATTRS.items():
|
|
596
|
+
if name in existing:
|
|
597
|
+
continue
|
|
598
|
+
|
|
599
|
+
try:
|
|
600
|
+
attr = onnx.helper.make_attribute(name, value)
|
|
601
|
+
except Exception as e:
|
|
602
|
+
raise HandlerImplementationError(
|
|
603
|
+
op_type=node.op_type,
|
|
604
|
+
message=f"Failed to create default attribute '{name}': {e}",
|
|
605
|
+
) from e
|
|
606
|
+
|
|
607
|
+
node.attribute.append(attr)
|
|
608
|
+
|
|
609
|
+
def _serialize_quantized_attrs(self, attrs: dict) -> dict:
|
|
610
|
+
"""
|
|
611
|
+
Convert logical attribute values into the serialized form expected
|
|
612
|
+
by quantized custom ops.
|
|
613
|
+
|
|
614
|
+
Lists are converted to comma-separated strings.
|
|
615
|
+
"""
|
|
616
|
+
serialized = {}
|
|
617
|
+
|
|
618
|
+
for name, value in attrs.items():
|
|
619
|
+
if isinstance(value, list):
|
|
620
|
+
serialized[name] = ", ".join(str(v) for v in value)
|
|
621
|
+
else:
|
|
622
|
+
serialized[name] = value
|
|
623
|
+
|
|
624
|
+
return serialized
|
|
625
|
+
|
|
583
626
|
|
|
584
627
|
class PassthroughQuantizer(BaseOpQuantizer):
|
|
585
628
|
"""
|
|
@@ -3,9 +3,12 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
5
|
if TYPE_CHECKING:
|
|
6
|
+
from typing import ClassVar
|
|
7
|
+
|
|
6
8
|
import onnx
|
|
7
9
|
|
|
8
10
|
from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
|
|
11
|
+
extract_attributes,
|
|
9
12
|
get_attribute_ints,
|
|
10
13
|
)
|
|
11
14
|
from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
|
|
@@ -21,6 +24,12 @@ class QuantizeMaxpool(QuantizerBase):
|
|
|
21
24
|
USE_WB = False
|
|
22
25
|
USE_SCALING = False
|
|
23
26
|
|
|
27
|
+
DEFAULT_ATTRS: ClassVar = {
|
|
28
|
+
"dilations": [1],
|
|
29
|
+
"pads": [0],
|
|
30
|
+
"strides": [1],
|
|
31
|
+
}
|
|
32
|
+
|
|
24
33
|
|
|
25
34
|
class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
|
|
26
35
|
"""
|
|
@@ -72,6 +81,35 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
|
|
|
72
81
|
InvalidParamError: If any requirement is not met.
|
|
73
82
|
"""
|
|
74
83
|
_ = initializer_map
|
|
84
|
+
attributes = extract_attributes(node)
|
|
85
|
+
ceil_mode = attributes.get("ceil_mode", None)
|
|
86
|
+
auto_pad = attributes.get("auto_pad", None)
|
|
87
|
+
storage_order = attributes.get("storage_order", None)
|
|
88
|
+
|
|
89
|
+
if ceil_mode != 0 and ceil_mode is not None:
|
|
90
|
+
raise InvalidParamError(
|
|
91
|
+
node.name,
|
|
92
|
+
node.op_type,
|
|
93
|
+
"ceil_mode must be 0",
|
|
94
|
+
"ceil_mode",
|
|
95
|
+
"0",
|
|
96
|
+
)
|
|
97
|
+
if auto_pad != "NOTSET" and auto_pad is not None:
|
|
98
|
+
raise InvalidParamError(
|
|
99
|
+
node.name,
|
|
100
|
+
node.op_type,
|
|
101
|
+
"auto_pad must be NOTSET",
|
|
102
|
+
"auto_pad",
|
|
103
|
+
"NOTSET",
|
|
104
|
+
)
|
|
105
|
+
if storage_order != 0 and storage_order is not None:
|
|
106
|
+
raise InvalidParamError(
|
|
107
|
+
node.name,
|
|
108
|
+
node.op_type,
|
|
109
|
+
"storage_order must be 0",
|
|
110
|
+
"storage_order",
|
|
111
|
+
"0",
|
|
112
|
+
)
|
|
75
113
|
self.check_all_params_exist(node)
|
|
76
114
|
self.check_params_size(node)
|
|
77
115
|
self.check_pool_pads(node)
|
|
@@ -85,8 +123,8 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
|
|
|
85
123
|
Raises:
|
|
86
124
|
InvalidParamError: If shape requirement is not met.
|
|
87
125
|
"""
|
|
88
|
-
|
|
89
|
-
|
|
126
|
+
required_attrs = ["kernel_shape"]
|
|
127
|
+
|
|
90
128
|
self.validate_required_attrs(node, required_attrs)
|
|
91
129
|
|
|
92
130
|
# Check dimension of kernel
|
|
@@ -121,11 +159,23 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
|
|
|
121
159
|
|
|
122
160
|
def check_pool_pads(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
|
|
123
161
|
kernel_shape = get_attribute_ints(node, "kernel_shape", default=[])
|
|
124
|
-
|
|
162
|
+
pads_raw = get_attribute_ints(
|
|
163
|
+
node,
|
|
164
|
+
"pads",
|
|
165
|
+
default=self.DEFAULT_ATTRS.get("pads", None),
|
|
166
|
+
)
|
|
167
|
+
pads = self.adjust_pads(node, pads_raw)
|
|
168
|
+
|
|
125
169
|
if pads is None:
|
|
126
170
|
return
|
|
127
171
|
num_dims = len(kernel_shape)
|
|
128
|
-
|
|
172
|
+
|
|
173
|
+
if len(pads) == 1:
|
|
174
|
+
pads = pads * 2 * num_dims
|
|
175
|
+
elif len(pads) == num_dims:
|
|
176
|
+
# If only beginning pads given, repeat for end pads
|
|
177
|
+
pads = pads + pads
|
|
178
|
+
elif len(pads) != num_dims * 2:
|
|
129
179
|
raise InvalidParamError(
|
|
130
180
|
node.name,
|
|
131
181
|
node.op_type,
|
|
@@ -148,3 +198,28 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
|
|
|
148
198
|
node.op_type,
|
|
149
199
|
f"pads[{dim + num_dims}]={pad_after} >= kernel[{dim}]={kernel}",
|
|
150
200
|
)
|
|
201
|
+
|
|
202
|
+
def adjust_pads(
|
|
203
|
+
self: MaxpoolQuantizer,
|
|
204
|
+
node: onnx.NodeProto,
|
|
205
|
+
pads_raw: str | int | list[int] | None,
|
|
206
|
+
) -> list[int]:
|
|
207
|
+
if pads_raw is None:
|
|
208
|
+
pads: list[int] = []
|
|
209
|
+
elif isinstance(pads_raw, str):
|
|
210
|
+
# single string, could be "0" or "1 2"
|
|
211
|
+
pads = [int(x) for x in pads_raw.split()]
|
|
212
|
+
elif isinstance(pads_raw, int):
|
|
213
|
+
# single integer
|
|
214
|
+
pads = [pads_raw]
|
|
215
|
+
elif isinstance(pads_raw, (list, tuple)):
|
|
216
|
+
# already a list of numbers (may be strings)
|
|
217
|
+
pads = [int(x) for x in pads_raw]
|
|
218
|
+
else:
|
|
219
|
+
raise InvalidParamError(
|
|
220
|
+
node.name,
|
|
221
|
+
node.op_type,
|
|
222
|
+
f"Cannot parse pads: {pads_raw}",
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return pads
|
|
@@ -28,6 +28,8 @@ class MaxPoolConfigProvider(BaseLayerConfigProvider):
|
|
|
28
28
|
"pads": [0, 0, 0, 0],
|
|
29
29
|
},
|
|
30
30
|
required_initializers={},
|
|
31
|
+
input_shapes={"input": [1, 3, 4, 4]},
|
|
32
|
+
output_shapes={"maxpool_output": [1, 3, 2, 2]},
|
|
31
33
|
)
|
|
32
34
|
|
|
33
35
|
def get_test_specs(self) -> list:
|
|
@@ -52,12 +54,97 @@ class MaxPoolConfigProvider(BaseLayerConfigProvider):
|
|
|
52
54
|
.override_attrs(strides=[1, 1])
|
|
53
55
|
.tags("stride_1", "pool", "overlap")
|
|
54
56
|
.build(),
|
|
57
|
+
valid_test("missing_dilations_attr")
|
|
58
|
+
.description("MaxPool without dilations attribute should default to [1, 1]")
|
|
59
|
+
.override_attrs(dilations=None)
|
|
60
|
+
.tags("default_attr", "dilations", "pool")
|
|
61
|
+
.build(),
|
|
62
|
+
valid_test("non_default_dilations")
|
|
63
|
+
.description("MaxPool with explicit non-default dilations")
|
|
64
|
+
.override_attrs(dilations=[2, 2])
|
|
65
|
+
.tags("dilations", "non_default", "pool")
|
|
66
|
+
.override_output_shapes(maxpool_output=[1, 3, 1, 1])
|
|
67
|
+
.build(),
|
|
68
|
+
valid_test("missing_optional_attrs")
|
|
69
|
+
.description(
|
|
70
|
+
"MaxPool without pads/strides/dilations should use default values",
|
|
71
|
+
)
|
|
72
|
+
.override_attrs(
|
|
73
|
+
pads=None,
|
|
74
|
+
strides=None,
|
|
75
|
+
dilations=None,
|
|
76
|
+
)
|
|
77
|
+
.tags("defaults", "optional_attrs", "pool")
|
|
78
|
+
.build(),
|
|
79
|
+
valid_test("non_default_pads")
|
|
80
|
+
.description("MaxPool with explicit non-default pads")
|
|
81
|
+
.override_attrs(pads=[1, 1, 1, 1])
|
|
82
|
+
.override_output_shapes(maxpool_output=[1, 3, 3, 3])
|
|
83
|
+
.tags("pads", "non_default", "pool")
|
|
84
|
+
.build(),
|
|
85
|
+
valid_test("non_default_strides")
|
|
86
|
+
.description("MaxPool with explicit non-default strides [3, 3]")
|
|
87
|
+
.override_attrs(strides=[3, 3])
|
|
88
|
+
.override_output_shapes(maxpool_output=[1, 3, 1, 1])
|
|
89
|
+
.tags("strides", "non_default", "pool")
|
|
90
|
+
.build(),
|
|
91
|
+
valid_test("rectangular_strides")
|
|
92
|
+
.description("MaxPool with non-square strides [2, 1]")
|
|
93
|
+
.override_attrs(strides=[2, 1])
|
|
94
|
+
.override_output_shapes(maxpool_output=[1, 3, 2, 3])
|
|
95
|
+
.tags("strides", "non_square", "pool")
|
|
96
|
+
.build(),
|
|
97
|
+
# --- E2E TESTS ---
|
|
55
98
|
e2e_test("e2e_basic")
|
|
56
99
|
.description("End-to-end test for 2D MaxPool")
|
|
57
100
|
.override_input_shapes(input=[1, 3, 4, 4])
|
|
58
101
|
.override_output_shapes(maxpool_output=[1, 3, 2, 2])
|
|
59
102
|
.tags("e2e", "pool", "2d")
|
|
60
103
|
.build(),
|
|
104
|
+
e2e_test("missing_dilations_attr")
|
|
105
|
+
.description("MaxPool without dilations attribute should default to [1, 1]")
|
|
106
|
+
.override_attrs(dilations=None)
|
|
107
|
+
.tags("default_attr", "dilations", "pool")
|
|
108
|
+
.build(),
|
|
109
|
+
e2e_test("non_default_dilations")
|
|
110
|
+
.description("MaxPool with explicit non-default dilations")
|
|
111
|
+
.override_attrs(dilations=[2, 2])
|
|
112
|
+
.override_output_shapes(maxpool_output=[1, 3, 1, 1])
|
|
113
|
+
.tags("dilations", "non_default", "pool")
|
|
114
|
+
.build(),
|
|
115
|
+
e2e_test("e2e_defaults_applied")
|
|
116
|
+
.description("E2E MaxPool with default pads/strides/dilations applied")
|
|
117
|
+
.override_attrs(
|
|
118
|
+
pads=None,
|
|
119
|
+
strides=None,
|
|
120
|
+
dilations=None,
|
|
121
|
+
)
|
|
122
|
+
.override_input_shapes(input=[1, 1, 4, 4])
|
|
123
|
+
.override_output_shapes(maxpool_output=[1, 1, 3, 3])
|
|
124
|
+
.tags("e2e", "defaults", "pool")
|
|
125
|
+
.build(),
|
|
126
|
+
e2e_test("non_default_pads")
|
|
127
|
+
.description("MaxPool with explicit non-default pads")
|
|
128
|
+
.override_attrs(pads=[1, 1, 1, 1])
|
|
129
|
+
.override_output_shapes(maxpool_output=[1, 3, 3, 3])
|
|
130
|
+
.tags("pads", "non_default", "pool")
|
|
131
|
+
.build(),
|
|
132
|
+
e2e_test("non_default_strides")
|
|
133
|
+
.description("E2E MaxPool with explicit non-default strides")
|
|
134
|
+
.override_attrs(strides=[3, 3])
|
|
135
|
+
.override_input_shapes(input=[1, 3, 4, 4])
|
|
136
|
+
.override_output_shapes(maxpool_output=[1, 3, 1, 1])
|
|
137
|
+
.tags("e2e", "strides", "non_default", "pool")
|
|
138
|
+
.build(),
|
|
139
|
+
e2e_test("missing_strides_attr_defaults_applied")
|
|
140
|
+
.description(
|
|
141
|
+
"E2E MaxPool without strides attribute should default to [1, 1]",
|
|
142
|
+
)
|
|
143
|
+
.override_attrs(strides=None)
|
|
144
|
+
.override_input_shapes(input=[1, 1, 4, 4])
|
|
145
|
+
.override_output_shapes(maxpool_output=[1, 1, 3, 3])
|
|
146
|
+
.tags("e2e", "defaults", "strides", "pool")
|
|
147
|
+
.build(),
|
|
61
148
|
# # --- ERROR TESTS ---
|
|
62
149
|
error_test("asymmetric_padding")
|
|
63
150
|
.description("MaxPool with asymmetric padding")
|
|
@@ -71,6 +158,25 @@ class MaxPoolConfigProvider(BaseLayerConfigProvider):
|
|
|
71
158
|
.expects_error(InvalidParamError, "Currently only MaxPool2D is supported")
|
|
72
159
|
.tags("invalid_attr_length", "kernel_shape")
|
|
73
160
|
.build(),
|
|
161
|
+
error_test("auto_pad_not_supported")
|
|
162
|
+
.description("MaxPool with auto_pad should be rejected")
|
|
163
|
+
.override_attrs(auto_pad="SAME_UPPER")
|
|
164
|
+
.expects_error(InvalidParamError, "auto_pad must be NOTSET")
|
|
165
|
+
.tags("invalid", "auto_pad", "pool")
|
|
166
|
+
.build(),
|
|
167
|
+
error_test("ceil_mode_not_supported")
|
|
168
|
+
.description("MaxPool with ceil_mode != 0 should be rejected")
|
|
169
|
+
.override_attrs(ceil_mode=1)
|
|
170
|
+
.expects_error(InvalidParamError, "ceil_mode must be 0")
|
|
171
|
+
.tags("invalid", "ceil_mode", "pool")
|
|
172
|
+
.build(),
|
|
173
|
+
error_test("storage_order_not_supported")
|
|
174
|
+
.description("MaxPool with storage_order != 0 should be rejected")
|
|
175
|
+
.override_attrs(storage_order=1)
|
|
176
|
+
.expects_error(InvalidParamError, "storage_order must be 0")
|
|
177
|
+
.tags("invalid", "storage_order", "pool")
|
|
178
|
+
.build(),
|
|
179
|
+
# This can be removed if we start to support explicit None strides
|
|
74
180
|
# --- EDGE CASE / SKIPPED TEST ---
|
|
75
181
|
valid_test("large_input")
|
|
76
182
|
.description("Large MaxPool input (performance/stress test)")
|
|
Binary file
|
|
File without changes
|
|
File without changes
|
|
File without changes
|