JSTprove 2.0.0__py3-none-macosx_11_0_arm64.whl → 2.1.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- {jstprove-2.0.0.dist-info → jstprove-2.1.0.dist-info}/METADATA +1 -1
- {jstprove-2.0.0.dist-info → jstprove-2.1.0.dist-info}/RECORD +18 -13
- python/core/binaries/onnx_generic_circuit_2-1-0 +0 -0
- python/core/model_processing/converters/onnx_converter.py +4 -2
- python/core/model_processing/onnx_quantizer/layers/squeeze.py +155 -0
- python/core/model_processing/onnx_quantizer/layers/unsqueeze.py +313 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -0
- python/frontend/cli.py +2 -0
- python/frontend/commands/__init__.py +2 -0
- python/frontend/commands/batch.py +192 -0
- python/tests/onnx_quantizer_tests/layers/base.py +1 -1
- python/tests/onnx_quantizer_tests/layers/squeeze_config.py +117 -0
- python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py +114 -0
- python/tests/test_cli.py +183 -0
- python/core/binaries/onnx_generic_circuit_2-0-0 +0 -0
- {jstprove-2.0.0.dist-info → jstprove-2.1.0.dist-info}/WHEEL +0 -0
- {jstprove-2.0.0.dist-info → jstprove-2.1.0.dist-info}/entry_points.txt +0 -0
- {jstprove-2.0.0.dist-info → jstprove-2.1.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-2.0.0.dist-info → jstprove-2.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
jstprove-2.
|
|
1
|
+
jstprove-2.1.0.dist-info/licenses/LICENSE,sha256=UXQRcYRUH-PfN27n3P-FMaZFY6jr9jFPKcwT7CWbljw,1160
|
|
2
2
|
python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
python/core/__init__.py,sha256=RlfbqGAaUulKl44QGMCkkGJBQZ8R_AgC5bU5zS7BjnA,97
|
|
4
4
|
python/core/binaries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
python/core/binaries/expander-exec,sha256=C_1JcezdfLp9sFOQ2z3wp2gcq1k8zjIR09CxJKGGIuM,7095168
|
|
6
|
-
python/core/binaries/onnx_generic_circuit_2-
|
|
6
|
+
python/core/binaries/onnx_generic_circuit_2-1-0,sha256=RAA6W15aObCMa-0FnwIFHyMStTAuhNJsiaE95dxTBPE,3323248
|
|
7
7
|
python/core/circuit_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
python/core/circuit_models/generic_onnx.py,sha256=P65UZkfVBTE6YhaQ951S6QoTHPuU5ntDt8QL5pXghvw,8787
|
|
9
9
|
python/core/circuit_models/simple_circuit.py,sha256=igQrZtQyreyHc26iAgCyDb0TuD2bJAoumYhc1pYPDzQ,4682
|
|
@@ -15,7 +15,7 @@ python/core/model_processing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
|
|
|
15
15
|
python/core/model_processing/errors.py,sha256=uh2YFjuuU5JM3anMtSTLAH-zjlNAKStmLDZqRUgBWS8,4611
|
|
16
16
|
python/core/model_processing/converters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
17
|
python/core/model_processing/converters/base.py,sha256=o6bNwmqD9sOM9taqMb0ed6804RugQiU3va0rY_EA5SE,4265
|
|
18
|
-
python/core/model_processing/converters/onnx_converter.py,sha256
|
|
18
|
+
python/core/model_processing/converters/onnx_converter.py,sha256=nAvcGzNqNJGHTUgdphvkQAg7JXBHMchIgzp-HTYAkJw,44353
|
|
19
19
|
python/core/model_processing/onnx_custom_ops/__init__.py,sha256=ZKUC4ToRxgEEMHcTyERATVEN0KSDs-9cM1T-tTw3I1g,525
|
|
20
20
|
python/core/model_processing/onnx_custom_ops/batchnorm.py,sha256=8kg4iGGdt6B_fIJkpt4v5eNFpoHa4bjTB0NnCSmKFvE,1693
|
|
21
21
|
python/core/model_processing/onnx_custom_ops/conv.py,sha256=6jJm3fcGWzcU4RjVgf179mPFCqsl4C3AR7bqQTffDgA,3464
|
|
@@ -27,7 +27,7 @@ python/core/model_processing/onnx_custom_ops/onnx_helpers.py,sha256=utnJuc5sgb_z
|
|
|
27
27
|
python/core/model_processing/onnx_custom_ops/relu.py,sha256=pZsPXC_r0FPggURKDphh8P1IRXY0w4hH7ExBmYTlWjE,1202
|
|
28
28
|
python/core/model_processing/onnx_quantizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
29
|
python/core/model_processing/onnx_quantizer/exceptions.py,sha256=vzxBRbpvk4ZZbgacDISnqmQQKj7Ls46V08ilHnhaJy0,5645
|
|
30
|
-
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=
|
|
30
|
+
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=fgN-2-oAqwu436yi_dvmdBApXq2T3DpKGZOqdZybjJ8,9996
|
|
31
31
|
python/core/model_processing/onnx_quantizer/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
32
|
python/core/model_processing/onnx_quantizer/layers/add.py,sha256=AGxzqMa0jABIEKOIgPqEAA7EpZtynQtnD9nxI2NHc0s,1409
|
|
33
33
|
python/core/model_processing/onnx_quantizer/layers/base.py,sha256=H48okaq1tvl6ckGZV6BlXCVLLKIy1HpvomGPWBLI-8Q,22544
|
|
@@ -41,7 +41,9 @@ python/core/model_processing/onnx_quantizer/layers/maxpool.py,sha256=NUa63fYNWrO
|
|
|
41
41
|
python/core/model_processing/onnx_quantizer/layers/min.py,sha256=cQbXzGOApR6HUJZMARXy87W8IbUC562jnAQm8J8ynQI,1709
|
|
42
42
|
python/core/model_processing/onnx_quantizer/layers/mul.py,sha256=qHsmnYPH-c5uiFeDCvV6e1xSgmIXJ64Sjvh0LYDYEqQ,1396
|
|
43
43
|
python/core/model_processing/onnx_quantizer/layers/relu.py,sha256=d-5fyeKNLTgKKnqCwURpxkjl7QdbJQpuovtCFBM03FA,1685
|
|
44
|
+
python/core/model_processing/onnx_quantizer/layers/squeeze.py,sha256=haQ9BaeE9w4XxBdTvKIuWSAhZ65Q7A7MhbZlzbFiyDE,4848
|
|
44
45
|
python/core/model_processing/onnx_quantizer/layers/sub.py,sha256=M7D98TZBNP9-2R9MX6mcpYlrWFxTiX9JCs3XNcg1U-Q,1409
|
|
46
|
+
python/core/model_processing/onnx_quantizer/layers/unsqueeze.py,sha256=qCS8BNsffDesseP15h2s9GDfw7Vg51XqZNVaTclKquQ,9807
|
|
45
47
|
python/core/model_templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
46
48
|
python/core/model_templates/circuit_template.py,sha256=OAqMRshi9OiJYoqpjkg5tUfNf18MfZmhsxxD6SANm_4,2106
|
|
47
49
|
python/core/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -54,10 +56,11 @@ python/core/utils/model_registry.py,sha256=aZg_9LEqsBXK84oxQ8A3NGZl-9aGnLgfR-kgx
|
|
|
54
56
|
python/core/utils/scratch_tests.py,sha256=o2VDTk8QBKA3UHHE-h7Ghtoge6kGG7G-8qwvesuTFFc,2281
|
|
55
57
|
python/core/utils/witness_utils.py,sha256=ukvbF6EaHMPzRQVZad9wQ9gISRwBGQ1hEAHzc5TpGuw,9488
|
|
56
58
|
python/frontend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
57
|
-
python/frontend/cli.py,sha256=
|
|
58
|
-
python/frontend/commands/__init__.py,sha256=
|
|
59
|
+
python/frontend/cli.py,sha256=HU0nYkp0kFgxP1BCs_CX4ERd7VvqzKFMooBK5LkEK2c,3142
|
|
60
|
+
python/frontend/commands/__init__.py,sha256=b6N5jB72TbacBNTraYwqFCei8gepw06id8skZOkKMOw,651
|
|
59
61
|
python/frontend/commands/args.py,sha256=JmG4q-tbEy8_YcQsph_WLEAs_w7y7GiR22PhTrc14v4,2255
|
|
60
62
|
python/frontend/commands/base.py,sha256=a7NWoXB9VL8It1TpuL2vmR7J2bhejHDllrNMBEm-JLE,6368
|
|
63
|
+
python/frontend/commands/batch.py,sha256=-uy1PLgeEDIzgRet8JPcUb2xTkrQHW3v_x0_om7hr6U,6069
|
|
61
64
|
python/frontend/commands/compile.py,sha256=-mE4LjBEXzgsnzTJCeas0ZkZgD-kdATpYLk52ljBw88,1905
|
|
62
65
|
python/frontend/commands/constants.py,sha256=feCVczqP6xphHUta2ZMaAuYyVeemZgwU_sCWr6ky5X8,164
|
|
63
66
|
python/frontend/commands/model_check.py,sha256=xsXvXzYpIagFGnUk_UYGZYx-onP0Opes95XismqvY64,1806
|
|
@@ -72,7 +75,7 @@ python/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
72
75
|
python/scripts/benchmark_runner.py,sha256=sjbqaLrdjt94AoyQXAxT4FhsN6aRu5idTRQ5uHmZOWM,28593
|
|
73
76
|
python/scripts/gen_and_bench.py,sha256=V36x7djYmHlveAJgYzMlXwnmF0gAGO3-1mg9PWOmpj8,16249
|
|
74
77
|
python/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
75
|
-
python/tests/test_cli.py,sha256=
|
|
78
|
+
python/tests/test_cli.py,sha256=hhlhsEx-6jJEtw4SKwZFBki5iKs8Z8KPtQccxgy1AMo,31755
|
|
76
79
|
python/tests/circuit_e2e_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
80
|
python/tests/circuit_e2e_tests/circuit_model_developer_test.py,sha256=8hl8SKw7obXplo0jsiKoKIZLxlu1_HhXvGDeSBDBars,39456
|
|
78
81
|
python/tests/circuit_e2e_tests/helper_fns_for_tests.py,sha256=uEThqTsRdNJivHwAv-aJIUtSPlmVHdhMZqZSH1OqhDE,5177
|
|
@@ -89,7 +92,7 @@ python/tests/onnx_quantizer_tests/test_registered_quantizers.py,sha256=lw_jYSbQ9
|
|
|
89
92
|
python/tests/onnx_quantizer_tests/testing_helper_functions.py,sha256=N0fQv2pYzUCVZ7wkcR8gEKs5zTXT1hWrK-HKSTQYvYU,534
|
|
90
93
|
python/tests/onnx_quantizer_tests/layers/__init__.py,sha256=xP-RmW6LfIANgK1s9Q0KZet2yvNr-3c6YIVLAAQqGUY,404
|
|
91
94
|
python/tests/onnx_quantizer_tests/layers/add_config.py,sha256=T3tGddupDtrvLck2SL2yETDblNtv0aU7Tl7fNyZUhO4,4133
|
|
92
|
-
python/tests/onnx_quantizer_tests/layers/base.py,sha256=
|
|
95
|
+
python/tests/onnx_quantizer_tests/layers/base.py,sha256=YbUx0CnGGtET9qsOd1aD2oN5GbhC6MYtguONp6fY6BA,9327
|
|
93
96
|
python/tests/onnx_quantizer_tests/layers/batchnorm_config.py,sha256=P-sZuHAdEfNczcgTeLjqJnEbpqN3dKTsbqvY4-SBqiQ,8231
|
|
94
97
|
python/tests/onnx_quantizer_tests/layers/clip_config.py,sha256=-OuhnUgz6xY4iW1jUR7W-J__Ie9lXI9vplmzp8qXqRc,4973
|
|
95
98
|
python/tests/onnx_quantizer_tests/layers/constant_config.py,sha256=RdrKNMNZjI3Sk5o8WLNqmBUyYVJRWgtFbQ6oFWMwyQk,1193
|
|
@@ -103,7 +106,9 @@ python/tests/onnx_quantizer_tests/layers/min_config.py,sha256=izKtCaMXoQHiAfmcGl
|
|
|
103
106
|
python/tests/onnx_quantizer_tests/layers/mul_config.py,sha256=_Oy4b97ORxFlF3w0BmJ94hNA968HQx2AvwYiASrGPxw,4135
|
|
104
107
|
python/tests/onnx_quantizer_tests/layers/relu_config.py,sha256=_aHuddDApLUBOa0FiR9h4fNfmMSnH5r4JzOMLW0KaTk,2197
|
|
105
108
|
python/tests/onnx_quantizer_tests/layers/reshape_config.py,sha256=fZchSqIAy76m7j97wVC_UI6slSpv8nbwukhkbGR2sRE,2203
|
|
109
|
+
python/tests/onnx_quantizer_tests/layers/squeeze_config.py,sha256=6Grn9yzgQzrLtjA5fs6E_DNZxtsyDfIqmwpr9l2jric,4909
|
|
106
110
|
python/tests/onnx_quantizer_tests/layers/sub_config.py,sha256=IxF18mG9kjlEiKYSNG912CEcBxOFGxIWoRAwjvBXiRo,4133
|
|
111
|
+
python/tests/onnx_quantizer_tests/layers/unsqueeze_config.py,sha256=M707k0pFvlZOeehkE9TWurPNrYwlaH-IjjV6mZdFEPM,4782
|
|
107
112
|
python/tests/onnx_quantizer_tests/layers_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
108
113
|
python/tests/onnx_quantizer_tests/layers_tests/base_test.py,sha256=UgbcT97tgcuTtO1pOADpww9bz_JElKiI2mxLJYKyF1k,2992
|
|
109
114
|
python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py,sha256=Vxn4LEWHZeGa_vS1-7ptFqSSBb0D-3BG-ETocP4pvsI,3651
|
|
@@ -115,8 +120,8 @@ python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py,sha256=RfnIIi
|
|
|
115
120
|
python/tests/onnx_quantizer_tests/layers_tests/test_validation.py,sha256=jz-WtIEP-jjUklOOAnznwPUXbf07U2PAMGrhzMWP0JU,1371
|
|
116
121
|
python/tests/utils_testing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
117
122
|
python/tests/utils_testing/test_helper_functions.py,sha256=xmeGQieh4LE9U-CDKBlHhSWqH0cAmmDU3qXNbDkkvms,27192
|
|
118
|
-
jstprove-2.
|
|
119
|
-
jstprove-2.
|
|
120
|
-
jstprove-2.
|
|
121
|
-
jstprove-2.
|
|
122
|
-
jstprove-2.
|
|
123
|
+
jstprove-2.1.0.dist-info/METADATA,sha256=dbpJDcKKKd0J8pLcyiz4f4Ohac7EI6GVK6muJf2Y52k,14166
|
|
124
|
+
jstprove-2.1.0.dist-info/WHEEL,sha256=W0m4OUfxniOjS_4shj7CwgYZPxSrjpkVN6CH28tx5ZE,106
|
|
125
|
+
jstprove-2.1.0.dist-info/entry_points.txt,sha256=nGcTSO-4q08gPl1IoWdrPaiY7IbO7XvmXKkd34dYHc8,49
|
|
126
|
+
jstprove-2.1.0.dist-info/top_level.txt,sha256=J-z0poNcsv31IHB413--iOY8LoHBKiTHeybHX3abokI,7
|
|
127
|
+
jstprove-2.1.0.dist-info/RECORD,,
|
|
Binary file
|
|
@@ -1022,7 +1022,9 @@ class ONNXConverter(ModelConverter):
|
|
|
1022
1022
|
]
|
|
1023
1023
|
self.input_shape = get_input_shapes(onnx_model)
|
|
1024
1024
|
|
|
1025
|
-
def get_weights(
|
|
1025
|
+
def get_weights(
|
|
1026
|
+
self: ONNXConverter,
|
|
1027
|
+
) -> tuple[
|
|
1026
1028
|
dict[str, list[ONNXLayerDict]],
|
|
1027
1029
|
dict[str, list[ONNXLayerDict]],
|
|
1028
1030
|
CircuitParamsDict,
|
|
@@ -1055,7 +1057,7 @@ class ONNXConverter(ModelConverter):
|
|
|
1055
1057
|
scale_base=scale_base,
|
|
1056
1058
|
)
|
|
1057
1059
|
# Get layers in correct format
|
|
1058
|
-
|
|
1060
|
+
architecture, w_and_b = self.analyze_layers(
|
|
1059
1061
|
scaled_and_transformed_model,
|
|
1060
1062
|
output_name_to_shape,
|
|
1061
1063
|
)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from onnx import numpy_helper
|
|
7
|
+
|
|
8
|
+
from python.core.model_processing.onnx_quantizer.exceptions import (
|
|
9
|
+
InvalidParamError,
|
|
10
|
+
)
|
|
11
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
12
|
+
BaseOpQuantizer,
|
|
13
|
+
QuantizerBase,
|
|
14
|
+
ScaleConfig,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import onnx
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class QuantizeSqueeze(QuantizerBase):
|
|
22
|
+
OP_TYPE = "Squeeze"
|
|
23
|
+
DOMAIN = ""
|
|
24
|
+
USE_WB = False
|
|
25
|
+
USE_SCALING = False
|
|
26
|
+
# Only the data input is relevant for scale-planning.
|
|
27
|
+
SCALE_PLAN: ClassVar = {0: 1}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SqueezeQuantizer(BaseOpQuantizer, QuantizeSqueeze):
|
|
31
|
+
"""
|
|
32
|
+
Quantizer for ONNX Squeeze.
|
|
33
|
+
|
|
34
|
+
Squeeze is scale-preserving (pure shape/view transform):
|
|
35
|
+
- No arithmetic
|
|
36
|
+
- No rescaling
|
|
37
|
+
- No custom op
|
|
38
|
+
|
|
39
|
+
We support:
|
|
40
|
+
- axes as an attribute (older opsets)
|
|
41
|
+
- axes as a constant initializer input (newer opsets)
|
|
42
|
+
|
|
43
|
+
We do NOT support dynamic axes provided at runtime.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self: SqueezeQuantizer,
|
|
48
|
+
new_initializers: list[onnx.TensorProto] | None = None,
|
|
49
|
+
) -> None:
|
|
50
|
+
super().__init__()
|
|
51
|
+
if new_initializers is not None:
|
|
52
|
+
self.new_initializers = new_initializers
|
|
53
|
+
|
|
54
|
+
def quantize(
|
|
55
|
+
self: SqueezeQuantizer,
|
|
56
|
+
node: onnx.NodeProto,
|
|
57
|
+
graph: onnx.GraphProto,
|
|
58
|
+
scale_config: ScaleConfig,
|
|
59
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
60
|
+
) -> list[onnx.NodeProto]:
|
|
61
|
+
# Pure passthrough; QuantizerBase handles standard bookkeeping.
|
|
62
|
+
return QuantizeSqueeze.quantize(
|
|
63
|
+
self,
|
|
64
|
+
node,
|
|
65
|
+
graph,
|
|
66
|
+
scale_config,
|
|
67
|
+
initializer_map,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
_N_INPUTS_NO_AXES: ClassVar[int] = 1
|
|
71
|
+
_N_INPUTS_WITH_AXES: ClassVar[int] = 2
|
|
72
|
+
|
|
73
|
+
def _get_axes_from_attribute(self, node: onnx.NodeProto) -> list[int] | None:
|
|
74
|
+
for attr in node.attribute:
|
|
75
|
+
if attr.name == "axes":
|
|
76
|
+
return list(attr.ints)
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
def _get_axes_from_initializer_input(
|
|
80
|
+
self,
|
|
81
|
+
node: onnx.NodeProto,
|
|
82
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
83
|
+
) -> list[int]:
|
|
84
|
+
axes_name = node.input[1]
|
|
85
|
+
if axes_name not in initializer_map:
|
|
86
|
+
raise InvalidParamError(
|
|
87
|
+
node_name=node.name,
|
|
88
|
+
op_type=node.op_type,
|
|
89
|
+
message=(
|
|
90
|
+
f"Dynamic axes input is not supported for Squeeze "
|
|
91
|
+
f"(expected axes '{axes_name}' to be an initializer)."
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
axes_tensor = initializer_map[axes_name]
|
|
96
|
+
arr = numpy_helper.to_array(axes_tensor)
|
|
97
|
+
|
|
98
|
+
if not np.issubdtype(arr.dtype, np.integer):
|
|
99
|
+
raise InvalidParamError(
|
|
100
|
+
node_name=node.name,
|
|
101
|
+
op_type=node.op_type,
|
|
102
|
+
message=f"Squeeze axes initializer must be integer, got {arr.dtype}.",
|
|
103
|
+
attr_key="axes",
|
|
104
|
+
expected="integer tensor (0-D or 1-D)",
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if arr.ndim == 0:
|
|
108
|
+
return [int(arr)]
|
|
109
|
+
if arr.ndim == 1:
|
|
110
|
+
return [int(x) for x in arr.tolist()]
|
|
111
|
+
|
|
112
|
+
raise InvalidParamError(
|
|
113
|
+
node_name=node.name,
|
|
114
|
+
op_type=node.op_type,
|
|
115
|
+
message=f"Squeeze axes initializer must be 0-D or 1-D, got {arr.ndim}-D.",
|
|
116
|
+
attr_key="axes",
|
|
117
|
+
expected="0-D scalar or 1-D list of axes",
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def check_supported(
|
|
121
|
+
self: SqueezeQuantizer,
|
|
122
|
+
node: onnx.NodeProto,
|
|
123
|
+
initializer_map: dict[str, onnx.TensorProto] | None = None,
|
|
124
|
+
) -> None:
|
|
125
|
+
self.validate_node_has_output(node)
|
|
126
|
+
initializer_map = initializer_map or {}
|
|
127
|
+
|
|
128
|
+
n_inputs = len(node.input)
|
|
129
|
+
if n_inputs not in (self._N_INPUTS_NO_AXES, self._N_INPUTS_WITH_AXES):
|
|
130
|
+
raise InvalidParamError(
|
|
131
|
+
node_name=node.name,
|
|
132
|
+
op_type=node.op_type,
|
|
133
|
+
message=f"Squeeze expects 1 or 2 inputs, got {n_inputs}.",
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
axes = self._get_axes_from_attribute(node)
|
|
137
|
+
|
|
138
|
+
# If axes is provided as a second input, it must be a constant initializer.
|
|
139
|
+
if axes is None and n_inputs == self._N_INPUTS_WITH_AXES:
|
|
140
|
+
axes = self._get_axes_from_initializer_input(node, initializer_map)
|
|
141
|
+
|
|
142
|
+
# If axes is omitted entirely, ONNX semantics are "remove all dims of size 1".
|
|
143
|
+
# We can't validate legality here without rank/shape; defer to Rust.
|
|
144
|
+
|
|
145
|
+
if axes is None:
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
if len(set(axes)) != len(axes):
|
|
149
|
+
raise InvalidParamError(
|
|
150
|
+
node_name=node.name,
|
|
151
|
+
op_type=node.op_type,
|
|
152
|
+
message=f"axes must not contain duplicates: {axes}",
|
|
153
|
+
attr_key="axes",
|
|
154
|
+
expected="axes list with unique entries",
|
|
155
|
+
)
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from onnx import helper, numpy_helper
|
|
7
|
+
|
|
8
|
+
from python.core.model_processing.converters.base import ModelType
|
|
9
|
+
from python.core.model_processing.errors import LayerAnalysisError
|
|
10
|
+
from python.core.model_processing.onnx_custom_ops.onnx_helpers import parse_attributes
|
|
11
|
+
from python.core.model_processing.onnx_quantizer.exceptions import (
|
|
12
|
+
InvalidParamError,
|
|
13
|
+
)
|
|
14
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
15
|
+
BaseOpQuantizer,
|
|
16
|
+
QuantizerBase,
|
|
17
|
+
ScaleConfig,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
import onnx
|
|
22
|
+
|
|
23
|
+
_N_UNSQUEEZE_INPUTS: int = 2
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class QuantizeUnsqueeze(QuantizerBase):
|
|
27
|
+
OP_TYPE = "Unsqueeze"
|
|
28
|
+
DOMAIN = ""
|
|
29
|
+
USE_WB = False
|
|
30
|
+
USE_SCALING = False
|
|
31
|
+
# Only the data input is relevant for scale-planning.
|
|
32
|
+
SCALE_PLAN: ClassVar = {0: 1}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class UnsqueezeQuantizer(BaseOpQuantizer, QuantizeUnsqueeze):
|
|
36
|
+
"""
|
|
37
|
+
Quantizer for ONNX Unsqueeze.
|
|
38
|
+
|
|
39
|
+
Unsqueeze is scale-preserving (pure shape/view transform):
|
|
40
|
+
- No arithmetic
|
|
41
|
+
- No rescaling
|
|
42
|
+
- No custom op
|
|
43
|
+
|
|
44
|
+
Semantics:
|
|
45
|
+
- Inserts new dimensions of size 1 at the specified axes positions.
|
|
46
|
+
|
|
47
|
+
We support:
|
|
48
|
+
- axes as an attribute (older opsets)
|
|
49
|
+
- axes as a constant initializer input (opset >= 13 style)
|
|
50
|
+
|
|
51
|
+
We do NOT support dynamic axes provided at runtime.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self: UnsqueezeQuantizer,
|
|
56
|
+
new_initializers: list[onnx.TensorProto] | None = None,
|
|
57
|
+
) -> None:
|
|
58
|
+
super().__init__()
|
|
59
|
+
if new_initializers is not None:
|
|
60
|
+
self.new_initializers = new_initializers
|
|
61
|
+
|
|
62
|
+
def quantize(
|
|
63
|
+
self: UnsqueezeQuantizer,
|
|
64
|
+
node: onnx.NodeProto,
|
|
65
|
+
graph: onnx.GraphProto,
|
|
66
|
+
scale_config: ScaleConfig,
|
|
67
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
68
|
+
) -> list[onnx.NodeProto]:
|
|
69
|
+
# Pure passthrough; QuantizerBase handles standard bookkeeping.
|
|
70
|
+
return QuantizeUnsqueeze.quantize(
|
|
71
|
+
self,
|
|
72
|
+
node,
|
|
73
|
+
graph,
|
|
74
|
+
scale_config,
|
|
75
|
+
initializer_map,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def pre_analysis_transform(
|
|
79
|
+
self: UnsqueezeQuantizer,
|
|
80
|
+
node: onnx.NodeProto,
|
|
81
|
+
graph: onnx.GraphProto,
|
|
82
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
83
|
+
scale_base: int,
|
|
84
|
+
scale_exponent: int,
|
|
85
|
+
) -> None:
|
|
86
|
+
_ = initializer_map, scale_base, scale_exponent
|
|
87
|
+
model_type = ModelType.ONNX
|
|
88
|
+
params = parse_attributes(node.attribute)
|
|
89
|
+
if node.op_type != "Unsqueeze":
|
|
90
|
+
return
|
|
91
|
+
if params and "axes" in params:
|
|
92
|
+
return
|
|
93
|
+
axes = _extract_unsqueeze_axes_into_params(
|
|
94
|
+
name=node.name,
|
|
95
|
+
inputs=node.input,
|
|
96
|
+
params=params,
|
|
97
|
+
graph=graph,
|
|
98
|
+
model_type=model_type,
|
|
99
|
+
initializer_map=initializer_map,
|
|
100
|
+
)
|
|
101
|
+
attr = helper.make_attribute("axes", axes["axes"])
|
|
102
|
+
node.attribute.append(attr)
|
|
103
|
+
|
|
104
|
+
_N_INPUTS_NO_AXES: ClassVar[int] = 1
|
|
105
|
+
_N_INPUTS_WITH_AXES: ClassVar[int] = 2
|
|
106
|
+
|
|
107
|
+
def _get_axes_from_attribute(self, node: onnx.NodeProto) -> list[int] | None:
|
|
108
|
+
for attr in node.attribute:
|
|
109
|
+
if attr.name == "axes":
|
|
110
|
+
return list(attr.ints)
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
def _get_axes_from_initializer_input(
|
|
114
|
+
self,
|
|
115
|
+
node: onnx.NodeProto,
|
|
116
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
117
|
+
) -> list[int]:
|
|
118
|
+
axes_name = node.input[1]
|
|
119
|
+
if axes_name not in initializer_map:
|
|
120
|
+
raise InvalidParamError(
|
|
121
|
+
node_name=node.name,
|
|
122
|
+
op_type=node.op_type,
|
|
123
|
+
message=(
|
|
124
|
+
f"Dynamic axes input is not supported for Unsqueeze "
|
|
125
|
+
f"(expected axes '{axes_name}' to be an initializer)."
|
|
126
|
+
),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
axes_tensor = initializer_map[axes_name]
|
|
130
|
+
arr = numpy_helper.to_array(axes_tensor)
|
|
131
|
+
|
|
132
|
+
if not np.issubdtype(arr.dtype, np.integer):
|
|
133
|
+
raise InvalidParamError(
|
|
134
|
+
node_name=node.name,
|
|
135
|
+
op_type=node.op_type,
|
|
136
|
+
message=f"Unsqueeze axes initializer must be integer, got {arr.dtype}.",
|
|
137
|
+
attr_key="axes",
|
|
138
|
+
expected="integer tensor (0-D or 1-D)",
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
if arr.ndim == 0:
|
|
142
|
+
return [int(arr)]
|
|
143
|
+
if arr.ndim == 1:
|
|
144
|
+
return [int(x) for x in arr.tolist()]
|
|
145
|
+
|
|
146
|
+
raise InvalidParamError(
|
|
147
|
+
node_name=node.name,
|
|
148
|
+
op_type=node.op_type,
|
|
149
|
+
message=f"Unsqueeze axes initializer must be 0-D or 1-D, got {arr.ndim}-D.",
|
|
150
|
+
attr_key="axes",
|
|
151
|
+
expected="0-D scalar or 1-D list of axes",
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def check_supported(
|
|
155
|
+
self: UnsqueezeQuantizer,
|
|
156
|
+
node: onnx.NodeProto,
|
|
157
|
+
initializer_map: dict[str, onnx.TensorProto] | None = None,
|
|
158
|
+
) -> None:
|
|
159
|
+
self.validate_node_has_output(node)
|
|
160
|
+
initializer_map = initializer_map or {}
|
|
161
|
+
|
|
162
|
+
n_inputs = len(node.input)
|
|
163
|
+
if n_inputs not in (self._N_INPUTS_NO_AXES, self._N_INPUTS_WITH_AXES):
|
|
164
|
+
raise InvalidParamError(
|
|
165
|
+
node_name=node.name,
|
|
166
|
+
op_type=node.op_type,
|
|
167
|
+
message=(
|
|
168
|
+
"Unsqueeze expects either 1 input (axes as attribute) or 2 inputs "
|
|
169
|
+
f"(axes as initializer), got {n_inputs}."
|
|
170
|
+
),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
axes = self._get_axes_from_attribute(node)
|
|
174
|
+
|
|
175
|
+
# ONNX Unsqueeze has two schema styles:
|
|
176
|
+
# - newer: Unsqueeze(data, axes) -> 2 inputs, axes is initializer input
|
|
177
|
+
# - older: Unsqueeze(data) with axes attribute -> 1 input
|
|
178
|
+
if n_inputs == self._N_INPUTS_NO_AXES:
|
|
179
|
+
if axes is None:
|
|
180
|
+
raise InvalidParamError(
|
|
181
|
+
node_name=node.name,
|
|
182
|
+
op_type=node.op_type,
|
|
183
|
+
message=(
|
|
184
|
+
"Unsqueeze with 1 input is only supported when 'axes' is "
|
|
185
|
+
"provided as an attribute (older opsets)."
|
|
186
|
+
),
|
|
187
|
+
attr_key="axes",
|
|
188
|
+
expected="axes attribute",
|
|
189
|
+
)
|
|
190
|
+
elif axes is None:
|
|
191
|
+
axes = self._get_axes_from_initializer_input(node, initializer_map)
|
|
192
|
+
|
|
193
|
+
if axes is None:
|
|
194
|
+
raise InvalidParamError(
|
|
195
|
+
node_name=node.name,
|
|
196
|
+
op_type=node.op_type,
|
|
197
|
+
message="Unsqueeze requires 'axes' to be provided.",
|
|
198
|
+
attr_key="axes",
|
|
199
|
+
expected="axes attribute or initializer input",
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if len(set(axes)) != len(axes):
|
|
203
|
+
raise InvalidParamError(
|
|
204
|
+
node_name=node.name,
|
|
205
|
+
op_type=node.op_type,
|
|
206
|
+
message=f"axes must not contain duplicates: {axes}",
|
|
207
|
+
attr_key="axes",
|
|
208
|
+
expected="axes list with unique entries",
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _extract_unsqueeze_axes_into_params( # noqa: PLR0913
|
|
213
|
+
*,
|
|
214
|
+
name: str,
|
|
215
|
+
inputs: list[str] | tuple[str, ...],
|
|
216
|
+
params: dict | None,
|
|
217
|
+
graph: onnx.GraphProto,
|
|
218
|
+
model_type: ModelType,
|
|
219
|
+
initializer_map: dict[str, onnx.TensorProto] | None = None,
|
|
220
|
+
) -> dict:
|
|
221
|
+
if len(inputs) != _N_UNSQUEEZE_INPUTS:
|
|
222
|
+
msg = (
|
|
223
|
+
f"Unsqueeze '{name}' is missing axes input. "
|
|
224
|
+
f"Expected 2 inputs (data, axes), got {len(inputs)}: {list(inputs)}"
|
|
225
|
+
)
|
|
226
|
+
raise LayerAnalysisError(model_type=model_type, reason=msg)
|
|
227
|
+
|
|
228
|
+
axes_name = inputs[1]
|
|
229
|
+
|
|
230
|
+
axes_arr = _resolve_unsqueeze_axes_array(
|
|
231
|
+
name=name,
|
|
232
|
+
axes_name=axes_name,
|
|
233
|
+
graph=graph,
|
|
234
|
+
model_type=model_type,
|
|
235
|
+
initializer_map=initializer_map,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
_validate_unsqueeze_axes_are_integer(
|
|
239
|
+
name=name,
|
|
240
|
+
axes_arr=axes_arr,
|
|
241
|
+
model_type=model_type,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
out_params = params or {}
|
|
245
|
+
out_params["axes"] = _axes_array_to_int_list(axes_arr)
|
|
246
|
+
return out_params
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _resolve_unsqueeze_axes_array(
|
|
250
|
+
*,
|
|
251
|
+
name: str,
|
|
252
|
+
axes_name: str,
|
|
253
|
+
graph: onnx.GraphProto,
|
|
254
|
+
model_type: ModelType,
|
|
255
|
+
initializer_map: dict[str, onnx.TensorProto] | None = None,
|
|
256
|
+
) -> np.ndarray:
|
|
257
|
+
if not initializer_map:
|
|
258
|
+
initializer_map = {init.name: init for init in graph.initializer}
|
|
259
|
+
|
|
260
|
+
if axes_name in initializer_map:
|
|
261
|
+
return numpy_helper.to_array(initializer_map[axes_name])
|
|
262
|
+
|
|
263
|
+
const_tensor = _find_constant_tensor_by_output_name(
|
|
264
|
+
graph=graph,
|
|
265
|
+
output_name=axes_name,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if const_tensor is not None:
|
|
269
|
+
return numpy_helper.to_array(const_tensor)
|
|
270
|
+
|
|
271
|
+
msg = (
|
|
272
|
+
f"Unsqueeze '{name}' has dynamic axes input '{axes_name}'. "
|
|
273
|
+
"Only constant initializer axes or Constant-node axes are supported."
|
|
274
|
+
)
|
|
275
|
+
raise LayerAnalysisError(model_type=model_type, reason=msg)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _find_constant_tensor_by_output_name(
|
|
279
|
+
*,
|
|
280
|
+
graph: onnx.GraphProto,
|
|
281
|
+
output_name: str,
|
|
282
|
+
) -> onnx.TensorProto | None:
|
|
283
|
+
for n in graph.node:
|
|
284
|
+
if n.op_type != "Constant" or not n.output:
|
|
285
|
+
continue
|
|
286
|
+
if n.output[0] != output_name:
|
|
287
|
+
continue
|
|
288
|
+
|
|
289
|
+
for attr in n.attribute:
|
|
290
|
+
if attr.name == "value" and attr.t is not None:
|
|
291
|
+
return attr.t
|
|
292
|
+
|
|
293
|
+
# Constant node exists but doesn't have the expected tensor attribute.
|
|
294
|
+
return None
|
|
295
|
+
|
|
296
|
+
return None
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _validate_unsqueeze_axes_are_integer(
|
|
300
|
+
*,
|
|
301
|
+
name: str,
|
|
302
|
+
axes_arr: np.ndarray,
|
|
303
|
+
model_type: ModelType,
|
|
304
|
+
) -> None:
|
|
305
|
+
if not np.issubdtype(axes_arr.dtype, np.integer):
|
|
306
|
+
msg = f"Unsqueeze '{name}' axes must be integer, got dtype {axes_arr.dtype}."
|
|
307
|
+
raise LayerAnalysisError(model_type=model_type, reason=msg)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _axes_array_to_int_list(axes_arr: np.ndarray) -> list[int]:
|
|
311
|
+
if axes_arr.ndim == 0:
|
|
312
|
+
return [int(axes_arr)]
|
|
313
|
+
return [int(x) for x in axes_arr.reshape(-1).tolist()]
|
|
@@ -31,7 +31,11 @@ from python.core.model_processing.onnx_quantizer.layers.maxpool import MaxpoolQu
|
|
|
31
31
|
from python.core.model_processing.onnx_quantizer.layers.min import MinQuantizer
|
|
32
32
|
from python.core.model_processing.onnx_quantizer.layers.mul import MulQuantizer
|
|
33
33
|
from python.core.model_processing.onnx_quantizer.layers.relu import ReluQuantizer
|
|
34
|
+
from python.core.model_processing.onnx_quantizer.layers.squeeze import SqueezeQuantizer
|
|
34
35
|
from python.core.model_processing.onnx_quantizer.layers.sub import SubQuantizer
|
|
36
|
+
from python.core.model_processing.onnx_quantizer.layers.unsqueeze import (
|
|
37
|
+
UnsqueezeQuantizer,
|
|
38
|
+
)
|
|
35
39
|
|
|
36
40
|
|
|
37
41
|
class ONNXOpQuantizer:
|
|
@@ -90,6 +94,8 @@ class ONNXOpQuantizer:
|
|
|
90
94
|
self.register("Max", MaxQuantizer(self.new_initializers))
|
|
91
95
|
self.register("Min", MinQuantizer(self.new_initializers))
|
|
92
96
|
self.register("BatchNormalization", BatchnormQuantizer(self.new_initializers))
|
|
97
|
+
self.register("Squeeze", SqueezeQuantizer(self.new_initializers))
|
|
98
|
+
self.register("Unsqueeze", UnsqueezeQuantizer(self.new_initializers))
|
|
93
99
|
|
|
94
100
|
def register(
|
|
95
101
|
self: ONNXOpQuantizer,
|
python/frontend/cli.py
CHANGED
|
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
|
|
12
12
|
from python.frontend.commands import BaseCommand
|
|
13
13
|
|
|
14
14
|
from python.frontend.commands import (
|
|
15
|
+
BatchCommand,
|
|
15
16
|
BenchCommand,
|
|
16
17
|
CompileCommand,
|
|
17
18
|
ModelCheckCommand,
|
|
@@ -41,6 +42,7 @@ COMMANDS: list[type[BaseCommand]] = [
|
|
|
41
42
|
WitnessCommand,
|
|
42
43
|
ProveCommand,
|
|
43
44
|
VerifyCommand,
|
|
45
|
+
BatchCommand,
|
|
44
46
|
BenchCommand,
|
|
45
47
|
]
|
|
46
48
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from python.frontend.commands.base import BaseCommand
|
|
2
|
+
from python.frontend.commands.batch import BatchCommand
|
|
2
3
|
from python.frontend.commands.bench import BenchCommand
|
|
3
4
|
from python.frontend.commands.compile import CompileCommand
|
|
4
5
|
from python.frontend.commands.model_check import ModelCheckCommand
|
|
@@ -8,6 +9,7 @@ from python.frontend.commands.witness import WitnessCommand
|
|
|
8
9
|
|
|
9
10
|
__all__ = [
|
|
10
11
|
"BaseCommand",
|
|
12
|
+
"BatchCommand",
|
|
11
13
|
"BenchCommand",
|
|
12
14
|
"CompileCommand",
|
|
13
15
|
"ModelCheckCommand",
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
5
|
+
|
|
6
|
+
from python.core.utils.helper_functions import (
|
|
7
|
+
read_from_json,
|
|
8
|
+
run_cargo_command,
|
|
9
|
+
to_json,
|
|
10
|
+
)
|
|
11
|
+
from python.frontend.commands.args import CIRCUIT_PATH
|
|
12
|
+
from python.frontend.commands.base import BaseCommand
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import argparse
|
|
16
|
+
from collections.abc import Callable
|
|
17
|
+
|
|
18
|
+
from python.core.circuits.base import Circuit
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _preprocess_manifest(
|
|
22
|
+
circuit: Circuit,
|
|
23
|
+
manifest_path: str,
|
|
24
|
+
circuit_path: str,
|
|
25
|
+
transform_job: Callable[[Circuit, dict[str, Any]], None],
|
|
26
|
+
) -> str:
|
|
27
|
+
circuit_file = Path(circuit_path)
|
|
28
|
+
quantized_path = str(
|
|
29
|
+
circuit_file.parent / f"{circuit_file.stem}_quantized_model.onnx",
|
|
30
|
+
)
|
|
31
|
+
circuit.load_quantized_model(quantized_path)
|
|
32
|
+
|
|
33
|
+
manifest: dict[str, Any] = read_from_json(manifest_path)
|
|
34
|
+
if not isinstance(manifest, dict) or not isinstance(
|
|
35
|
+
manifest.get("jobs"),
|
|
36
|
+
list,
|
|
37
|
+
):
|
|
38
|
+
msg = f"Invalid manifest: expected {{'jobs': [...]}} in {manifest_path}"
|
|
39
|
+
raise TypeError(msg)
|
|
40
|
+
for job in manifest["jobs"]:
|
|
41
|
+
transform_job(circuit, job)
|
|
42
|
+
|
|
43
|
+
manifest_file = Path(manifest_path)
|
|
44
|
+
processed_path = str(
|
|
45
|
+
manifest_file.with_name(
|
|
46
|
+
manifest_file.stem + "_processed" + manifest_file.suffix,
|
|
47
|
+
),
|
|
48
|
+
)
|
|
49
|
+
to_json(manifest, processed_path)
|
|
50
|
+
return processed_path
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _validate_job_keys(job: dict[str, Any], *keys: str) -> None:
|
|
54
|
+
missing = [k for k in keys if k not in job]
|
|
55
|
+
if missing:
|
|
56
|
+
msg = f"Job missing required keys {missing}: {job}"
|
|
57
|
+
raise ValueError(msg)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _transform_witness_job(circuit: Circuit, job: dict[str, Any]) -> None:
|
|
61
|
+
_validate_job_keys(job, "input", "output")
|
|
62
|
+
inputs = read_from_json(job["input"])
|
|
63
|
+
scaled = circuit.scale_inputs_only(inputs)
|
|
64
|
+
|
|
65
|
+
inference_inputs = circuit.reshape_inputs_for_inference(scaled)
|
|
66
|
+
circuit_inputs = circuit.reshape_inputs_for_circuit(scaled)
|
|
67
|
+
|
|
68
|
+
path = Path(job["input"])
|
|
69
|
+
adjusted_path = str(path.with_name(path.stem + "_adjusted" + path.suffix))
|
|
70
|
+
to_json(circuit_inputs, adjusted_path)
|
|
71
|
+
|
|
72
|
+
outputs = circuit.get_outputs(inference_inputs)
|
|
73
|
+
formatted = circuit.format_outputs(outputs)
|
|
74
|
+
to_json(formatted, job["output"])
|
|
75
|
+
|
|
76
|
+
job["input"] = adjusted_path
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _transform_verify_job(circuit: Circuit, job: dict[str, Any]) -> None:
|
|
80
|
+
_validate_job_keys(job, "input")
|
|
81
|
+
inputs = read_from_json(job["input"])
|
|
82
|
+
circuit_inputs = circuit.reshape_inputs_for_circuit(inputs)
|
|
83
|
+
|
|
84
|
+
path = Path(job["input"])
|
|
85
|
+
processed_path = str(path.with_name(path.stem + "_veri" + path.suffix))
|
|
86
|
+
to_json(circuit_inputs, processed_path)
|
|
87
|
+
|
|
88
|
+
job["input"] = processed_path
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class BatchCommand(BaseCommand):
|
|
92
|
+
"""Run batch operations on multiple inputs."""
|
|
93
|
+
|
|
94
|
+
name: ClassVar[str] = "batch"
|
|
95
|
+
aliases: ClassVar[list[str]] = []
|
|
96
|
+
help: ClassVar[str] = "Run batch witness/prove/verify operations."
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def _add_batch_args(cls, subparser: argparse.ArgumentParser) -> None:
|
|
100
|
+
CIRCUIT_PATH.add_to_parser(subparser)
|
|
101
|
+
subparser.add_argument(
|
|
102
|
+
"-f",
|
|
103
|
+
"--manifest",
|
|
104
|
+
required=True,
|
|
105
|
+
help="Path to batch manifest JSON file.",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def configure_parser(
|
|
110
|
+
cls: type[BatchCommand],
|
|
111
|
+
parser: argparse.ArgumentParser,
|
|
112
|
+
) -> None:
|
|
113
|
+
subparsers = parser.add_subparsers(dest="batch_mode", required=True)
|
|
114
|
+
|
|
115
|
+
for name, help_text in [
|
|
116
|
+
("witness", "Generate witnesses for multiple inputs."),
|
|
117
|
+
("prove", "Generate proofs for multiple witnesses."),
|
|
118
|
+
("verify", "Verify multiple proofs."),
|
|
119
|
+
]:
|
|
120
|
+
sub = subparsers.add_parser(name, help=help_text)
|
|
121
|
+
cls._add_batch_args(sub)
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def run(cls: type[BatchCommand], args: argparse.Namespace) -> None:
|
|
125
|
+
circuit_path = getattr(args, "circuit_path", None) or getattr(
|
|
126
|
+
args,
|
|
127
|
+
"pos_circuit_path",
|
|
128
|
+
None,
|
|
129
|
+
)
|
|
130
|
+
if not circuit_path:
|
|
131
|
+
msg = "Missing required argument: circuit_path"
|
|
132
|
+
raise ValueError(msg)
|
|
133
|
+
|
|
134
|
+
circuit_file = Path(circuit_path)
|
|
135
|
+
if not circuit_file.is_file():
|
|
136
|
+
msg = f"Circuit file not found: {circuit_path}"
|
|
137
|
+
raise FileNotFoundError(msg)
|
|
138
|
+
|
|
139
|
+
manifest_path = args.manifest
|
|
140
|
+
if not Path(manifest_path).is_file():
|
|
141
|
+
msg = f"Manifest file not found: {manifest_path}"
|
|
142
|
+
raise FileNotFoundError(msg)
|
|
143
|
+
|
|
144
|
+
batch_mode = args.batch_mode
|
|
145
|
+
|
|
146
|
+
circuit = cls._build_circuit("cli")
|
|
147
|
+
|
|
148
|
+
run_type_map = {
|
|
149
|
+
"witness": "run_batch_witness",
|
|
150
|
+
"prove": "run_batch_prove",
|
|
151
|
+
"verify": "run_batch_verify",
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
run_type_str = run_type_map.get(batch_mode)
|
|
155
|
+
if not run_type_str:
|
|
156
|
+
msg = f"Unknown batch mode: {batch_mode}"
|
|
157
|
+
raise ValueError(msg)
|
|
158
|
+
|
|
159
|
+
circuit_dir = circuit_file.parent
|
|
160
|
+
name = circuit_file.stem
|
|
161
|
+
metadata_path = str(circuit_dir / f"{name}_metadata.json")
|
|
162
|
+
if not Path(metadata_path).is_file():
|
|
163
|
+
msg = f"Metadata file not found: {metadata_path}"
|
|
164
|
+
raise FileNotFoundError(msg)
|
|
165
|
+
|
|
166
|
+
preprocess_map = {
|
|
167
|
+
"witness": _transform_witness_job,
|
|
168
|
+
"verify": _transform_verify_job,
|
|
169
|
+
}
|
|
170
|
+
if batch_mode in preprocess_map:
|
|
171
|
+
manifest_path = _preprocess_manifest(
|
|
172
|
+
circuit,
|
|
173
|
+
manifest_path,
|
|
174
|
+
circuit_path,
|
|
175
|
+
preprocess_map[batch_mode],
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
run_cargo_command(
|
|
180
|
+
binary_name=circuit.name,
|
|
181
|
+
command_type=run_type_str,
|
|
182
|
+
args={
|
|
183
|
+
"c": circuit_path,
|
|
184
|
+
"f": manifest_path,
|
|
185
|
+
"m": metadata_path,
|
|
186
|
+
},
|
|
187
|
+
dev_mode=False,
|
|
188
|
+
)
|
|
189
|
+
except Exception as e:
|
|
190
|
+
raise RuntimeError(e) from e
|
|
191
|
+
|
|
192
|
+
print(f"[batch {batch_mode}] complete") # noqa: T201
|
|
@@ -98,7 +98,7 @@ class LayerTestConfig:
|
|
|
98
98
|
combined_inits = {**self.required_initializers, **initializer_overrides}
|
|
99
99
|
for name, data in combined_inits.items():
|
|
100
100
|
# Special handling for shape tensors in Reshape, etc.
|
|
101
|
-
if name
|
|
101
|
+
if name in {"shape", "axes"}:
|
|
102
102
|
tensor = numpy_helper.from_array(data.astype(np.int64), name=name)
|
|
103
103
|
else:
|
|
104
104
|
tensor = numpy_helper.from_array(data.astype(np.float32), name=name)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
|
|
6
|
+
from python.tests.onnx_quantizer_tests.layers.base import (
|
|
7
|
+
e2e_test,
|
|
8
|
+
error_test,
|
|
9
|
+
valid_test,
|
|
10
|
+
)
|
|
11
|
+
from python.tests.onnx_quantizer_tests.layers.factory import (
|
|
12
|
+
BaseLayerConfigProvider,
|
|
13
|
+
LayerTestConfig,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SqueezeConfigProvider(BaseLayerConfigProvider):
|
|
18
|
+
"""Test configuration provider for Squeeze"""
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def layer_name(self) -> str:
|
|
22
|
+
return "Squeeze"
|
|
23
|
+
|
|
24
|
+
def get_config(self) -> LayerTestConfig:
|
|
25
|
+
# Test opset-newer form: Squeeze(data, axes) where axes is an int64 initializer.
|
|
26
|
+
return LayerTestConfig(
|
|
27
|
+
op_type="Squeeze",
|
|
28
|
+
valid_inputs=["A", "axes"],
|
|
29
|
+
valid_attributes={}, # no attribute-based axes
|
|
30
|
+
required_initializers={},
|
|
31
|
+
input_shapes={
|
|
32
|
+
"A": [1, 3, 1, 5],
|
|
33
|
+
# "axes" is removed from graph inputs automatically
|
|
34
|
+
# when it is an initializer.
|
|
35
|
+
"axes": [2],
|
|
36
|
+
},
|
|
37
|
+
output_shapes={
|
|
38
|
+
"squeeze_output": [3, 5],
|
|
39
|
+
},
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def get_test_specs(self) -> list:
|
|
43
|
+
|
|
44
|
+
return [
|
|
45
|
+
# --- VALID TESTS ---
|
|
46
|
+
valid_test("axes_omitted")
|
|
47
|
+
.description("Squeeze with no axes input: removes all dims of size 1")
|
|
48
|
+
.override_inputs("A") # only data input
|
|
49
|
+
.override_input_shapes(A=[1, 3, 1, 5])
|
|
50
|
+
.override_output_shapes(squeeze_output=[3, 5])
|
|
51
|
+
.tags("basic", "squeeze", "axes_omitted")
|
|
52
|
+
.build(),
|
|
53
|
+
valid_test("axes_init_basic")
|
|
54
|
+
.description("Squeeze with axes initializer [0,2] on [1,3,1,5] -> [3,5]")
|
|
55
|
+
.override_inputs("A", "axes")
|
|
56
|
+
.override_initializer("axes", np.array([0, 2], dtype=np.int64))
|
|
57
|
+
.override_input_shapes(A=[1, 3, 1, 5])
|
|
58
|
+
.override_output_shapes(squeeze_output=[3, 5])
|
|
59
|
+
.tags("basic", "squeeze", "axes_initializer")
|
|
60
|
+
.build(),
|
|
61
|
+
valid_test("axes_init_singleton_middle")
|
|
62
|
+
.description("Squeeze with axes initializer [1] on [2,1,4] -> [2,4]")
|
|
63
|
+
.override_inputs("A", "axes")
|
|
64
|
+
.override_initializer("axes", np.array([1], dtype=np.int64))
|
|
65
|
+
.override_input_shapes(A=[2, 1, 4])
|
|
66
|
+
.override_output_shapes(squeeze_output=[2, 4])
|
|
67
|
+
.tags("squeeze", "axes_initializer")
|
|
68
|
+
.build(),
|
|
69
|
+
valid_test("axes_init_negative")
|
|
70
|
+
.description("Squeeze with axes initializer [-2] on [2,1,4] -> [2,4]")
|
|
71
|
+
.override_inputs("A", "axes")
|
|
72
|
+
.override_initializer("axes", np.array([-2], dtype=np.int64))
|
|
73
|
+
.override_input_shapes(A=[2, 1, 4])
|
|
74
|
+
.override_output_shapes(squeeze_output=[2, 4])
|
|
75
|
+
.tags("squeeze", "axes_initializer", "negative_axis")
|
|
76
|
+
.build(),
|
|
77
|
+
# --- ERROR TESTS ---
|
|
78
|
+
error_test("duplicate_axes_init")
|
|
79
|
+
.description("Duplicate axes in initializer should be rejected")
|
|
80
|
+
.override_inputs("A", "axes")
|
|
81
|
+
.override_initializer("axes", np.array([1, 1], dtype=np.int64))
|
|
82
|
+
.override_input_shapes(A=[2, 1, 4])
|
|
83
|
+
.override_output_shapes(squeeze_output=[2, 4])
|
|
84
|
+
.expects_error(InvalidParamError, match="axes must not contain duplicates")
|
|
85
|
+
.tags("error", "squeeze", "axes_initializer")
|
|
86
|
+
.build(),
|
|
87
|
+
error_test("dynamic_axes_input_not_supported")
|
|
88
|
+
.description(
|
|
89
|
+
"Squeeze with runtime axes (2 inputs but axes is NOT an initializer) "
|
|
90
|
+
"should be rejected",
|
|
91
|
+
)
|
|
92
|
+
.override_inputs("A", "axes") # axes provided as graph input (unsupported)
|
|
93
|
+
.override_input_shapes(A=[1, 3, 1, 5], axes=[2])
|
|
94
|
+
.override_output_shapes(squeeze_output=[3, 5])
|
|
95
|
+
.expects_error(
|
|
96
|
+
InvalidParamError,
|
|
97
|
+
match="Dynamic axes input is not supported",
|
|
98
|
+
)
|
|
99
|
+
.tags("error", "squeeze", "axes_input")
|
|
100
|
+
.build(),
|
|
101
|
+
# --- E2E TESTS ---
|
|
102
|
+
e2e_test("e2e_axes_omitted")
|
|
103
|
+
.description("End-to-end Squeeze test (axes omitted)")
|
|
104
|
+
.override_inputs("A")
|
|
105
|
+
.override_input_shapes(A=[1, 3, 1, 5])
|
|
106
|
+
.override_output_shapes(squeeze_output=[3, 5])
|
|
107
|
+
.tags("e2e", "squeeze")
|
|
108
|
+
.build(),
|
|
109
|
+
e2e_test("e2e_axes_init")
|
|
110
|
+
.description("End-to-end Squeeze test (axes initializer)")
|
|
111
|
+
.override_inputs("A", "axes")
|
|
112
|
+
.override_initializer("axes", np.array([0, 2], dtype=np.int64))
|
|
113
|
+
.override_input_shapes(A=[1, 3, 1, 5])
|
|
114
|
+
.override_output_shapes(squeeze_output=[3, 5])
|
|
115
|
+
.tags("e2e", "squeeze", "axes_initializer")
|
|
116
|
+
.build(),
|
|
117
|
+
]
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
|
|
6
|
+
from python.tests.onnx_quantizer_tests.layers.base import (
|
|
7
|
+
e2e_test,
|
|
8
|
+
error_test,
|
|
9
|
+
valid_test,
|
|
10
|
+
)
|
|
11
|
+
from python.tests.onnx_quantizer_tests.layers.factory import (
|
|
12
|
+
BaseLayerConfigProvider,
|
|
13
|
+
LayerTestConfig,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UnsqueezeConfigProvider(BaseLayerConfigProvider):
|
|
18
|
+
"""Test configuration provider for Unsqueeze"""
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def layer_name(self) -> str:
|
|
22
|
+
return "Unsqueeze"
|
|
23
|
+
|
|
24
|
+
def get_config(self) -> LayerTestConfig:
|
|
25
|
+
# Test opset-newer form: Unsqueeze(data, axes)
|
|
26
|
+
# where axes is an int64 initializer.
|
|
27
|
+
return LayerTestConfig(
|
|
28
|
+
op_type="Unsqueeze",
|
|
29
|
+
valid_inputs=["A", "axes"],
|
|
30
|
+
valid_attributes={}, # no attribute-based axes
|
|
31
|
+
required_initializers={},
|
|
32
|
+
input_shapes={
|
|
33
|
+
"A": [3, 5],
|
|
34
|
+
# "axes" will be removed from graph inputs automatically
|
|
35
|
+
# when it is an initializer.
|
|
36
|
+
"axes": [2],
|
|
37
|
+
},
|
|
38
|
+
output_shapes={
|
|
39
|
+
"unsqueeze_output": [1, 3, 1, 5],
|
|
40
|
+
},
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def get_test_specs(self) -> list:
|
|
44
|
+
|
|
45
|
+
return [
|
|
46
|
+
# --- VALID TESTS ---
|
|
47
|
+
valid_test("axes_init_basic")
|
|
48
|
+
.description("Unsqueeze with axes initializer [0,2] on [3,5] -> [1,3,1,5]")
|
|
49
|
+
.override_inputs("A", "axes")
|
|
50
|
+
.override_initializer("axes", np.array([0, 2], dtype=np.int64))
|
|
51
|
+
.override_input_shapes(A=[3, 5])
|
|
52
|
+
.override_output_shapes(unsqueeze_output=[1, 3, 1, 5])
|
|
53
|
+
.tags("basic", "unsqueeze", "axes_initializer")
|
|
54
|
+
.build(),
|
|
55
|
+
valid_test("axes_init_single_axis")
|
|
56
|
+
.description("Unsqueeze with axes initializer [1] on [3,5] -> [3,1,5]")
|
|
57
|
+
.override_inputs("A", "axes")
|
|
58
|
+
.override_initializer("axes", np.array([1], dtype=np.int64))
|
|
59
|
+
.override_input_shapes(A=[3, 5])
|
|
60
|
+
.override_output_shapes(unsqueeze_output=[3, 1, 5])
|
|
61
|
+
.tags("unsqueeze", "axes_initializer")
|
|
62
|
+
.build(),
|
|
63
|
+
valid_test("axes_init_negative")
|
|
64
|
+
.description("Unsqueeze with negative axis [-1] on [3,5] -> [3,5,1]")
|
|
65
|
+
.override_inputs("A", "axes")
|
|
66
|
+
.override_initializer("axes", np.array([-1], dtype=np.int64))
|
|
67
|
+
.override_input_shapes(A=[3, 5])
|
|
68
|
+
.override_output_shapes(unsqueeze_output=[3, 5, 1])
|
|
69
|
+
.tags("unsqueeze", "axes_initializer", "negative_axis")
|
|
70
|
+
.build(),
|
|
71
|
+
valid_test("axes_init_two_axes_append")
|
|
72
|
+
.description("Unsqueeze with axes [2,3] on [3,5] -> [3,5,1,1]")
|
|
73
|
+
.override_inputs("A", "axes")
|
|
74
|
+
.override_initializer("axes", np.array([2, 3], dtype=np.int64))
|
|
75
|
+
.override_input_shapes(A=[3, 5])
|
|
76
|
+
.override_output_shapes(unsqueeze_output=[3, 5, 1, 1])
|
|
77
|
+
.tags("unsqueeze", "axes_initializer")
|
|
78
|
+
.build(),
|
|
79
|
+
# --- ERROR TESTS ---
|
|
80
|
+
error_test("duplicate_axes_init")
|
|
81
|
+
.description("Duplicate axes in initializer should be rejected")
|
|
82
|
+
.override_inputs("A", "axes")
|
|
83
|
+
.override_initializer("axes", np.array([1, 1], dtype=np.int64))
|
|
84
|
+
.override_input_shapes(A=[3, 5])
|
|
85
|
+
.override_output_shapes(
|
|
86
|
+
unsqueeze_output=[3, 1, 5],
|
|
87
|
+
) # not used; kept consistent
|
|
88
|
+
.expects_error(InvalidParamError, match="axes must not contain duplicates")
|
|
89
|
+
.tags("error", "unsqueeze", "axes_initializer")
|
|
90
|
+
.build(),
|
|
91
|
+
error_test("dynamic_axes_input_not_supported")
|
|
92
|
+
.description(
|
|
93
|
+
"Unsqueeze with runtime axes (2 inputs but axes is NOT an initializer) "
|
|
94
|
+
"should be rejected",
|
|
95
|
+
)
|
|
96
|
+
.override_inputs("A", "axes") # axes provided as graph input (unsupported)
|
|
97
|
+
.override_input_shapes(A=[3, 5], axes=[2])
|
|
98
|
+
.override_output_shapes(unsqueeze_output=[1, 3, 1, 5])
|
|
99
|
+
.expects_error(
|
|
100
|
+
InvalidParamError,
|
|
101
|
+
match="Dynamic axes input is not supported",
|
|
102
|
+
)
|
|
103
|
+
.tags("error", "unsqueeze", "axes_input")
|
|
104
|
+
.build(),
|
|
105
|
+
# --- E2E TESTS ---
|
|
106
|
+
e2e_test("e2e_axes_init")
|
|
107
|
+
.description("End-to-end Unsqueeze test (axes initializer)")
|
|
108
|
+
.override_inputs("A", "axes")
|
|
109
|
+
.override_initializer("axes", np.array([0, 2], dtype=np.int64))
|
|
110
|
+
.override_input_shapes(A=[3, 5])
|
|
111
|
+
.override_output_shapes(unsqueeze_output=[1, 3, 1, 5])
|
|
112
|
+
.tags("e2e", "unsqueeze", "axes_initializer")
|
|
113
|
+
.build(),
|
|
114
|
+
]
|
python/tests/test_cli.py
CHANGED
|
@@ -1019,3 +1019,186 @@ def test_bench_with_iterations(tmp_path: Path) -> None:
|
|
|
1019
1019
|
assert "--iterations" in cmd
|
|
1020
1020
|
idx = cmd.index("--iterations")
|
|
1021
1021
|
assert cmd[idx + 1] == "10"
|
|
1022
|
+
|
|
1023
|
+
|
|
1024
|
+
@pytest.mark.unit
|
|
1025
|
+
def test_batch_prove_dispatch(tmp_path: Path) -> None:
|
|
1026
|
+
circuit = tmp_path / "circuit.txt"
|
|
1027
|
+
circuit.write_text("ok")
|
|
1028
|
+
|
|
1029
|
+
manifest = tmp_path / "manifest.json"
|
|
1030
|
+
manifest.write_text('{"jobs": []}')
|
|
1031
|
+
|
|
1032
|
+
metadata = tmp_path / "circuit_metadata.json"
|
|
1033
|
+
metadata.write_text("{}")
|
|
1034
|
+
|
|
1035
|
+
fake_circuit = MagicMock()
|
|
1036
|
+
fake_circuit.name = "test_circuit"
|
|
1037
|
+
|
|
1038
|
+
with (
|
|
1039
|
+
patch(
|
|
1040
|
+
"python.frontend.commands.batch.BatchCommand._build_circuit",
|
|
1041
|
+
return_value=fake_circuit,
|
|
1042
|
+
),
|
|
1043
|
+
patch(
|
|
1044
|
+
"python.frontend.commands.batch.run_cargo_command",
|
|
1045
|
+
) as mock_cargo,
|
|
1046
|
+
):
|
|
1047
|
+
rc = main(
|
|
1048
|
+
[
|
|
1049
|
+
"--no-banner",
|
|
1050
|
+
"batch",
|
|
1051
|
+
"prove",
|
|
1052
|
+
"-c",
|
|
1053
|
+
str(circuit),
|
|
1054
|
+
"-f",
|
|
1055
|
+
str(manifest),
|
|
1056
|
+
],
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
assert rc == 0
|
|
1060
|
+
mock_cargo.assert_called_once()
|
|
1061
|
+
call_kwargs = mock_cargo.call_args[1]
|
|
1062
|
+
assert call_kwargs["command_type"] == "run_batch_prove"
|
|
1063
|
+
assert call_kwargs["args"]["c"] == str(circuit)
|
|
1064
|
+
assert call_kwargs["args"]["f"] == str(manifest)
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
@pytest.mark.unit
|
|
1068
|
+
def test_batch_verify_dispatch(tmp_path: Path) -> None:
|
|
1069
|
+
circuit = tmp_path / "circuit.txt"
|
|
1070
|
+
circuit.write_text("ok")
|
|
1071
|
+
|
|
1072
|
+
manifest = tmp_path / "manifest.json"
|
|
1073
|
+
manifest.write_text('{"jobs": []}')
|
|
1074
|
+
|
|
1075
|
+
metadata = tmp_path / "circuit_metadata.json"
|
|
1076
|
+
metadata.write_text("{}")
|
|
1077
|
+
|
|
1078
|
+
processed = tmp_path / "manifest_processed.json"
|
|
1079
|
+
processed.write_text('{"jobs": []}')
|
|
1080
|
+
|
|
1081
|
+
fake_circuit = MagicMock()
|
|
1082
|
+
fake_circuit.name = "test_circuit"
|
|
1083
|
+
|
|
1084
|
+
with (
|
|
1085
|
+
patch(
|
|
1086
|
+
"python.frontend.commands.batch.BatchCommand._build_circuit",
|
|
1087
|
+
return_value=fake_circuit,
|
|
1088
|
+
),
|
|
1089
|
+
patch(
|
|
1090
|
+
"python.frontend.commands.batch._preprocess_manifest",
|
|
1091
|
+
return_value=str(processed),
|
|
1092
|
+
) as mock_preprocess,
|
|
1093
|
+
patch(
|
|
1094
|
+
"python.frontend.commands.batch.run_cargo_command",
|
|
1095
|
+
) as mock_cargo,
|
|
1096
|
+
):
|
|
1097
|
+
rc = main(
|
|
1098
|
+
[
|
|
1099
|
+
"--no-banner",
|
|
1100
|
+
"batch",
|
|
1101
|
+
"verify",
|
|
1102
|
+
"-c",
|
|
1103
|
+
str(circuit),
|
|
1104
|
+
"-f",
|
|
1105
|
+
str(manifest),
|
|
1106
|
+
],
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
assert rc == 0
|
|
1110
|
+
mock_preprocess.assert_called_once()
|
|
1111
|
+
mock_cargo.assert_called_once()
|
|
1112
|
+
call_kwargs = mock_cargo.call_args[1]
|
|
1113
|
+
assert call_kwargs["command_type"] == "run_batch_verify"
|
|
1114
|
+
assert call_kwargs["args"]["f"] == str(processed)
|
|
1115
|
+
|
|
1116
|
+
|
|
1117
|
+
@pytest.mark.unit
|
|
1118
|
+
def test_batch_witness_dispatch(tmp_path: Path) -> None:
|
|
1119
|
+
circuit = tmp_path / "circuit.txt"
|
|
1120
|
+
circuit.write_text("ok")
|
|
1121
|
+
|
|
1122
|
+
manifest = tmp_path / "manifest.json"
|
|
1123
|
+
manifest.write_text('{"jobs": []}')
|
|
1124
|
+
|
|
1125
|
+
metadata = tmp_path / "circuit_metadata.json"
|
|
1126
|
+
metadata.write_text("{}")
|
|
1127
|
+
|
|
1128
|
+
processed = tmp_path / "manifest_processed.json"
|
|
1129
|
+
processed.write_text('{"jobs": []}')
|
|
1130
|
+
|
|
1131
|
+
fake_circuit = MagicMock()
|
|
1132
|
+
fake_circuit.name = "test_circuit"
|
|
1133
|
+
|
|
1134
|
+
with (
|
|
1135
|
+
patch(
|
|
1136
|
+
"python.frontend.commands.batch.BatchCommand._build_circuit",
|
|
1137
|
+
return_value=fake_circuit,
|
|
1138
|
+
),
|
|
1139
|
+
patch(
|
|
1140
|
+
"python.frontend.commands.batch._preprocess_manifest",
|
|
1141
|
+
return_value=str(processed),
|
|
1142
|
+
) as mock_preprocess,
|
|
1143
|
+
patch(
|
|
1144
|
+
"python.frontend.commands.batch.run_cargo_command",
|
|
1145
|
+
) as mock_cargo,
|
|
1146
|
+
):
|
|
1147
|
+
rc = main(
|
|
1148
|
+
[
|
|
1149
|
+
"--no-banner",
|
|
1150
|
+
"batch",
|
|
1151
|
+
"witness",
|
|
1152
|
+
"-c",
|
|
1153
|
+
str(circuit),
|
|
1154
|
+
"-f",
|
|
1155
|
+
str(manifest),
|
|
1156
|
+
],
|
|
1157
|
+
)
|
|
1158
|
+
|
|
1159
|
+
assert rc == 0
|
|
1160
|
+
mock_preprocess.assert_called_once()
|
|
1161
|
+
mock_cargo.assert_called_once()
|
|
1162
|
+
call_kwargs = mock_cargo.call_args[1]
|
|
1163
|
+
assert call_kwargs["command_type"] == "run_batch_witness"
|
|
1164
|
+
assert call_kwargs["args"]["f"] == str(processed)
|
|
1165
|
+
|
|
1166
|
+
|
|
1167
|
+
@pytest.mark.unit
|
|
1168
|
+
def test_batch_missing_circuit(tmp_path: Path) -> None:
|
|
1169
|
+
manifest = tmp_path / "manifest.json"
|
|
1170
|
+
manifest.write_text('{"jobs": []}')
|
|
1171
|
+
|
|
1172
|
+
rc = main(
|
|
1173
|
+
[
|
|
1174
|
+
"--no-banner",
|
|
1175
|
+
"batch",
|
|
1176
|
+
"prove",
|
|
1177
|
+
"-c",
|
|
1178
|
+
str(tmp_path / "nonexistent.txt"),
|
|
1179
|
+
"-f",
|
|
1180
|
+
str(manifest),
|
|
1181
|
+
],
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
assert rc == 1
|
|
1185
|
+
|
|
1186
|
+
|
|
1187
|
+
@pytest.mark.unit
|
|
1188
|
+
def test_batch_missing_manifest(tmp_path: Path) -> None:
|
|
1189
|
+
circuit = tmp_path / "circuit.txt"
|
|
1190
|
+
circuit.write_text("ok")
|
|
1191
|
+
|
|
1192
|
+
rc = main(
|
|
1193
|
+
[
|
|
1194
|
+
"--no-banner",
|
|
1195
|
+
"batch",
|
|
1196
|
+
"prove",
|
|
1197
|
+
"-c",
|
|
1198
|
+
str(circuit),
|
|
1199
|
+
"-f",
|
|
1200
|
+
str(tmp_path / "nonexistent.json"),
|
|
1201
|
+
],
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
assert rc == 1
|
|
Binary file
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|