JSTprove 2.0.0__py3-none-macosx_11_0_arm64.whl → 2.1.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: JSTprove
3
- Version: 2.0.0
3
+ Version: 2.1.0
4
4
  Summary: Zero-knowledge proofs of ML inference on ONNX models
5
5
  Author: Inference Labs Inc
6
6
  Requires-Python: >=3.10
@@ -1,9 +1,9 @@
1
- jstprove-2.0.0.dist-info/licenses/LICENSE,sha256=UXQRcYRUH-PfN27n3P-FMaZFY6jr9jFPKcwT7CWbljw,1160
1
+ jstprove-2.1.0.dist-info/licenses/LICENSE,sha256=UXQRcYRUH-PfN27n3P-FMaZFY6jr9jFPKcwT7CWbljw,1160
2
2
  python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  python/core/__init__.py,sha256=RlfbqGAaUulKl44QGMCkkGJBQZ8R_AgC5bU5zS7BjnA,97
4
4
  python/core/binaries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  python/core/binaries/expander-exec,sha256=C_1JcezdfLp9sFOQ2z3wp2gcq1k8zjIR09CxJKGGIuM,7095168
6
- python/core/binaries/onnx_generic_circuit_2-0-0,sha256=mPtP1TSpMBtVw56vtGv7Gr7PztEwbdVQ5qwrg6pwDCM,3237712
6
+ python/core/binaries/onnx_generic_circuit_2-1-0,sha256=RAA6W15aObCMa-0FnwIFHyMStTAuhNJsiaE95dxTBPE,3323248
7
7
  python/core/circuit_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  python/core/circuit_models/generic_onnx.py,sha256=P65UZkfVBTE6YhaQ951S6QoTHPuU5ntDt8QL5pXghvw,8787
9
9
  python/core/circuit_models/simple_circuit.py,sha256=igQrZtQyreyHc26iAgCyDb0TuD2bJAoumYhc1pYPDzQ,4682
@@ -15,7 +15,7 @@ python/core/model_processing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
15
15
  python/core/model_processing/errors.py,sha256=uh2YFjuuU5JM3anMtSTLAH-zjlNAKStmLDZqRUgBWS8,4611
16
16
  python/core/model_processing/converters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  python/core/model_processing/converters/base.py,sha256=o6bNwmqD9sOM9taqMb0ed6804RugQiU3va0rY_EA5SE,4265
18
- python/core/model_processing/converters/onnx_converter.py,sha256=-eXdF6tfluFRxGgnQtJQ8R2309aYX-8z8HzMxk_Qv8I,44340
18
+ python/core/model_processing/converters/onnx_converter.py,sha256=nAvcGzNqNJGHTUgdphvkQAg7JXBHMchIgzp-HTYAkJw,44353
19
19
  python/core/model_processing/onnx_custom_ops/__init__.py,sha256=ZKUC4ToRxgEEMHcTyERATVEN0KSDs-9cM1T-tTw3I1g,525
20
20
  python/core/model_processing/onnx_custom_ops/batchnorm.py,sha256=8kg4iGGdt6B_fIJkpt4v5eNFpoHa4bjTB0NnCSmKFvE,1693
21
21
  python/core/model_processing/onnx_custom_ops/conv.py,sha256=6jJm3fcGWzcU4RjVgf179mPFCqsl4C3AR7bqQTffDgA,3464
@@ -27,7 +27,7 @@ python/core/model_processing/onnx_custom_ops/onnx_helpers.py,sha256=utnJuc5sgb_z
27
27
  python/core/model_processing/onnx_custom_ops/relu.py,sha256=pZsPXC_r0FPggURKDphh8P1IRXY0w4hH7ExBmYTlWjE,1202
28
28
  python/core/model_processing/onnx_quantizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
29
  python/core/model_processing/onnx_quantizer/exceptions.py,sha256=vzxBRbpvk4ZZbgacDISnqmQQKj7Ls46V08ilHnhaJy0,5645
30
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=5I67frJn4j2T1LTvODHixQK4VaqazJFJ0T1BCvqLPgg,9655
30
+ python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=fgN-2-oAqwu436yi_dvmdBApXq2T3DpKGZOqdZybjJ8,9996
31
31
  python/core/model_processing/onnx_quantizer/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
32
  python/core/model_processing/onnx_quantizer/layers/add.py,sha256=AGxzqMa0jABIEKOIgPqEAA7EpZtynQtnD9nxI2NHc0s,1409
33
33
  python/core/model_processing/onnx_quantizer/layers/base.py,sha256=H48okaq1tvl6ckGZV6BlXCVLLKIy1HpvomGPWBLI-8Q,22544
@@ -41,7 +41,9 @@ python/core/model_processing/onnx_quantizer/layers/maxpool.py,sha256=NUa63fYNWrO
41
41
  python/core/model_processing/onnx_quantizer/layers/min.py,sha256=cQbXzGOApR6HUJZMARXy87W8IbUC562jnAQm8J8ynQI,1709
42
42
  python/core/model_processing/onnx_quantizer/layers/mul.py,sha256=qHsmnYPH-c5uiFeDCvV6e1xSgmIXJ64Sjvh0LYDYEqQ,1396
43
43
  python/core/model_processing/onnx_quantizer/layers/relu.py,sha256=d-5fyeKNLTgKKnqCwURpxkjl7QdbJQpuovtCFBM03FA,1685
44
+ python/core/model_processing/onnx_quantizer/layers/squeeze.py,sha256=haQ9BaeE9w4XxBdTvKIuWSAhZ65Q7A7MhbZlzbFiyDE,4848
44
45
  python/core/model_processing/onnx_quantizer/layers/sub.py,sha256=M7D98TZBNP9-2R9MX6mcpYlrWFxTiX9JCs3XNcg1U-Q,1409
46
+ python/core/model_processing/onnx_quantizer/layers/unsqueeze.py,sha256=qCS8BNsffDesseP15h2s9GDfw7Vg51XqZNVaTclKquQ,9807
45
47
  python/core/model_templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
48
  python/core/model_templates/circuit_template.py,sha256=OAqMRshi9OiJYoqpjkg5tUfNf18MfZmhsxxD6SANm_4,2106
47
49
  python/core/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -54,10 +56,11 @@ python/core/utils/model_registry.py,sha256=aZg_9LEqsBXK84oxQ8A3NGZl-9aGnLgfR-kgx
54
56
  python/core/utils/scratch_tests.py,sha256=o2VDTk8QBKA3UHHE-h7Ghtoge6kGG7G-8qwvesuTFFc,2281
55
57
  python/core/utils/witness_utils.py,sha256=ukvbF6EaHMPzRQVZad9wQ9gISRwBGQ1hEAHzc5TpGuw,9488
56
58
  python/frontend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
- python/frontend/cli.py,sha256=lkvhzQC6bv0AgWUypg_cH-JT574r89qgTIsgHDT9GRg,3106
58
- python/frontend/commands/__init__.py,sha256=HKfKIM8wKMvzUPlBMJCSAqRurPIp85btGFCNjr7DbyE,575
59
+ python/frontend/cli.py,sha256=HU0nYkp0kFgxP1BCs_CX4ERd7VvqzKFMooBK5LkEK2c,3142
60
+ python/frontend/commands/__init__.py,sha256=b6N5jB72TbacBNTraYwqFCei8gepw06id8skZOkKMOw,651
59
61
  python/frontend/commands/args.py,sha256=JmG4q-tbEy8_YcQsph_WLEAs_w7y7GiR22PhTrc14v4,2255
60
62
  python/frontend/commands/base.py,sha256=a7NWoXB9VL8It1TpuL2vmR7J2bhejHDllrNMBEm-JLE,6368
63
+ python/frontend/commands/batch.py,sha256=-uy1PLgeEDIzgRet8JPcUb2xTkrQHW3v_x0_om7hr6U,6069
61
64
  python/frontend/commands/compile.py,sha256=-mE4LjBEXzgsnzTJCeas0ZkZgD-kdATpYLk52ljBw88,1905
62
65
  python/frontend/commands/constants.py,sha256=feCVczqP6xphHUta2ZMaAuYyVeemZgwU_sCWr6ky5X8,164
63
66
  python/frontend/commands/model_check.py,sha256=xsXvXzYpIagFGnUk_UYGZYx-onP0Opes95XismqvY64,1806
@@ -72,7 +75,7 @@ python/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
72
75
  python/scripts/benchmark_runner.py,sha256=sjbqaLrdjt94AoyQXAxT4FhsN6aRu5idTRQ5uHmZOWM,28593
73
76
  python/scripts/gen_and_bench.py,sha256=V36x7djYmHlveAJgYzMlXwnmF0gAGO3-1mg9PWOmpj8,16249
74
77
  python/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
75
- python/tests/test_cli.py,sha256=OiAyG3aBpukk0i5FFWbiKaF42wf-7By-UWDHNjwtsqo,27042
78
+ python/tests/test_cli.py,sha256=hhlhsEx-6jJEtw4SKwZFBki5iKs8Z8KPtQccxgy1AMo,31755
76
79
  python/tests/circuit_e2e_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
80
  python/tests/circuit_e2e_tests/circuit_model_developer_test.py,sha256=8hl8SKw7obXplo0jsiKoKIZLxlu1_HhXvGDeSBDBars,39456
78
81
  python/tests/circuit_e2e_tests/helper_fns_for_tests.py,sha256=uEThqTsRdNJivHwAv-aJIUtSPlmVHdhMZqZSH1OqhDE,5177
@@ -89,7 +92,7 @@ python/tests/onnx_quantizer_tests/test_registered_quantizers.py,sha256=lw_jYSbQ9
89
92
  python/tests/onnx_quantizer_tests/testing_helper_functions.py,sha256=N0fQv2pYzUCVZ7wkcR8gEKs5zTXT1hWrK-HKSTQYvYU,534
90
93
  python/tests/onnx_quantizer_tests/layers/__init__.py,sha256=xP-RmW6LfIANgK1s9Q0KZet2yvNr-3c6YIVLAAQqGUY,404
91
94
  python/tests/onnx_quantizer_tests/layers/add_config.py,sha256=T3tGddupDtrvLck2SL2yETDblNtv0aU7Tl7fNyZUhO4,4133
92
- python/tests/onnx_quantizer_tests/layers/base.py,sha256=3nqmU2PgOdK_mPkz-YHg3idgr-PXYbu5kCIY-Uic5yo,9317
95
+ python/tests/onnx_quantizer_tests/layers/base.py,sha256=YbUx0CnGGtET9qsOd1aD2oN5GbhC6MYtguONp6fY6BA,9327
93
96
  python/tests/onnx_quantizer_tests/layers/batchnorm_config.py,sha256=P-sZuHAdEfNczcgTeLjqJnEbpqN3dKTsbqvY4-SBqiQ,8231
94
97
  python/tests/onnx_quantizer_tests/layers/clip_config.py,sha256=-OuhnUgz6xY4iW1jUR7W-J__Ie9lXI9vplmzp8qXqRc,4973
95
98
  python/tests/onnx_quantizer_tests/layers/constant_config.py,sha256=RdrKNMNZjI3Sk5o8WLNqmBUyYVJRWgtFbQ6oFWMwyQk,1193
@@ -103,7 +106,9 @@ python/tests/onnx_quantizer_tests/layers/min_config.py,sha256=izKtCaMXoQHiAfmcGl
103
106
  python/tests/onnx_quantizer_tests/layers/mul_config.py,sha256=_Oy4b97ORxFlF3w0BmJ94hNA968HQx2AvwYiASrGPxw,4135
104
107
  python/tests/onnx_quantizer_tests/layers/relu_config.py,sha256=_aHuddDApLUBOa0FiR9h4fNfmMSnH5r4JzOMLW0KaTk,2197
105
108
  python/tests/onnx_quantizer_tests/layers/reshape_config.py,sha256=fZchSqIAy76m7j97wVC_UI6slSpv8nbwukhkbGR2sRE,2203
109
+ python/tests/onnx_quantizer_tests/layers/squeeze_config.py,sha256=6Grn9yzgQzrLtjA5fs6E_DNZxtsyDfIqmwpr9l2jric,4909
106
110
  python/tests/onnx_quantizer_tests/layers/sub_config.py,sha256=IxF18mG9kjlEiKYSNG912CEcBxOFGxIWoRAwjvBXiRo,4133
111
+ python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py,sha256=M707k0pFvlZOeehkE9TWurPNrYwlaH-IjjV6mZdFEPM,4782
107
112
  python/tests/onnx_quantizer_tests/layers_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
108
113
  python/tests/onnx_quantizer_tests/layers_tests/base_test.py,sha256=UgbcT97tgcuTtO1pOADpww9bz_JElKiI2mxLJYKyF1k,2992
109
114
  python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py,sha256=Vxn4LEWHZeGa_vS1-7ptFqSSBb0D-3BG-ETocP4pvsI,3651
@@ -115,8 +120,8 @@ python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py,sha256=RfnIIi
115
120
  python/tests/onnx_quantizer_tests/layers_tests/test_validation.py,sha256=jz-WtIEP-jjUklOOAnznwPUXbf07U2PAMGrhzMWP0JU,1371
116
121
  python/tests/utils_testing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
117
122
  python/tests/utils_testing/test_helper_functions.py,sha256=xmeGQieh4LE9U-CDKBlHhSWqH0cAmmDU3qXNbDkkvms,27192
118
- jstprove-2.0.0.dist-info/METADATA,sha256=2y7oLbQStZnZoukzpTT8Avi5d3gTWtidjoo56IWVyCg,14166
119
- jstprove-2.0.0.dist-info/WHEEL,sha256=W0m4OUfxniOjS_4shj7CwgYZPxSrjpkVN6CH28tx5ZE,106
120
- jstprove-2.0.0.dist-info/entry_points.txt,sha256=nGcTSO-4q08gPl1IoWdrPaiY7IbO7XvmXKkd34dYHc8,49
121
- jstprove-2.0.0.dist-info/top_level.txt,sha256=J-z0poNcsv31IHB413--iOY8LoHBKiTHeybHX3abokI,7
122
- jstprove-2.0.0.dist-info/RECORD,,
123
+ jstprove-2.1.0.dist-info/METADATA,sha256=dbpJDcKKKd0J8pLcyiz4f4Ohac7EI6GVK6muJf2Y52k,14166
124
+ jstprove-2.1.0.dist-info/WHEEL,sha256=W0m4OUfxniOjS_4shj7CwgYZPxSrjpkVN6CH28tx5ZE,106
125
+ jstprove-2.1.0.dist-info/entry_points.txt,sha256=nGcTSO-4q08gPl1IoWdrPaiY7IbO7XvmXKkd34dYHc8,49
126
+ jstprove-2.1.0.dist-info/top_level.txt,sha256=J-z0poNcsv31IHB413--iOY8LoHBKiTHeybHX3abokI,7
127
+ jstprove-2.1.0.dist-info/RECORD,,
@@ -1022,7 +1022,9 @@ class ONNXConverter(ModelConverter):
1022
1022
  ]
1023
1023
  self.input_shape = get_input_shapes(onnx_model)
1024
1024
 
1025
- def get_weights(self: ONNXConverter) -> tuple[
1025
+ def get_weights(
1026
+ self: ONNXConverter,
1027
+ ) -> tuple[
1026
1028
  dict[str, list[ONNXLayerDict]],
1027
1029
  dict[str, list[ONNXLayerDict]],
1028
1030
  CircuitParamsDict,
@@ -1055,7 +1057,7 @@ class ONNXConverter(ModelConverter):
1055
1057
  scale_base=scale_base,
1056
1058
  )
1057
1059
  # Get layers in correct format
1058
- (architecture, w_and_b) = self.analyze_layers(
1060
+ architecture, w_and_b = self.analyze_layers(
1059
1061
  scaled_and_transformed_model,
1060
1062
  output_name_to_shape,
1061
1063
  )
@@ -0,0 +1,155 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ import numpy as np
6
+ from onnx import numpy_helper
7
+
8
+ from python.core.model_processing.onnx_quantizer.exceptions import (
9
+ InvalidParamError,
10
+ )
11
+ from python.core.model_processing.onnx_quantizer.layers.base import (
12
+ BaseOpQuantizer,
13
+ QuantizerBase,
14
+ ScaleConfig,
15
+ )
16
+
17
+ if TYPE_CHECKING:
18
+ import onnx
19
+
20
+
21
+ class QuantizeSqueeze(QuantizerBase):
22
+ OP_TYPE = "Squeeze"
23
+ DOMAIN = ""
24
+ USE_WB = False
25
+ USE_SCALING = False
26
+ # Only the data input is relevant for scale-planning.
27
+ SCALE_PLAN: ClassVar = {0: 1}
28
+
29
+
30
+ class SqueezeQuantizer(BaseOpQuantizer, QuantizeSqueeze):
31
+ """
32
+ Quantizer for ONNX Squeeze.
33
+
34
+ Squeeze is scale-preserving (pure shape/view transform):
35
+ - No arithmetic
36
+ - No rescaling
37
+ - No custom op
38
+
39
+ We support:
40
+ - axes as an attribute (older opsets)
41
+ - axes as a constant initializer input (newer opsets)
42
+
43
+ We do NOT support dynamic axes provided at runtime.
44
+ """
45
+
46
+ def __init__(
47
+ self: SqueezeQuantizer,
48
+ new_initializers: list[onnx.TensorProto] | None = None,
49
+ ) -> None:
50
+ super().__init__()
51
+ if new_initializers is not None:
52
+ self.new_initializers = new_initializers
53
+
54
+ def quantize(
55
+ self: SqueezeQuantizer,
56
+ node: onnx.NodeProto,
57
+ graph: onnx.GraphProto,
58
+ scale_config: ScaleConfig,
59
+ initializer_map: dict[str, onnx.TensorProto],
60
+ ) -> list[onnx.NodeProto]:
61
+ # Pure passthrough; QuantizerBase handles standard bookkeeping.
62
+ return QuantizeSqueeze.quantize(
63
+ self,
64
+ node,
65
+ graph,
66
+ scale_config,
67
+ initializer_map,
68
+ )
69
+
70
+ _N_INPUTS_NO_AXES: ClassVar[int] = 1
71
+ _N_INPUTS_WITH_AXES: ClassVar[int] = 2
72
+
73
+ def _get_axes_from_attribute(self, node: onnx.NodeProto) -> list[int] | None:
74
+ for attr in node.attribute:
75
+ if attr.name == "axes":
76
+ return list(attr.ints)
77
+ return None
78
+
79
+ def _get_axes_from_initializer_input(
80
+ self,
81
+ node: onnx.NodeProto,
82
+ initializer_map: dict[str, onnx.TensorProto],
83
+ ) -> list[int]:
84
+ axes_name = node.input[1]
85
+ if axes_name not in initializer_map:
86
+ raise InvalidParamError(
87
+ node_name=node.name,
88
+ op_type=node.op_type,
89
+ message=(
90
+ f"Dynamic axes input is not supported for Squeeze "
91
+ f"(expected axes '{axes_name}' to be an initializer)."
92
+ ),
93
+ )
94
+
95
+ axes_tensor = initializer_map[axes_name]
96
+ arr = numpy_helper.to_array(axes_tensor)
97
+
98
+ if not np.issubdtype(arr.dtype, np.integer):
99
+ raise InvalidParamError(
100
+ node_name=node.name,
101
+ op_type=node.op_type,
102
+ message=f"Squeeze axes initializer must be integer, got {arr.dtype}.",
103
+ attr_key="axes",
104
+ expected="integer tensor (0-D or 1-D)",
105
+ )
106
+
107
+ if arr.ndim == 0:
108
+ return [int(arr)]
109
+ if arr.ndim == 1:
110
+ return [int(x) for x in arr.tolist()]
111
+
112
+ raise InvalidParamError(
113
+ node_name=node.name,
114
+ op_type=node.op_type,
115
+ message=f"Squeeze axes initializer must be 0-D or 1-D, got {arr.ndim}-D.",
116
+ attr_key="axes",
117
+ expected="0-D scalar or 1-D list of axes",
118
+ )
119
+
120
+ def check_supported(
121
+ self: SqueezeQuantizer,
122
+ node: onnx.NodeProto,
123
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
124
+ ) -> None:
125
+ self.validate_node_has_output(node)
126
+ initializer_map = initializer_map or {}
127
+
128
+ n_inputs = len(node.input)
129
+ if n_inputs not in (self._N_INPUTS_NO_AXES, self._N_INPUTS_WITH_AXES):
130
+ raise InvalidParamError(
131
+ node_name=node.name,
132
+ op_type=node.op_type,
133
+ message=f"Squeeze expects 1 or 2 inputs, got {n_inputs}.",
134
+ )
135
+
136
+ axes = self._get_axes_from_attribute(node)
137
+
138
+ # If axes is provided as a second input, it must be a constant initializer.
139
+ if axes is None and n_inputs == self._N_INPUTS_WITH_AXES:
140
+ axes = self._get_axes_from_initializer_input(node, initializer_map)
141
+
142
+ # If axes is omitted entirely, ONNX semantics are "remove all dims of size 1".
143
+ # We can't validate legality here without rank/shape; defer to Rust.
144
+
145
+ if axes is None:
146
+ return
147
+
148
+ if len(set(axes)) != len(axes):
149
+ raise InvalidParamError(
150
+ node_name=node.name,
151
+ op_type=node.op_type,
152
+ message=f"axes must not contain duplicates: {axes}",
153
+ attr_key="axes",
154
+ expected="axes list with unique entries",
155
+ )
@@ -0,0 +1,313 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ import numpy as np
6
+ from onnx import helper, numpy_helper
7
+
8
+ from python.core.model_processing.converters.base import ModelType
9
+ from python.core.model_processing.errors import LayerAnalysisError
10
+ from python.core.model_processing.onnx_custom_ops.onnx_helpers import parse_attributes
11
+ from python.core.model_processing.onnx_quantizer.exceptions import (
12
+ InvalidParamError,
13
+ )
14
+ from python.core.model_processing.onnx_quantizer.layers.base import (
15
+ BaseOpQuantizer,
16
+ QuantizerBase,
17
+ ScaleConfig,
18
+ )
19
+
20
+ if TYPE_CHECKING:
21
+ import onnx
22
+
23
+ _N_UNSQUEEZE_INPUTS: int = 2
24
+
25
+
26
+ class QuantizeUnsqueeze(QuantizerBase):
27
+ OP_TYPE = "Unsqueeze"
28
+ DOMAIN = ""
29
+ USE_WB = False
30
+ USE_SCALING = False
31
+ # Only the data input is relevant for scale-planning.
32
+ SCALE_PLAN: ClassVar = {0: 1}
33
+
34
+
35
+ class UnsqueezeQuantizer(BaseOpQuantizer, QuantizeUnsqueeze):
36
+ """
37
+ Quantizer for ONNX Unsqueeze.
38
+
39
+ Unsqueeze is scale-preserving (pure shape/view transform):
40
+ - No arithmetic
41
+ - No rescaling
42
+ - No custom op
43
+
44
+ Semantics:
45
+ - Inserts new dimensions of size 1 at the specified axes positions.
46
+
47
+ We support:
48
+ - axes as an attribute (older opsets)
49
+ - axes as a constant initializer input (opset >= 13 style)
50
+
51
+ We do NOT support dynamic axes provided at runtime.
52
+ """
53
+
54
+ def __init__(
55
+ self: UnsqueezeQuantizer,
56
+ new_initializers: list[onnx.TensorProto] | None = None,
57
+ ) -> None:
58
+ super().__init__()
59
+ if new_initializers is not None:
60
+ self.new_initializers = new_initializers
61
+
62
+ def quantize(
63
+ self: UnsqueezeQuantizer,
64
+ node: onnx.NodeProto,
65
+ graph: onnx.GraphProto,
66
+ scale_config: ScaleConfig,
67
+ initializer_map: dict[str, onnx.TensorProto],
68
+ ) -> list[onnx.NodeProto]:
69
+ # Pure passthrough; QuantizerBase handles standard bookkeeping.
70
+ return QuantizeUnsqueeze.quantize(
71
+ self,
72
+ node,
73
+ graph,
74
+ scale_config,
75
+ initializer_map,
76
+ )
77
+
78
+ def pre_analysis_transform(
79
+ self: UnsqueezeQuantizer,
80
+ node: onnx.NodeProto,
81
+ graph: onnx.GraphProto,
82
+ initializer_map: dict[str, onnx.TensorProto],
83
+ scale_base: int,
84
+ scale_exponent: int,
85
+ ) -> None:
86
+ _ = initializer_map, scale_base, scale_exponent
87
+ model_type = ModelType.ONNX
88
+ params = parse_attributes(node.attribute)
89
+ if node.op_type != "Unsqueeze":
90
+ return
91
+ if params and "axes" in params:
92
+ return
93
+ axes = _extract_unsqueeze_axes_into_params(
94
+ name=node.name,
95
+ inputs=node.input,
96
+ params=params,
97
+ graph=graph,
98
+ model_type=model_type,
99
+ initializer_map=initializer_map,
100
+ )
101
+ attr = helper.make_attribute("axes", axes["axes"])
102
+ node.attribute.append(attr)
103
+
104
+ _N_INPUTS_NO_AXES: ClassVar[int] = 1
105
+ _N_INPUTS_WITH_AXES: ClassVar[int] = 2
106
+
107
+ def _get_axes_from_attribute(self, node: onnx.NodeProto) -> list[int] | None:
108
+ for attr in node.attribute:
109
+ if attr.name == "axes":
110
+ return list(attr.ints)
111
+ return None
112
+
113
+ def _get_axes_from_initializer_input(
114
+ self,
115
+ node: onnx.NodeProto,
116
+ initializer_map: dict[str, onnx.TensorProto],
117
+ ) -> list[int]:
118
+ axes_name = node.input[1]
119
+ if axes_name not in initializer_map:
120
+ raise InvalidParamError(
121
+ node_name=node.name,
122
+ op_type=node.op_type,
123
+ message=(
124
+ f"Dynamic axes input is not supported for Unsqueeze "
125
+ f"(expected axes '{axes_name}' to be an initializer)."
126
+ ),
127
+ )
128
+
129
+ axes_tensor = initializer_map[axes_name]
130
+ arr = numpy_helper.to_array(axes_tensor)
131
+
132
+ if not np.issubdtype(arr.dtype, np.integer):
133
+ raise InvalidParamError(
134
+ node_name=node.name,
135
+ op_type=node.op_type,
136
+ message=f"Unsqueeze axes initializer must be integer, got {arr.dtype}.",
137
+ attr_key="axes",
138
+ expected="integer tensor (0-D or 1-D)",
139
+ )
140
+
141
+ if arr.ndim == 0:
142
+ return [int(arr)]
143
+ if arr.ndim == 1:
144
+ return [int(x) for x in arr.tolist()]
145
+
146
+ raise InvalidParamError(
147
+ node_name=node.name,
148
+ op_type=node.op_type,
149
+ message=f"Unsqueeze axes initializer must be 0-D or 1-D, got {arr.ndim}-D.",
150
+ attr_key="axes",
151
+ expected="0-D scalar or 1-D list of axes",
152
+ )
153
+
154
+ def check_supported(
155
+ self: UnsqueezeQuantizer,
156
+ node: onnx.NodeProto,
157
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
158
+ ) -> None:
159
+ self.validate_node_has_output(node)
160
+ initializer_map = initializer_map or {}
161
+
162
+ n_inputs = len(node.input)
163
+ if n_inputs not in (self._N_INPUTS_NO_AXES, self._N_INPUTS_WITH_AXES):
164
+ raise InvalidParamError(
165
+ node_name=node.name,
166
+ op_type=node.op_type,
167
+ message=(
168
+ "Unsqueeze expects either 1 input (axes as attribute) or 2 inputs "
169
+ f"(axes as initializer), got {n_inputs}."
170
+ ),
171
+ )
172
+
173
+ axes = self._get_axes_from_attribute(node)
174
+
175
+ # ONNX Unsqueeze has two schema styles:
176
+ # - newer: Unsqueeze(data, axes) -> 2 inputs, axes is initializer input
177
+ # - older: Unsqueeze(data) with axes attribute -> 1 input
178
+ if n_inputs == self._N_INPUTS_NO_AXES:
179
+ if axes is None:
180
+ raise InvalidParamError(
181
+ node_name=node.name,
182
+ op_type=node.op_type,
183
+ message=(
184
+ "Unsqueeze with 1 input is only supported when 'axes' is "
185
+ "provided as an attribute (older opsets)."
186
+ ),
187
+ attr_key="axes",
188
+ expected="axes attribute",
189
+ )
190
+ elif axes is None:
191
+ axes = self._get_axes_from_initializer_input(node, initializer_map)
192
+
193
+ if axes is None:
194
+ raise InvalidParamError(
195
+ node_name=node.name,
196
+ op_type=node.op_type,
197
+ message="Unsqueeze requires 'axes' to be provided.",
198
+ attr_key="axes",
199
+ expected="axes attribute or initializer input",
200
+ )
201
+
202
+ if len(set(axes)) != len(axes):
203
+ raise InvalidParamError(
204
+ node_name=node.name,
205
+ op_type=node.op_type,
206
+ message=f"axes must not contain duplicates: {axes}",
207
+ attr_key="axes",
208
+ expected="axes list with unique entries",
209
+ )
210
+
211
+
212
+ def _extract_unsqueeze_axes_into_params( # noqa: PLR0913
213
+ *,
214
+ name: str,
215
+ inputs: list[str] | tuple[str, ...],
216
+ params: dict | None,
217
+ graph: onnx.GraphProto,
218
+ model_type: ModelType,
219
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
220
+ ) -> dict:
221
+ if len(inputs) != _N_UNSQUEEZE_INPUTS:
222
+ msg = (
223
+ f"Unsqueeze '{name}' is missing axes input. "
224
+ f"Expected 2 inputs (data, axes), got {len(inputs)}: {list(inputs)}"
225
+ )
226
+ raise LayerAnalysisError(model_type=model_type, reason=msg)
227
+
228
+ axes_name = inputs[1]
229
+
230
+ axes_arr = _resolve_unsqueeze_axes_array(
231
+ name=name,
232
+ axes_name=axes_name,
233
+ graph=graph,
234
+ model_type=model_type,
235
+ initializer_map=initializer_map,
236
+ )
237
+
238
+ _validate_unsqueeze_axes_are_integer(
239
+ name=name,
240
+ axes_arr=axes_arr,
241
+ model_type=model_type,
242
+ )
243
+
244
+ out_params = params or {}
245
+ out_params["axes"] = _axes_array_to_int_list(axes_arr)
246
+ return out_params
247
+
248
+
249
+ def _resolve_unsqueeze_axes_array(
250
+ *,
251
+ name: str,
252
+ axes_name: str,
253
+ graph: onnx.GraphProto,
254
+ model_type: ModelType,
255
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
256
+ ) -> np.ndarray:
257
+ if not initializer_map:
258
+ initializer_map = {init.name: init for init in graph.initializer}
259
+
260
+ if axes_name in initializer_map:
261
+ return numpy_helper.to_array(initializer_map[axes_name])
262
+
263
+ const_tensor = _find_constant_tensor_by_output_name(
264
+ graph=graph,
265
+ output_name=axes_name,
266
+ )
267
+
268
+ if const_tensor is not None:
269
+ return numpy_helper.to_array(const_tensor)
270
+
271
+ msg = (
272
+ f"Unsqueeze '{name}' has dynamic axes input '{axes_name}'. "
273
+ "Only constant initializer axes or Constant-node axes are supported."
274
+ )
275
+ raise LayerAnalysisError(model_type=model_type, reason=msg)
276
+
277
+
278
+ def _find_constant_tensor_by_output_name(
279
+ *,
280
+ graph: onnx.GraphProto,
281
+ output_name: str,
282
+ ) -> onnx.TensorProto | None:
283
+ for n in graph.node:
284
+ if n.op_type != "Constant" or not n.output:
285
+ continue
286
+ if n.output[0] != output_name:
287
+ continue
288
+
289
+ for attr in n.attribute:
290
+ if attr.name == "value" and attr.t is not None:
291
+ return attr.t
292
+
293
+ # Constant node exists but doesn't have the expected tensor attribute.
294
+ return None
295
+
296
+ return None
297
+
298
+
299
+ def _validate_unsqueeze_axes_are_integer(
300
+ *,
301
+ name: str,
302
+ axes_arr: np.ndarray,
303
+ model_type: ModelType,
304
+ ) -> None:
305
+ if not np.issubdtype(axes_arr.dtype, np.integer):
306
+ msg = f"Unsqueeze '{name}' axes must be integer, got dtype {axes_arr.dtype}."
307
+ raise LayerAnalysisError(model_type=model_type, reason=msg)
308
+
309
+
310
+ def _axes_array_to_int_list(axes_arr: np.ndarray) -> list[int]:
311
+ if axes_arr.ndim == 0:
312
+ return [int(axes_arr)]
313
+ return [int(x) for x in axes_arr.reshape(-1).tolist()]
@@ -31,7 +31,11 @@ from python.core.model_processing.onnx_quantizer.layers.maxpool import MaxpoolQu
31
31
  from python.core.model_processing.onnx_quantizer.layers.min import MinQuantizer
32
32
  from python.core.model_processing.onnx_quantizer.layers.mul import MulQuantizer
33
33
  from python.core.model_processing.onnx_quantizer.layers.relu import ReluQuantizer
34
+ from python.core.model_processing.onnx_quantizer.layers.squeeze import SqueezeQuantizer
34
35
  from python.core.model_processing.onnx_quantizer.layers.sub import SubQuantizer
36
+ from python.core.model_processing.onnx_quantizer.layers.unsqueeze import (
37
+ UnsqueezeQuantizer,
38
+ )
35
39
 
36
40
 
37
41
  class ONNXOpQuantizer:
@@ -90,6 +94,8 @@ class ONNXOpQuantizer:
90
94
  self.register("Max", MaxQuantizer(self.new_initializers))
91
95
  self.register("Min", MinQuantizer(self.new_initializers))
92
96
  self.register("BatchNormalization", BatchnormQuantizer(self.new_initializers))
97
+ self.register("Squeeze", SqueezeQuantizer(self.new_initializers))
98
+ self.register("Unsqueeze", UnsqueezeQuantizer(self.new_initializers))
93
99
 
94
100
  def register(
95
101
  self: ONNXOpQuantizer,
python/frontend/cli.py CHANGED
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
12
12
  from python.frontend.commands import BaseCommand
13
13
 
14
14
  from python.frontend.commands import (
15
+ BatchCommand,
15
16
  BenchCommand,
16
17
  CompileCommand,
17
18
  ModelCheckCommand,
@@ -41,6 +42,7 @@ COMMANDS: list[type[BaseCommand]] = [
41
42
  WitnessCommand,
42
43
  ProveCommand,
43
44
  VerifyCommand,
45
+ BatchCommand,
44
46
  BenchCommand,
45
47
  ]
46
48
 
@@ -1,4 +1,5 @@
1
1
  from python.frontend.commands.base import BaseCommand
2
+ from python.frontend.commands.batch import BatchCommand
2
3
  from python.frontend.commands.bench import BenchCommand
3
4
  from python.frontend.commands.compile import CompileCommand
4
5
  from python.frontend.commands.model_check import ModelCheckCommand
@@ -8,6 +9,7 @@ from python.frontend.commands.witness import WitnessCommand
8
9
 
9
10
  __all__ = [
10
11
  "BaseCommand",
12
+ "BatchCommand",
11
13
  "BenchCommand",
12
14
  "CompileCommand",
13
15
  "ModelCheckCommand",
@@ -0,0 +1,192 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any, ClassVar
5
+
6
+ from python.core.utils.helper_functions import (
7
+ read_from_json,
8
+ run_cargo_command,
9
+ to_json,
10
+ )
11
+ from python.frontend.commands.args import CIRCUIT_PATH
12
+ from python.frontend.commands.base import BaseCommand
13
+
14
+ if TYPE_CHECKING:
15
+ import argparse
16
+ from collections.abc import Callable
17
+
18
+ from python.core.circuits.base import Circuit
19
+
20
+
21
+ def _preprocess_manifest(
22
+ circuit: Circuit,
23
+ manifest_path: str,
24
+ circuit_path: str,
25
+ transform_job: Callable[[Circuit, dict[str, Any]], None],
26
+ ) -> str:
27
+ circuit_file = Path(circuit_path)
28
+ quantized_path = str(
29
+ circuit_file.parent / f"{circuit_file.stem}_quantized_model.onnx",
30
+ )
31
+ circuit.load_quantized_model(quantized_path)
32
+
33
+ manifest: dict[str, Any] = read_from_json(manifest_path)
34
+ if not isinstance(manifest, dict) or not isinstance(
35
+ manifest.get("jobs"),
36
+ list,
37
+ ):
38
+ msg = f"Invalid manifest: expected {{'jobs': [...]}} in {manifest_path}"
39
+ raise TypeError(msg)
40
+ for job in manifest["jobs"]:
41
+ transform_job(circuit, job)
42
+
43
+ manifest_file = Path(manifest_path)
44
+ processed_path = str(
45
+ manifest_file.with_name(
46
+ manifest_file.stem + "_processed" + manifest_file.suffix,
47
+ ),
48
+ )
49
+ to_json(manifest, processed_path)
50
+ return processed_path
51
+
52
+
53
+ def _validate_job_keys(job: dict[str, Any], *keys: str) -> None:
54
+ missing = [k for k in keys if k not in job]
55
+ if missing:
56
+ msg = f"Job missing required keys {missing}: {job}"
57
+ raise ValueError(msg)
58
+
59
+
60
+ def _transform_witness_job(circuit: Circuit, job: dict[str, Any]) -> None:
61
+ _validate_job_keys(job, "input", "output")
62
+ inputs = read_from_json(job["input"])
63
+ scaled = circuit.scale_inputs_only(inputs)
64
+
65
+ inference_inputs = circuit.reshape_inputs_for_inference(scaled)
66
+ circuit_inputs = circuit.reshape_inputs_for_circuit(scaled)
67
+
68
+ path = Path(job["input"])
69
+ adjusted_path = str(path.with_name(path.stem + "_adjusted" + path.suffix))
70
+ to_json(circuit_inputs, adjusted_path)
71
+
72
+ outputs = circuit.get_outputs(inference_inputs)
73
+ formatted = circuit.format_outputs(outputs)
74
+ to_json(formatted, job["output"])
75
+
76
+ job["input"] = adjusted_path
77
+
78
+
79
+ def _transform_verify_job(circuit: Circuit, job: dict[str, Any]) -> None:
80
+ _validate_job_keys(job, "input")
81
+ inputs = read_from_json(job["input"])
82
+ circuit_inputs = circuit.reshape_inputs_for_circuit(inputs)
83
+
84
+ path = Path(job["input"])
85
+ processed_path = str(path.with_name(path.stem + "_veri" + path.suffix))
86
+ to_json(circuit_inputs, processed_path)
87
+
88
+ job["input"] = processed_path
89
+
90
+
91
+ class BatchCommand(BaseCommand):
92
+ """Run batch operations on multiple inputs."""
93
+
94
+ name: ClassVar[str] = "batch"
95
+ aliases: ClassVar[list[str]] = []
96
+ help: ClassVar[str] = "Run batch witness/prove/verify operations."
97
+
98
+ @classmethod
99
+ def _add_batch_args(cls, subparser: argparse.ArgumentParser) -> None:
100
+ CIRCUIT_PATH.add_to_parser(subparser)
101
+ subparser.add_argument(
102
+ "-f",
103
+ "--manifest",
104
+ required=True,
105
+ help="Path to batch manifest JSON file.",
106
+ )
107
+
108
+ @classmethod
109
+ def configure_parser(
110
+ cls: type[BatchCommand],
111
+ parser: argparse.ArgumentParser,
112
+ ) -> None:
113
+ subparsers = parser.add_subparsers(dest="batch_mode", required=True)
114
+
115
+ for name, help_text in [
116
+ ("witness", "Generate witnesses for multiple inputs."),
117
+ ("prove", "Generate proofs for multiple witnesses."),
118
+ ("verify", "Verify multiple proofs."),
119
+ ]:
120
+ sub = subparsers.add_parser(name, help=help_text)
121
+ cls._add_batch_args(sub)
122
+
123
+ @classmethod
124
+ def run(cls: type[BatchCommand], args: argparse.Namespace) -> None:
125
+ circuit_path = getattr(args, "circuit_path", None) or getattr(
126
+ args,
127
+ "pos_circuit_path",
128
+ None,
129
+ )
130
+ if not circuit_path:
131
+ msg = "Missing required argument: circuit_path"
132
+ raise ValueError(msg)
133
+
134
+ circuit_file = Path(circuit_path)
135
+ if not circuit_file.is_file():
136
+ msg = f"Circuit file not found: {circuit_path}"
137
+ raise FileNotFoundError(msg)
138
+
139
+ manifest_path = args.manifest
140
+ if not Path(manifest_path).is_file():
141
+ msg = f"Manifest file not found: {manifest_path}"
142
+ raise FileNotFoundError(msg)
143
+
144
+ batch_mode = args.batch_mode
145
+
146
+ circuit = cls._build_circuit("cli")
147
+
148
+ run_type_map = {
149
+ "witness": "run_batch_witness",
150
+ "prove": "run_batch_prove",
151
+ "verify": "run_batch_verify",
152
+ }
153
+
154
+ run_type_str = run_type_map.get(batch_mode)
155
+ if not run_type_str:
156
+ msg = f"Unknown batch mode: {batch_mode}"
157
+ raise ValueError(msg)
158
+
159
+ circuit_dir = circuit_file.parent
160
+ name = circuit_file.stem
161
+ metadata_path = str(circuit_dir / f"{name}_metadata.json")
162
+ if not Path(metadata_path).is_file():
163
+ msg = f"Metadata file not found: {metadata_path}"
164
+ raise FileNotFoundError(msg)
165
+
166
+ preprocess_map = {
167
+ "witness": _transform_witness_job,
168
+ "verify": _transform_verify_job,
169
+ }
170
+ if batch_mode in preprocess_map:
171
+ manifest_path = _preprocess_manifest(
172
+ circuit,
173
+ manifest_path,
174
+ circuit_path,
175
+ preprocess_map[batch_mode],
176
+ )
177
+
178
+ try:
179
+ run_cargo_command(
180
+ binary_name=circuit.name,
181
+ command_type=run_type_str,
182
+ args={
183
+ "c": circuit_path,
184
+ "f": manifest_path,
185
+ "m": metadata_path,
186
+ },
187
+ dev_mode=False,
188
+ )
189
+ except Exception as e:
190
+ raise RuntimeError(e) from e
191
+
192
+ print(f"[batch {batch_mode}] complete") # noqa: T201
@@ -98,7 +98,7 @@ class LayerTestConfig:
98
98
  combined_inits = {**self.required_initializers, **initializer_overrides}
99
99
  for name, data in combined_inits.items():
100
100
  # Special handling for shape tensors in Reshape, etc.
101
- if name == "shape":
101
+ if name in {"shape", "axes"}:
102
102
  tensor = numpy_helper.from_array(data.astype(np.int64), name=name)
103
103
  else:
104
104
  tensor = numpy_helper.from_array(data.astype(np.float32), name=name)
@@ -0,0 +1,117 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
6
+ from python.tests.onnx_quantizer_tests.layers.base import (
7
+ e2e_test,
8
+ error_test,
9
+ valid_test,
10
+ )
11
+ from python.tests.onnx_quantizer_tests.layers.factory import (
12
+ BaseLayerConfigProvider,
13
+ LayerTestConfig,
14
+ )
15
+
16
+
17
+ class SqueezeConfigProvider(BaseLayerConfigProvider):
18
+ """Test configuration provider for Squeeze"""
19
+
20
+ @property
21
+ def layer_name(self) -> str:
22
+ return "Squeeze"
23
+
24
+ def get_config(self) -> LayerTestConfig:
25
+ # Test opset-newer form: Squeeze(data, axes) where axes is an int64 initializer.
26
+ return LayerTestConfig(
27
+ op_type="Squeeze",
28
+ valid_inputs=["A", "axes"],
29
+ valid_attributes={}, # no attribute-based axes
30
+ required_initializers={},
31
+ input_shapes={
32
+ "A": [1, 3, 1, 5],
33
+ # "axes" is removed from graph inputs automatically
34
+ # when it is an initializer.
35
+ "axes": [2],
36
+ },
37
+ output_shapes={
38
+ "squeeze_output": [3, 5],
39
+ },
40
+ )
41
+
42
+ def get_test_specs(self) -> list:
43
+
44
+ return [
45
+ # --- VALID TESTS ---
46
+ valid_test("axes_omitted")
47
+ .description("Squeeze with no axes input: removes all dims of size 1")
48
+ .override_inputs("A") # only data input
49
+ .override_input_shapes(A=[1, 3, 1, 5])
50
+ .override_output_shapes(squeeze_output=[3, 5])
51
+ .tags("basic", "squeeze", "axes_omitted")
52
+ .build(),
53
+ valid_test("axes_init_basic")
54
+ .description("Squeeze with axes initializer [0,2] on [1,3,1,5] -> [3,5]")
55
+ .override_inputs("A", "axes")
56
+ .override_initializer("axes", np.array([0, 2], dtype=np.int64))
57
+ .override_input_shapes(A=[1, 3, 1, 5])
58
+ .override_output_shapes(squeeze_output=[3, 5])
59
+ .tags("basic", "squeeze", "axes_initializer")
60
+ .build(),
61
+ valid_test("axes_init_singleton_middle")
62
+ .description("Squeeze with axes initializer [1] on [2,1,4] -> [2,4]")
63
+ .override_inputs("A", "axes")
64
+ .override_initializer("axes", np.array([1], dtype=np.int64))
65
+ .override_input_shapes(A=[2, 1, 4])
66
+ .override_output_shapes(squeeze_output=[2, 4])
67
+ .tags("squeeze", "axes_initializer")
68
+ .build(),
69
+ valid_test("axes_init_negative")
70
+ .description("Squeeze with axes initializer [-2] on [2,1,4] -> [2,4]")
71
+ .override_inputs("A", "axes")
72
+ .override_initializer("axes", np.array([-2], dtype=np.int64))
73
+ .override_input_shapes(A=[2, 1, 4])
74
+ .override_output_shapes(squeeze_output=[2, 4])
75
+ .tags("squeeze", "axes_initializer", "negative_axis")
76
+ .build(),
77
+ # --- ERROR TESTS ---
78
+ error_test("duplicate_axes_init")
79
+ .description("Duplicate axes in initializer should be rejected")
80
+ .override_inputs("A", "axes")
81
+ .override_initializer("axes", np.array([1, 1], dtype=np.int64))
82
+ .override_input_shapes(A=[2, 1, 4])
83
+ .override_output_shapes(squeeze_output=[2, 4])
84
+ .expects_error(InvalidParamError, match="axes must not contain duplicates")
85
+ .tags("error", "squeeze", "axes_initializer")
86
+ .build(),
87
+ error_test("dynamic_axes_input_not_supported")
88
+ .description(
89
+ "Squeeze with runtime axes (2 inputs but axes is NOT an initializer) "
90
+ "should be rejected",
91
+ )
92
+ .override_inputs("A", "axes") # axes provided as graph input (unsupported)
93
+ .override_input_shapes(A=[1, 3, 1, 5], axes=[2])
94
+ .override_output_shapes(squeeze_output=[3, 5])
95
+ .expects_error(
96
+ InvalidParamError,
97
+ match="Dynamic axes input is not supported",
98
+ )
99
+ .tags("error", "squeeze", "axes_input")
100
+ .build(),
101
+ # --- E2E TESTS ---
102
+ e2e_test("e2e_axes_omitted")
103
+ .description("End-to-end Squeeze test (axes omitted)")
104
+ .override_inputs("A")
105
+ .override_input_shapes(A=[1, 3, 1, 5])
106
+ .override_output_shapes(squeeze_output=[3, 5])
107
+ .tags("e2e", "squeeze")
108
+ .build(),
109
+ e2e_test("e2e_axes_init")
110
+ .description("End-to-end Squeeze test (axes initializer)")
111
+ .override_inputs("A", "axes")
112
+ .override_initializer("axes", np.array([0, 2], dtype=np.int64))
113
+ .override_input_shapes(A=[1, 3, 1, 5])
114
+ .override_output_shapes(squeeze_output=[3, 5])
115
+ .tags("e2e", "squeeze", "axes_initializer")
116
+ .build(),
117
+ ]
@@ -0,0 +1,114 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
6
+ from python.tests.onnx_quantizer_tests.layers.base import (
7
+ e2e_test,
8
+ error_test,
9
+ valid_test,
10
+ )
11
+ from python.tests.onnx_quantizer_tests.layers.factory import (
12
+ BaseLayerConfigProvider,
13
+ LayerTestConfig,
14
+ )
15
+
16
+
17
+ class UnsqueezeConfigProvider(BaseLayerConfigProvider):
18
+ """Test configuration provider for Unsqueeze"""
19
+
20
+ @property
21
+ def layer_name(self) -> str:
22
+ return "Unsqueeze"
23
+
24
+ def get_config(self) -> LayerTestConfig:
25
+ # Test opset-newer form: Unsqueeze(data, axes)
26
+ # where axes is an int64 initializer.
27
+ return LayerTestConfig(
28
+ op_type="Unsqueeze",
29
+ valid_inputs=["A", "axes"],
30
+ valid_attributes={}, # no attribute-based axes
31
+ required_initializers={},
32
+ input_shapes={
33
+ "A": [3, 5],
34
+ # "axes" will be removed from graph inputs automatically
35
+ # when it is an initializer.
36
+ "axes": [2],
37
+ },
38
+ output_shapes={
39
+ "unsqueeze_output": [1, 3, 1, 5],
40
+ },
41
+ )
42
+
43
+ def get_test_specs(self) -> list:
44
+
45
+ return [
46
+ # --- VALID TESTS ---
47
+ valid_test("axes_init_basic")
48
+ .description("Unsqueeze with axes initializer [0,2] on [3,5] -> [1,3,1,5]")
49
+ .override_inputs("A", "axes")
50
+ .override_initializer("axes", np.array([0, 2], dtype=np.int64))
51
+ .override_input_shapes(A=[3, 5])
52
+ .override_output_shapes(unsqueeze_output=[1, 3, 1, 5])
53
+ .tags("basic", "unsqueeze", "axes_initializer")
54
+ .build(),
55
+ valid_test("axes_init_single_axis")
56
+ .description("Unsqueeze with axes initializer [1] on [3,5] -> [3,1,5]")
57
+ .override_inputs("A", "axes")
58
+ .override_initializer("axes", np.array([1], dtype=np.int64))
59
+ .override_input_shapes(A=[3, 5])
60
+ .override_output_shapes(unsqueeze_output=[3, 1, 5])
61
+ .tags("unsqueeze", "axes_initializer")
62
+ .build(),
63
+ valid_test("axes_init_negative")
64
+ .description("Unsqueeze with negative axis [-1] on [3,5] -> [3,5,1]")
65
+ .override_inputs("A", "axes")
66
+ .override_initializer("axes", np.array([-1], dtype=np.int64))
67
+ .override_input_shapes(A=[3, 5])
68
+ .override_output_shapes(unsqueeze_output=[3, 5, 1])
69
+ .tags("unsqueeze", "axes_initializer", "negative_axis")
70
+ .build(),
71
+ valid_test("axes_init_two_axes_append")
72
+ .description("Unsqueeze with axes [2,3] on [3,5] -> [3,5,1,1]")
73
+ .override_inputs("A", "axes")
74
+ .override_initializer("axes", np.array([2, 3], dtype=np.int64))
75
+ .override_input_shapes(A=[3, 5])
76
+ .override_output_shapes(unsqueeze_output=[3, 5, 1, 1])
77
+ .tags("unsqueeze", "axes_initializer")
78
+ .build(),
79
+ # --- ERROR TESTS ---
80
+ error_test("duplicate_axes_init")
81
+ .description("Duplicate axes in initializer should be rejected")
82
+ .override_inputs("A", "axes")
83
+ .override_initializer("axes", np.array([1, 1], dtype=np.int64))
84
+ .override_input_shapes(A=[3, 5])
85
+ .override_output_shapes(
86
+ unsqueeze_output=[3, 1, 5],
87
+ ) # not used; kept consistent
88
+ .expects_error(InvalidParamError, match="axes must not contain duplicates")
89
+ .tags("error", "unsqueeze", "axes_initializer")
90
+ .build(),
91
+ error_test("dynamic_axes_input_not_supported")
92
+ .description(
93
+ "Unsqueeze with runtime axes (2 inputs but axes is NOT an initializer) "
94
+ "should be rejected",
95
+ )
96
+ .override_inputs("A", "axes") # axes provided as graph input (unsupported)
97
+ .override_input_shapes(A=[3, 5], axes=[2])
98
+ .override_output_shapes(unsqueeze_output=[1, 3, 1, 5])
99
+ .expects_error(
100
+ InvalidParamError,
101
+ match="Dynamic axes input is not supported",
102
+ )
103
+ .tags("error", "unsqueeze", "axes_input")
104
+ .build(),
105
+ # --- E2E TESTS ---
106
+ e2e_test("e2e_axes_init")
107
+ .description("End-to-end Unsqueeze test (axes initializer)")
108
+ .override_inputs("A", "axes")
109
+ .override_initializer("axes", np.array([0, 2], dtype=np.int64))
110
+ .override_input_shapes(A=[3, 5])
111
+ .override_output_shapes(unsqueeze_output=[1, 3, 1, 5])
112
+ .tags("e2e", "unsqueeze", "axes_initializer")
113
+ .build(),
114
+ ]
python/tests/test_cli.py CHANGED
@@ -1019,3 +1019,186 @@ def test_bench_with_iterations(tmp_path: Path) -> None:
1019
1019
  assert "--iterations" in cmd
1020
1020
  idx = cmd.index("--iterations")
1021
1021
  assert cmd[idx + 1] == "10"
1022
+
1023
+
1024
+ @pytest.mark.unit
1025
+ def test_batch_prove_dispatch(tmp_path: Path) -> None:
1026
+ circuit = tmp_path / "circuit.txt"
1027
+ circuit.write_text("ok")
1028
+
1029
+ manifest = tmp_path / "manifest.json"
1030
+ manifest.write_text('{"jobs": []}')
1031
+
1032
+ metadata = tmp_path / "circuit_metadata.json"
1033
+ metadata.write_text("{}")
1034
+
1035
+ fake_circuit = MagicMock()
1036
+ fake_circuit.name = "test_circuit"
1037
+
1038
+ with (
1039
+ patch(
1040
+ "python.frontend.commands.batch.BatchCommand._build_circuit",
1041
+ return_value=fake_circuit,
1042
+ ),
1043
+ patch(
1044
+ "python.frontend.commands.batch.run_cargo_command",
1045
+ ) as mock_cargo,
1046
+ ):
1047
+ rc = main(
1048
+ [
1049
+ "--no-banner",
1050
+ "batch",
1051
+ "prove",
1052
+ "-c",
1053
+ str(circuit),
1054
+ "-f",
1055
+ str(manifest),
1056
+ ],
1057
+ )
1058
+
1059
+ assert rc == 0
1060
+ mock_cargo.assert_called_once()
1061
+ call_kwargs = mock_cargo.call_args[1]
1062
+ assert call_kwargs["command_type"] == "run_batch_prove"
1063
+ assert call_kwargs["args"]["c"] == str(circuit)
1064
+ assert call_kwargs["args"]["f"] == str(manifest)
1065
+
1066
+
1067
+ @pytest.mark.unit
1068
+ def test_batch_verify_dispatch(tmp_path: Path) -> None:
1069
+ circuit = tmp_path / "circuit.txt"
1070
+ circuit.write_text("ok")
1071
+
1072
+ manifest = tmp_path / "manifest.json"
1073
+ manifest.write_text('{"jobs": []}')
1074
+
1075
+ metadata = tmp_path / "circuit_metadata.json"
1076
+ metadata.write_text("{}")
1077
+
1078
+ processed = tmp_path / "manifest_processed.json"
1079
+ processed.write_text('{"jobs": []}')
1080
+
1081
+ fake_circuit = MagicMock()
1082
+ fake_circuit.name = "test_circuit"
1083
+
1084
+ with (
1085
+ patch(
1086
+ "python.frontend.commands.batch.BatchCommand._build_circuit",
1087
+ return_value=fake_circuit,
1088
+ ),
1089
+ patch(
1090
+ "python.frontend.commands.batch._preprocess_manifest",
1091
+ return_value=str(processed),
1092
+ ) as mock_preprocess,
1093
+ patch(
1094
+ "python.frontend.commands.batch.run_cargo_command",
1095
+ ) as mock_cargo,
1096
+ ):
1097
+ rc = main(
1098
+ [
1099
+ "--no-banner",
1100
+ "batch",
1101
+ "verify",
1102
+ "-c",
1103
+ str(circuit),
1104
+ "-f",
1105
+ str(manifest),
1106
+ ],
1107
+ )
1108
+
1109
+ assert rc == 0
1110
+ mock_preprocess.assert_called_once()
1111
+ mock_cargo.assert_called_once()
1112
+ call_kwargs = mock_cargo.call_args[1]
1113
+ assert call_kwargs["command_type"] == "run_batch_verify"
1114
+ assert call_kwargs["args"]["f"] == str(processed)
1115
+
1116
+
1117
+ @pytest.mark.unit
1118
+ def test_batch_witness_dispatch(tmp_path: Path) -> None:
1119
+ circuit = tmp_path / "circuit.txt"
1120
+ circuit.write_text("ok")
1121
+
1122
+ manifest = tmp_path / "manifest.json"
1123
+ manifest.write_text('{"jobs": []}')
1124
+
1125
+ metadata = tmp_path / "circuit_metadata.json"
1126
+ metadata.write_text("{}")
1127
+
1128
+ processed = tmp_path / "manifest_processed.json"
1129
+ processed.write_text('{"jobs": []}')
1130
+
1131
+ fake_circuit = MagicMock()
1132
+ fake_circuit.name = "test_circuit"
1133
+
1134
+ with (
1135
+ patch(
1136
+ "python.frontend.commands.batch.BatchCommand._build_circuit",
1137
+ return_value=fake_circuit,
1138
+ ),
1139
+ patch(
1140
+ "python.frontend.commands.batch._preprocess_manifest",
1141
+ return_value=str(processed),
1142
+ ) as mock_preprocess,
1143
+ patch(
1144
+ "python.frontend.commands.batch.run_cargo_command",
1145
+ ) as mock_cargo,
1146
+ ):
1147
+ rc = main(
1148
+ [
1149
+ "--no-banner",
1150
+ "batch",
1151
+ "witness",
1152
+ "-c",
1153
+ str(circuit),
1154
+ "-f",
1155
+ str(manifest),
1156
+ ],
1157
+ )
1158
+
1159
+ assert rc == 0
1160
+ mock_preprocess.assert_called_once()
1161
+ mock_cargo.assert_called_once()
1162
+ call_kwargs = mock_cargo.call_args[1]
1163
+ assert call_kwargs["command_type"] == "run_batch_witness"
1164
+ assert call_kwargs["args"]["f"] == str(processed)
1165
+
1166
+
1167
+ @pytest.mark.unit
1168
+ def test_batch_missing_circuit(tmp_path: Path) -> None:
1169
+ manifest = tmp_path / "manifest.json"
1170
+ manifest.write_text('{"jobs": []}')
1171
+
1172
+ rc = main(
1173
+ [
1174
+ "--no-banner",
1175
+ "batch",
1176
+ "prove",
1177
+ "-c",
1178
+ str(tmp_path / "nonexistent.txt"),
1179
+ "-f",
1180
+ str(manifest),
1181
+ ],
1182
+ )
1183
+
1184
+ assert rc == 1
1185
+
1186
+
1187
+ @pytest.mark.unit
1188
+ def test_batch_missing_manifest(tmp_path: Path) -> None:
1189
+ circuit = tmp_path / "circuit.txt"
1190
+ circuit.write_text("ok")
1191
+
1192
+ rc = main(
1193
+ [
1194
+ "--no-banner",
1195
+ "batch",
1196
+ "prove",
1197
+ "-c",
1198
+ str(circuit),
1199
+ "-f",
1200
+ str(tmp_path / "nonexistent.json"),
1201
+ ],
1202
+ )
1203
+
1204
+ assert rc == 1