JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +6 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
python/tests/test_cli.py
ADDED
|
@@ -0,0 +1,1021 @@
|
|
|
1
|
+
# python/testing/core/tests/test_cli.py
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from unittest.mock import MagicMock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from python.core.circuits.errors import CircuitRunError
|
|
8
|
+
from python.core.model_processing.onnx_quantizer.exceptions import (
|
|
9
|
+
UnsupportedOpError,
|
|
10
|
+
)
|
|
11
|
+
from python.core.utils.helper_functions import RunType
|
|
12
|
+
from python.frontend.cli import main
|
|
13
|
+
|
|
14
|
+
# -----------------------
|
|
15
|
+
# unit tests: dispatch only
|
|
16
|
+
# -----------------------
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@pytest.mark.unit
|
|
20
|
+
def test_witness_dispatch(tmp_path: Path) -> None:
|
|
21
|
+
# minimal files so _ensure_exists passes
|
|
22
|
+
circuit = tmp_path / "circuit.txt"
|
|
23
|
+
circuit.write_text("ok")
|
|
24
|
+
|
|
25
|
+
quant = tmp_path / "q.onnx"
|
|
26
|
+
quant.write_bytes(b"\x00")
|
|
27
|
+
|
|
28
|
+
inputj = tmp_path / "in.json"
|
|
29
|
+
inputj.write_text('{"input":[0]}')
|
|
30
|
+
|
|
31
|
+
outputj = tmp_path / "out.json" # doesn't need to pre-exist
|
|
32
|
+
witness = tmp_path / "w.bin" # doesn't need to pre-exist
|
|
33
|
+
|
|
34
|
+
fake_circuit = MagicMock()
|
|
35
|
+
with patch(
|
|
36
|
+
"python.frontend.commands.witness.WitnessCommand._build_circuit",
|
|
37
|
+
return_value=fake_circuit,
|
|
38
|
+
):
|
|
39
|
+
rc = main(
|
|
40
|
+
[
|
|
41
|
+
"--no-banner",
|
|
42
|
+
"witness",
|
|
43
|
+
"-c",
|
|
44
|
+
str(circuit),
|
|
45
|
+
"-i",
|
|
46
|
+
str(inputj),
|
|
47
|
+
"-o",
|
|
48
|
+
str(outputj),
|
|
49
|
+
"-w",
|
|
50
|
+
str(witness),
|
|
51
|
+
],
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
assert rc == 0
|
|
55
|
+
call_args = fake_circuit.base_testing.call_args
|
|
56
|
+
config = call_args[0][0]
|
|
57
|
+
assert config.run_type == RunType.GEN_WITNESS
|
|
58
|
+
assert config.circuit_path == str(circuit)
|
|
59
|
+
assert config.input_file == str(inputj)
|
|
60
|
+
assert config.output_file == str(outputj)
|
|
61
|
+
assert config.witness_file == str(witness)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.mark.unit
|
|
65
|
+
def test_witness_dispatch_positional(tmp_path: Path) -> None:
|
|
66
|
+
circuit = tmp_path / "circuit.txt"
|
|
67
|
+
circuit.write_text("ok")
|
|
68
|
+
|
|
69
|
+
quant = tmp_path / "q.onnx"
|
|
70
|
+
quant.write_bytes(b"\x00")
|
|
71
|
+
|
|
72
|
+
inputj = tmp_path / "in.json"
|
|
73
|
+
inputj.write_text('{"input":[0]}')
|
|
74
|
+
|
|
75
|
+
outputj = tmp_path / "out.json"
|
|
76
|
+
witness = tmp_path / "w.bin"
|
|
77
|
+
|
|
78
|
+
fake_circuit = MagicMock()
|
|
79
|
+
with patch(
|
|
80
|
+
"python.frontend.commands.witness.WitnessCommand._build_circuit",
|
|
81
|
+
return_value=fake_circuit,
|
|
82
|
+
):
|
|
83
|
+
rc = main(
|
|
84
|
+
[
|
|
85
|
+
"--no-banner",
|
|
86
|
+
"witness",
|
|
87
|
+
str(circuit),
|
|
88
|
+
str(inputj),
|
|
89
|
+
str(outputj),
|
|
90
|
+
str(witness),
|
|
91
|
+
],
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
assert rc == 0
|
|
95
|
+
call_args = fake_circuit.base_testing.call_args
|
|
96
|
+
config = call_args[0][0]
|
|
97
|
+
assert config.run_type == RunType.GEN_WITNESS
|
|
98
|
+
assert config.circuit_path == str(circuit)
|
|
99
|
+
assert config.input_file == str(inputj)
|
|
100
|
+
assert config.output_file == str(outputj)
|
|
101
|
+
assert config.witness_file == str(witness)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@pytest.mark.unit
|
|
105
|
+
def test_prove_dispatch(tmp_path: Path) -> None:
|
|
106
|
+
circuit = tmp_path / "circuit.txt"
|
|
107
|
+
circuit.write_text("ok")
|
|
108
|
+
|
|
109
|
+
witness = tmp_path / "w.bin"
|
|
110
|
+
witness.write_bytes(b"\x00")
|
|
111
|
+
|
|
112
|
+
proof = tmp_path / "p.bin" # doesn't need to pre-exist
|
|
113
|
+
|
|
114
|
+
fake_circuit = MagicMock()
|
|
115
|
+
with patch(
|
|
116
|
+
"python.frontend.commands.prove.ProveCommand._build_circuit",
|
|
117
|
+
return_value=fake_circuit,
|
|
118
|
+
):
|
|
119
|
+
rc = main(
|
|
120
|
+
[
|
|
121
|
+
"--no-banner",
|
|
122
|
+
"prove",
|
|
123
|
+
"-c",
|
|
124
|
+
str(circuit),
|
|
125
|
+
"-w",
|
|
126
|
+
str(witness),
|
|
127
|
+
"-p",
|
|
128
|
+
str(proof),
|
|
129
|
+
],
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
assert rc == 0
|
|
133
|
+
call_args = fake_circuit.base_testing.call_args
|
|
134
|
+
config = call_args[0][0]
|
|
135
|
+
assert config.run_type == RunType.PROVE_WITNESS
|
|
136
|
+
assert config.circuit_path == str(circuit)
|
|
137
|
+
assert config.witness_file == str(witness)
|
|
138
|
+
assert config.proof_file == str(proof)
|
|
139
|
+
assert config.ecc is False
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@pytest.mark.unit
|
|
143
|
+
def test_prove_dispatch_positional(tmp_path: Path) -> None:
|
|
144
|
+
circuit = tmp_path / "circuit.txt"
|
|
145
|
+
circuit.write_text("ok")
|
|
146
|
+
|
|
147
|
+
witness = tmp_path / "w.bin"
|
|
148
|
+
witness.write_bytes(b"\x00")
|
|
149
|
+
|
|
150
|
+
proof = tmp_path / "p.bin"
|
|
151
|
+
|
|
152
|
+
fake_circuit = MagicMock()
|
|
153
|
+
with patch(
|
|
154
|
+
"python.frontend.commands.prove.ProveCommand._build_circuit",
|
|
155
|
+
return_value=fake_circuit,
|
|
156
|
+
):
|
|
157
|
+
rc = main(
|
|
158
|
+
[
|
|
159
|
+
"--no-banner",
|
|
160
|
+
"prove",
|
|
161
|
+
str(circuit),
|
|
162
|
+
str(witness),
|
|
163
|
+
str(proof),
|
|
164
|
+
],
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
assert rc == 0
|
|
168
|
+
call_args = fake_circuit.base_testing.call_args
|
|
169
|
+
config = call_args[0][0]
|
|
170
|
+
assert config.run_type == RunType.PROVE_WITNESS
|
|
171
|
+
assert config.circuit_path == str(circuit)
|
|
172
|
+
assert config.witness_file == str(witness)
|
|
173
|
+
assert config.proof_file == str(proof)
|
|
174
|
+
assert config.ecc is False
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@pytest.mark.unit
|
|
178
|
+
def test_verify_dispatch(tmp_path: Path) -> None:
|
|
179
|
+
circuit = tmp_path / "circuit.txt"
|
|
180
|
+
circuit.write_text("ok")
|
|
181
|
+
|
|
182
|
+
inputj = tmp_path / "in.json"
|
|
183
|
+
inputj.write_text('{"input":[0]}')
|
|
184
|
+
|
|
185
|
+
outputj = tmp_path / "out.json"
|
|
186
|
+
outputj.write_text('{"output":[0]}') # verify requires it exists
|
|
187
|
+
|
|
188
|
+
witness = tmp_path / "w.bin"
|
|
189
|
+
witness.write_bytes(b"\x00")
|
|
190
|
+
|
|
191
|
+
proof = tmp_path / "p.bin"
|
|
192
|
+
proof.write_bytes(b"\x00")
|
|
193
|
+
|
|
194
|
+
quant = tmp_path / "q.onnx"
|
|
195
|
+
quant.write_bytes(b"\x00")
|
|
196
|
+
|
|
197
|
+
fake_circuit = MagicMock()
|
|
198
|
+
|
|
199
|
+
with patch(
|
|
200
|
+
"python.frontend.commands.verify.VerifyCommand._build_circuit",
|
|
201
|
+
return_value=fake_circuit,
|
|
202
|
+
):
|
|
203
|
+
rc = main(
|
|
204
|
+
[
|
|
205
|
+
"--no-banner",
|
|
206
|
+
"verify",
|
|
207
|
+
"-c",
|
|
208
|
+
str(circuit),
|
|
209
|
+
"-i",
|
|
210
|
+
str(inputj),
|
|
211
|
+
"-o",
|
|
212
|
+
str(outputj),
|
|
213
|
+
"-w",
|
|
214
|
+
str(witness),
|
|
215
|
+
"-p",
|
|
216
|
+
str(proof),
|
|
217
|
+
],
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
assert rc == 0
|
|
221
|
+
call_args = fake_circuit.base_testing.call_args
|
|
222
|
+
config = call_args[0][0]
|
|
223
|
+
assert config.run_type == RunType.GEN_VERIFY
|
|
224
|
+
assert config.circuit_path == str(circuit)
|
|
225
|
+
assert config.input_file == str(inputj)
|
|
226
|
+
assert config.output_file == str(outputj)
|
|
227
|
+
assert config.witness_file == str(witness)
|
|
228
|
+
assert config.proof_file == str(proof)
|
|
229
|
+
assert config.ecc is False
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@pytest.mark.unit
|
|
233
|
+
def test_verify_dispatch_positional(tmp_path: Path) -> None:
|
|
234
|
+
circuit = tmp_path / "circuit.txt"
|
|
235
|
+
circuit.write_text("ok")
|
|
236
|
+
|
|
237
|
+
inputj = tmp_path / "in.json"
|
|
238
|
+
inputj.write_text('{"input":[0]}')
|
|
239
|
+
|
|
240
|
+
outputj = tmp_path / "out.json"
|
|
241
|
+
outputj.write_text('{"output":[0]}')
|
|
242
|
+
|
|
243
|
+
witness = tmp_path / "w.bin"
|
|
244
|
+
witness.write_bytes(b"\x00")
|
|
245
|
+
|
|
246
|
+
proof = tmp_path / "p.bin"
|
|
247
|
+
proof.write_bytes(b"\x00")
|
|
248
|
+
|
|
249
|
+
quant = tmp_path / "q.onnx"
|
|
250
|
+
quant.write_bytes(b"\x00")
|
|
251
|
+
|
|
252
|
+
fake_circuit = MagicMock()
|
|
253
|
+
|
|
254
|
+
with patch(
|
|
255
|
+
"python.frontend.commands.verify.VerifyCommand._build_circuit",
|
|
256
|
+
return_value=fake_circuit,
|
|
257
|
+
):
|
|
258
|
+
rc = main(
|
|
259
|
+
[
|
|
260
|
+
"--no-banner",
|
|
261
|
+
"verify",
|
|
262
|
+
str(circuit),
|
|
263
|
+
str(inputj),
|
|
264
|
+
str(outputj),
|
|
265
|
+
str(witness),
|
|
266
|
+
str(proof),
|
|
267
|
+
],
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
assert rc == 0
|
|
271
|
+
call_args = fake_circuit.base_testing.call_args
|
|
272
|
+
config = call_args[0][0]
|
|
273
|
+
assert config.run_type == RunType.GEN_VERIFY
|
|
274
|
+
assert config.circuit_path == str(circuit)
|
|
275
|
+
assert config.input_file == str(inputj)
|
|
276
|
+
assert config.output_file == str(outputj)
|
|
277
|
+
assert config.witness_file == str(witness)
|
|
278
|
+
assert config.proof_file == str(proof)
|
|
279
|
+
assert config.ecc is False
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@pytest.mark.unit
|
|
283
|
+
def test_compile_dispatch(tmp_path: Path) -> None:
|
|
284
|
+
# minimal files so _ensure_exists passes
|
|
285
|
+
model = tmp_path / "model.onnx"
|
|
286
|
+
model.write_bytes(b"\x00")
|
|
287
|
+
|
|
288
|
+
circuit = tmp_path / "circuit.txt" # doesn't need to pre-exist
|
|
289
|
+
|
|
290
|
+
fake_circuit = MagicMock()
|
|
291
|
+
with patch(
|
|
292
|
+
"python.frontend.commands.compile.CompileCommand._build_circuit",
|
|
293
|
+
return_value=fake_circuit,
|
|
294
|
+
):
|
|
295
|
+
rc = main(
|
|
296
|
+
[
|
|
297
|
+
"--no-banner",
|
|
298
|
+
"compile",
|
|
299
|
+
"-m",
|
|
300
|
+
str(model),
|
|
301
|
+
"-c",
|
|
302
|
+
str(circuit),
|
|
303
|
+
],
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
assert rc == 0
|
|
307
|
+
assert fake_circuit.model_file_name == str(model)
|
|
308
|
+
assert fake_circuit.onnx_path == str(model)
|
|
309
|
+
assert fake_circuit.model_path == str(model)
|
|
310
|
+
# Check the base_testing call
|
|
311
|
+
call_args = fake_circuit.base_testing.call_args
|
|
312
|
+
config = call_args[0][0]
|
|
313
|
+
assert config.run_type == RunType.COMPILE_CIRCUIT
|
|
314
|
+
assert config.circuit_path == str(circuit)
|
|
315
|
+
assert config.dev_mode is False
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@pytest.mark.unit
|
|
319
|
+
def test_compile_dispatch_positional(tmp_path: Path) -> None:
|
|
320
|
+
model = tmp_path / "model.onnx"
|
|
321
|
+
model.write_bytes(b"\x00")
|
|
322
|
+
|
|
323
|
+
circuit = tmp_path / "circuit.txt"
|
|
324
|
+
|
|
325
|
+
fake_circuit = MagicMock()
|
|
326
|
+
with patch(
|
|
327
|
+
"python.frontend.commands.compile.CompileCommand._build_circuit",
|
|
328
|
+
return_value=fake_circuit,
|
|
329
|
+
):
|
|
330
|
+
rc = main(
|
|
331
|
+
[
|
|
332
|
+
"--no-banner",
|
|
333
|
+
"compile",
|
|
334
|
+
str(model),
|
|
335
|
+
str(circuit),
|
|
336
|
+
],
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
assert rc == 0
|
|
340
|
+
assert fake_circuit.model_file_name == str(model)
|
|
341
|
+
assert fake_circuit.onnx_path == str(model)
|
|
342
|
+
assert fake_circuit.model_path == str(model)
|
|
343
|
+
call_args = fake_circuit.base_testing.call_args
|
|
344
|
+
config = call_args[0][0]
|
|
345
|
+
assert config.run_type == RunType.COMPILE_CIRCUIT
|
|
346
|
+
assert config.circuit_path == str(circuit)
|
|
347
|
+
assert config.dev_mode is False
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@pytest.mark.unit
|
|
351
|
+
def test_compile_missing_model_path() -> None:
|
|
352
|
+
rc = main(["--no-banner", "compile", "-c", "circuit.txt"])
|
|
353
|
+
assert rc == 1
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@pytest.mark.unit
|
|
357
|
+
def test_compile_missing_circuit_path() -> None:
|
|
358
|
+
rc = main(["--no-banner", "compile", "-m", "model.onnx"])
|
|
359
|
+
assert rc == 1
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@pytest.mark.unit
|
|
363
|
+
def test_witness_missing_args() -> None:
|
|
364
|
+
rc = main(["--no-banner", "witness", "-c", "circuit.txt"])
|
|
365
|
+
assert rc == 1
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
@pytest.mark.unit
|
|
369
|
+
def test_prove_missing_args() -> None:
|
|
370
|
+
rc = main(["--no-banner", "prove", "-c", "circuit.txt"])
|
|
371
|
+
assert rc == 1
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@pytest.mark.unit
|
|
375
|
+
def test_verify_missing_args() -> None:
|
|
376
|
+
rc = main(["--no-banner", "verify", "-c", "circuit.txt"])
|
|
377
|
+
assert rc == 1
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
@pytest.mark.unit
|
|
381
|
+
def test_model_check_missing_model_path() -> None:
|
|
382
|
+
rc = main(["--no-banner", "model_check"])
|
|
383
|
+
assert rc == 1
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
@pytest.mark.unit
|
|
387
|
+
def test_compile_file_not_found(tmp_path: Path) -> None:
|
|
388
|
+
circuit = tmp_path / "circuit.txt"
|
|
389
|
+
rc = main(
|
|
390
|
+
[
|
|
391
|
+
"--no-banner",
|
|
392
|
+
"compile",
|
|
393
|
+
"-m",
|
|
394
|
+
"nonexistent.onnx",
|
|
395
|
+
"-c",
|
|
396
|
+
str(circuit),
|
|
397
|
+
],
|
|
398
|
+
)
|
|
399
|
+
assert rc == 1
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
@pytest.mark.unit
|
|
403
|
+
def test_witness_file_not_found(tmp_path: Path) -> None:
|
|
404
|
+
output = tmp_path / "out.json"
|
|
405
|
+
witness = tmp_path / "w.bin"
|
|
406
|
+
rc = main(
|
|
407
|
+
[
|
|
408
|
+
"--no-banner",
|
|
409
|
+
"witness",
|
|
410
|
+
"-c",
|
|
411
|
+
"nonexistent.txt",
|
|
412
|
+
"-i",
|
|
413
|
+
"nonexistent.json",
|
|
414
|
+
"-o",
|
|
415
|
+
str(output),
|
|
416
|
+
"-w",
|
|
417
|
+
str(witness),
|
|
418
|
+
],
|
|
419
|
+
)
|
|
420
|
+
assert rc == 1
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@pytest.mark.unit
|
|
424
|
+
def test_prove_file_not_found(tmp_path: Path) -> None:
|
|
425
|
+
proof = tmp_path / "proof.bin"
|
|
426
|
+
rc = main(
|
|
427
|
+
[
|
|
428
|
+
"--no-banner",
|
|
429
|
+
"prove",
|
|
430
|
+
"-c",
|
|
431
|
+
"nonexistent.txt",
|
|
432
|
+
"-w",
|
|
433
|
+
"nonexistent.bin",
|
|
434
|
+
"-p",
|
|
435
|
+
str(proof),
|
|
436
|
+
],
|
|
437
|
+
)
|
|
438
|
+
assert rc == 1
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
@pytest.mark.unit
|
|
442
|
+
def test_verify_file_not_found(tmp_path: Path) -> None:
|
|
443
|
+
rc = main(
|
|
444
|
+
[
|
|
445
|
+
"--no-banner",
|
|
446
|
+
"verify",
|
|
447
|
+
"-c",
|
|
448
|
+
"nonexistent.txt",
|
|
449
|
+
"-i",
|
|
450
|
+
"nonexistent.json",
|
|
451
|
+
"-o",
|
|
452
|
+
"nonexistent_out.json",
|
|
453
|
+
"-w",
|
|
454
|
+
"nonexistent.bin",
|
|
455
|
+
"-p",
|
|
456
|
+
"nonexistent_proof.bin",
|
|
457
|
+
],
|
|
458
|
+
)
|
|
459
|
+
assert rc == 1
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@pytest.mark.unit
|
|
463
|
+
def test_model_check_file_not_found() -> None:
|
|
464
|
+
rc = main(["--no-banner", "model_check", "-m", "nonexistent.onnx"])
|
|
465
|
+
assert rc == 1
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
@pytest.mark.unit
|
|
469
|
+
def test_compile_mixed_positional_and_flag(tmp_path: Path) -> None:
|
|
470
|
+
model = tmp_path / "model.onnx"
|
|
471
|
+
model.write_bytes(b"\x00")
|
|
472
|
+
circuit = tmp_path / "circuit.txt"
|
|
473
|
+
|
|
474
|
+
fake_circuit = MagicMock()
|
|
475
|
+
with patch(
|
|
476
|
+
"python.frontend.commands.compile.CompileCommand._build_circuit",
|
|
477
|
+
return_value=fake_circuit,
|
|
478
|
+
):
|
|
479
|
+
rc = main(
|
|
480
|
+
[
|
|
481
|
+
"--no-banner",
|
|
482
|
+
"compile",
|
|
483
|
+
str(model),
|
|
484
|
+
"-c",
|
|
485
|
+
str(circuit),
|
|
486
|
+
],
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
assert rc == 0
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
@pytest.mark.unit
|
|
493
|
+
def test_witness_mixed_positional_and_flag(tmp_path: Path) -> None:
|
|
494
|
+
circuit = tmp_path / "circuit.txt"
|
|
495
|
+
circuit.write_text("ok")
|
|
496
|
+
|
|
497
|
+
inputj = tmp_path / "in.json"
|
|
498
|
+
inputj.write_text('{"input":[0]}')
|
|
499
|
+
|
|
500
|
+
outputj = tmp_path / "out.json"
|
|
501
|
+
witness = tmp_path / "w.bin"
|
|
502
|
+
|
|
503
|
+
fake_circuit = MagicMock()
|
|
504
|
+
with patch(
|
|
505
|
+
"python.frontend.commands.witness.WitnessCommand._build_circuit",
|
|
506
|
+
return_value=fake_circuit,
|
|
507
|
+
):
|
|
508
|
+
rc = main(
|
|
509
|
+
[
|
|
510
|
+
"--no-banner",
|
|
511
|
+
"witness",
|
|
512
|
+
str(circuit),
|
|
513
|
+
"-i",
|
|
514
|
+
str(inputj),
|
|
515
|
+
"-o",
|
|
516
|
+
str(outputj),
|
|
517
|
+
"-w",
|
|
518
|
+
str(witness),
|
|
519
|
+
],
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
assert rc == 0
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
@pytest.mark.unit
|
|
526
|
+
def test_prove_mixed_positional_and_flag(tmp_path: Path) -> None:
|
|
527
|
+
circuit = tmp_path / "circuit.txt"
|
|
528
|
+
circuit.write_text("ok")
|
|
529
|
+
|
|
530
|
+
witness = tmp_path / "w.bin"
|
|
531
|
+
witness.write_bytes(b"\x00")
|
|
532
|
+
|
|
533
|
+
proof = tmp_path / "p.bin"
|
|
534
|
+
|
|
535
|
+
fake_circuit = MagicMock()
|
|
536
|
+
with patch(
|
|
537
|
+
"python.frontend.commands.prove.ProveCommand._build_circuit",
|
|
538
|
+
return_value=fake_circuit,
|
|
539
|
+
):
|
|
540
|
+
rc = main(
|
|
541
|
+
[
|
|
542
|
+
"--no-banner",
|
|
543
|
+
"prove",
|
|
544
|
+
str(circuit),
|
|
545
|
+
str(witness),
|
|
546
|
+
"-p",
|
|
547
|
+
str(proof),
|
|
548
|
+
],
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
assert rc == 0
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
@pytest.mark.unit
|
|
555
|
+
def test_model_check_positional(tmp_path: Path) -> None:
|
|
556
|
+
model = tmp_path / "model.onnx"
|
|
557
|
+
model.write_bytes(b"\x00")
|
|
558
|
+
|
|
559
|
+
with patch("onnx.load") as mock_load:
|
|
560
|
+
mock_model = MagicMock()
|
|
561
|
+
mock_load.return_value = mock_model
|
|
562
|
+
|
|
563
|
+
with patch(
|
|
564
|
+
"python.core.model_processing.onnx_quantizer.onnx_op_quantizer.ONNXOpQuantizer",
|
|
565
|
+
) as mock_quantizer_cls:
|
|
566
|
+
mock_quantizer = MagicMock()
|
|
567
|
+
mock_quantizer_cls.return_value = mock_quantizer
|
|
568
|
+
|
|
569
|
+
rc = main(["--no-banner", "model_check", str(model)])
|
|
570
|
+
|
|
571
|
+
assert rc == 0
|
|
572
|
+
mock_load.assert_called_once_with(str(model))
|
|
573
|
+
mock_quantizer.check_model.assert_called_once()
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
@pytest.mark.unit
|
|
577
|
+
def test_flag_takes_precedence_over_positional(tmp_path: Path) -> None:
|
|
578
|
+
model_flag = tmp_path / "flag_model.onnx"
|
|
579
|
+
model_flag.write_bytes(b"\x00")
|
|
580
|
+
model_pos = tmp_path / "pos_model.onnx"
|
|
581
|
+
model_pos.write_bytes(b"\x00")
|
|
582
|
+
circuit = tmp_path / "circuit.txt"
|
|
583
|
+
|
|
584
|
+
fake_circuit = MagicMock()
|
|
585
|
+
with patch(
|
|
586
|
+
"python.frontend.commands.compile.CompileCommand._build_circuit",
|
|
587
|
+
return_value=fake_circuit,
|
|
588
|
+
):
|
|
589
|
+
rc = main(
|
|
590
|
+
[
|
|
591
|
+
"--no-banner",
|
|
592
|
+
"compile",
|
|
593
|
+
str(model_pos),
|
|
594
|
+
"-m",
|
|
595
|
+
str(model_flag),
|
|
596
|
+
"-c",
|
|
597
|
+
str(circuit),
|
|
598
|
+
],
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
assert rc == 0
|
|
602
|
+
assert fake_circuit.model_path == str(model_flag)
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
@pytest.mark.unit
|
|
606
|
+
def test_parent_dir_creation(tmp_path: Path) -> None:
|
|
607
|
+
model = tmp_path / "model.onnx"
|
|
608
|
+
model.write_bytes(b"\x00")
|
|
609
|
+
nested_circuit = tmp_path / "nested" / "deep" / "circuit.txt"
|
|
610
|
+
|
|
611
|
+
fake_circuit = MagicMock()
|
|
612
|
+
with patch(
|
|
613
|
+
"python.frontend.commands.compile.CompileCommand._build_circuit",
|
|
614
|
+
return_value=fake_circuit,
|
|
615
|
+
):
|
|
616
|
+
rc = main(
|
|
617
|
+
[
|
|
618
|
+
"--no-banner",
|
|
619
|
+
"compile",
|
|
620
|
+
"-m",
|
|
621
|
+
str(model),
|
|
622
|
+
"-c",
|
|
623
|
+
str(nested_circuit),
|
|
624
|
+
],
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
assert rc == 0
|
|
628
|
+
assert nested_circuit.parent.exists()
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
@pytest.mark.unit
|
|
632
|
+
def test_verify_mixed_positional_and_flag(tmp_path: Path) -> None:
|
|
633
|
+
circuit = tmp_path / "circuit.txt"
|
|
634
|
+
circuit.write_text("ok")
|
|
635
|
+
|
|
636
|
+
inputj = tmp_path / "in.json"
|
|
637
|
+
inputj.write_text('{"input":[0]}')
|
|
638
|
+
|
|
639
|
+
outputj = tmp_path / "out.json"
|
|
640
|
+
outputj.write_text('{"output":[0]}')
|
|
641
|
+
|
|
642
|
+
witness = tmp_path / "w.bin"
|
|
643
|
+
witness.write_bytes(b"\x00")
|
|
644
|
+
|
|
645
|
+
proof = tmp_path / "p.bin"
|
|
646
|
+
proof.write_bytes(b"\x00")
|
|
647
|
+
|
|
648
|
+
fake_circuit = MagicMock()
|
|
649
|
+
with patch(
|
|
650
|
+
"python.frontend.commands.verify.VerifyCommand._build_circuit",
|
|
651
|
+
return_value=fake_circuit,
|
|
652
|
+
):
|
|
653
|
+
rc = main(
|
|
654
|
+
[
|
|
655
|
+
"--no-banner",
|
|
656
|
+
"verify",
|
|
657
|
+
str(circuit),
|
|
658
|
+
str(inputj),
|
|
659
|
+
"-o",
|
|
660
|
+
str(outputj),
|
|
661
|
+
"-w",
|
|
662
|
+
str(witness),
|
|
663
|
+
"-p",
|
|
664
|
+
str(proof),
|
|
665
|
+
],
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
assert rc == 0
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
@pytest.mark.unit
|
|
672
|
+
def test_circuit_run_error_handling(tmp_path: Path) -> None:
|
|
673
|
+
model = tmp_path / "model.onnx"
|
|
674
|
+
model.write_bytes(b"\x00")
|
|
675
|
+
circuit = tmp_path / "circuit.txt"
|
|
676
|
+
|
|
677
|
+
fake_circuit = MagicMock()
|
|
678
|
+
fake_circuit.base_testing.side_effect = CircuitRunError("Test error")
|
|
679
|
+
|
|
680
|
+
with patch(
|
|
681
|
+
"python.frontend.commands.compile.CompileCommand._build_circuit",
|
|
682
|
+
return_value=fake_circuit,
|
|
683
|
+
):
|
|
684
|
+
rc = main(
|
|
685
|
+
[
|
|
686
|
+
"--no-banner",
|
|
687
|
+
"compile",
|
|
688
|
+
"-m",
|
|
689
|
+
str(model),
|
|
690
|
+
"-c",
|
|
691
|
+
str(circuit),
|
|
692
|
+
],
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
assert rc == 1
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
@pytest.mark.unit
|
|
699
|
+
def test_model_check_unsupported_op_error(tmp_path: Path) -> None:
|
|
700
|
+
model = tmp_path / "model.onnx"
|
|
701
|
+
model.write_bytes(b"\x00")
|
|
702
|
+
|
|
703
|
+
with patch("onnx.load") as mock_load:
|
|
704
|
+
mock_model = MagicMock()
|
|
705
|
+
mock_load.return_value = mock_model
|
|
706
|
+
|
|
707
|
+
with patch(
|
|
708
|
+
"python.core.model_processing.onnx_quantizer.onnx_op_quantizer.ONNXOpQuantizer",
|
|
709
|
+
) as mock_quantizer_cls:
|
|
710
|
+
mock_quantizer = MagicMock()
|
|
711
|
+
mock_quantizer.check_model.side_effect = UnsupportedOpError(["BadOp"])
|
|
712
|
+
mock_quantizer_cls.return_value = mock_quantizer
|
|
713
|
+
|
|
714
|
+
rc = main(["--no-banner", "model_check", "-m", str(model)])
|
|
715
|
+
|
|
716
|
+
assert rc == 1
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
@pytest.mark.unit
|
|
720
|
+
def test_empty_string_arg() -> None:
|
|
721
|
+
rc = main(["--no-banner", "compile", "-m", "", "-c", "circuit.txt"])
|
|
722
|
+
assert rc == 1
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
@pytest.mark.unit
|
|
726
|
+
def test_flag_empty_string_uses_positional(tmp_path: Path) -> None:
|
|
727
|
+
model = tmp_path / "model.onnx"
|
|
728
|
+
model.write_bytes(b"\x00")
|
|
729
|
+
circuit = tmp_path / "circuit.txt"
|
|
730
|
+
|
|
731
|
+
fake_circuit = MagicMock()
|
|
732
|
+
with patch(
|
|
733
|
+
"python.frontend.commands.compile.CompileCommand._build_circuit",
|
|
734
|
+
return_value=fake_circuit,
|
|
735
|
+
):
|
|
736
|
+
rc = main(
|
|
737
|
+
[
|
|
738
|
+
"--no-banner",
|
|
739
|
+
"compile",
|
|
740
|
+
str(model),
|
|
741
|
+
"-m",
|
|
742
|
+
"",
|
|
743
|
+
"-c",
|
|
744
|
+
str(circuit),
|
|
745
|
+
],
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
assert rc == 1
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
# -----------------------
|
|
752
|
+
# bench command tests
|
|
753
|
+
# -----------------------
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
@pytest.mark.unit
|
|
757
|
+
def test_bench_list_models() -> None:
|
|
758
|
+
with patch(
|
|
759
|
+
"python.core.utils.model_registry.list_available_models",
|
|
760
|
+
return_value=["onnx: model1", "class: model2"],
|
|
761
|
+
):
|
|
762
|
+
rc = main(["--no-banner", "bench", "list", "--list-models"])
|
|
763
|
+
|
|
764
|
+
assert rc == 0
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
@pytest.mark.unit
|
|
768
|
+
def test_bench_with_model_path(tmp_path: Path) -> None:
|
|
769
|
+
model = tmp_path / "model.onnx"
|
|
770
|
+
model.write_bytes(b"\x00")
|
|
771
|
+
|
|
772
|
+
with (
|
|
773
|
+
patch(
|
|
774
|
+
"python.frontend.commands.bench.model.ModelCommand._generate_model_input",
|
|
775
|
+
),
|
|
776
|
+
patch("python.frontend.commands.bench.model.run_subprocess"),
|
|
777
|
+
):
|
|
778
|
+
rc = main(["--no-banner", "bench", "model", "--model-path", str(model)])
|
|
779
|
+
|
|
780
|
+
assert rc == 0
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
@pytest.mark.unit
|
|
784
|
+
def test_bench_with_model_flag() -> None:
|
|
785
|
+
fake_model_entry = MagicMock()
|
|
786
|
+
fake_instance = MagicMock()
|
|
787
|
+
fake_instance.model_file_name = "test_model.onnx"
|
|
788
|
+
fake_model_entry.loader.return_value = fake_instance
|
|
789
|
+
fake_model_entry.name = "test_model"
|
|
790
|
+
|
|
791
|
+
with (
|
|
792
|
+
patch(
|
|
793
|
+
"python.core.utils.model_registry.get_models_to_test",
|
|
794
|
+
return_value=[fake_model_entry],
|
|
795
|
+
),
|
|
796
|
+
patch(
|
|
797
|
+
"python.frontend.commands.bench.model.ModelCommand._generate_model_input",
|
|
798
|
+
),
|
|
799
|
+
patch("python.frontend.commands.bench.model.run_subprocess"),
|
|
800
|
+
):
|
|
801
|
+
rc = main(["--no-banner", "bench", "model", "--model", "test_model"])
|
|
802
|
+
|
|
803
|
+
assert rc == 0
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
@pytest.mark.unit
|
|
807
|
+
def test_bench_with_source_filter() -> None:
|
|
808
|
+
fake_model_entry = MagicMock()
|
|
809
|
+
fake_instance = MagicMock()
|
|
810
|
+
fake_instance.model_file_name = "test_model.onnx"
|
|
811
|
+
fake_model_entry.loader.return_value = fake_instance
|
|
812
|
+
fake_model_entry.name = "test_model"
|
|
813
|
+
|
|
814
|
+
with (
|
|
815
|
+
patch(
|
|
816
|
+
"python.core.utils.model_registry.get_models_to_test",
|
|
817
|
+
return_value=[fake_model_entry],
|
|
818
|
+
) as mock_get,
|
|
819
|
+
patch(
|
|
820
|
+
"python.frontend.commands.bench.model.ModelCommand._generate_model_input",
|
|
821
|
+
),
|
|
822
|
+
patch("python.frontend.commands.bench.model.run_subprocess"),
|
|
823
|
+
):
|
|
824
|
+
rc = main(["--no-banner", "bench", "model", "--source", "onnx"])
|
|
825
|
+
|
|
826
|
+
assert rc == 0
|
|
827
|
+
mock_get.assert_called_once_with(None, "onnx")
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
@pytest.mark.unit
|
|
831
|
+
def test_bench_depth_sweep_simple() -> None:
|
|
832
|
+
with patch("python.frontend.commands.bench.sweep.run_subprocess") as mock_run:
|
|
833
|
+
rc = main(["--no-banner", "bench", "sweep", "depth"])
|
|
834
|
+
|
|
835
|
+
assert rc == 0
|
|
836
|
+
cmd = mock_run.call_args[0][0]
|
|
837
|
+
assert "python.scripts.gen_and_bench" in cmd[2]
|
|
838
|
+
assert "--sweep" in cmd
|
|
839
|
+
assert "depth" in cmd
|
|
840
|
+
assert "--depth-min" in cmd
|
|
841
|
+
assert "1" in cmd
|
|
842
|
+
assert "--depth-max" in cmd
|
|
843
|
+
assert "16" in cmd
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
@pytest.mark.unit
|
|
847
|
+
def test_bench_breadth_sweep_simple() -> None:
|
|
848
|
+
with patch("python.frontend.commands.bench.sweep.run_subprocess") as mock_run:
|
|
849
|
+
rc = main(["--no-banner", "bench", "sweep", "breadth"])
|
|
850
|
+
|
|
851
|
+
assert rc == 0
|
|
852
|
+
cmd = mock_run.call_args[0][0]
|
|
853
|
+
assert "python.scripts.gen_and_bench" in cmd[2]
|
|
854
|
+
assert "--sweep" in cmd
|
|
855
|
+
assert "breadth" in cmd
|
|
856
|
+
assert "--arch-depth" in cmd
|
|
857
|
+
assert "5" in cmd
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
@pytest.mark.unit
|
|
861
|
+
def test_bench_sweep_with_custom_args() -> None:
|
|
862
|
+
with patch("python.frontend.commands.bench.sweep.run_subprocess") as mock_run:
|
|
863
|
+
rc = main(
|
|
864
|
+
[
|
|
865
|
+
"--no-banner",
|
|
866
|
+
"bench",
|
|
867
|
+
"sweep",
|
|
868
|
+
"depth",
|
|
869
|
+
"--depth-min",
|
|
870
|
+
"5",
|
|
871
|
+
"--depth-max",
|
|
872
|
+
"10",
|
|
873
|
+
],
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
assert rc == 0
|
|
877
|
+
cmd = mock_run.call_args[0][0]
|
|
878
|
+
assert "--depth-min" in cmd
|
|
879
|
+
idx_min = cmd.index("--depth-min")
|
|
880
|
+
assert cmd[idx_min + 1] == "5"
|
|
881
|
+
assert "--depth-max" in cmd
|
|
882
|
+
idx_max = cmd.index("--depth-max")
|
|
883
|
+
assert cmd[idx_max + 1] == "10"
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
@pytest.mark.unit
|
|
887
|
+
def test_bench_sweep_with_optional_args() -> None:
|
|
888
|
+
with patch("python.frontend.commands.bench.sweep.run_subprocess") as mock_run:
|
|
889
|
+
rc = main(
|
|
890
|
+
[
|
|
891
|
+
"--no-banner",
|
|
892
|
+
"bench",
|
|
893
|
+
"sweep",
|
|
894
|
+
"depth",
|
|
895
|
+
"--tag",
|
|
896
|
+
"test_tag",
|
|
897
|
+
"--onnx-dir",
|
|
898
|
+
"custom_onnx",
|
|
899
|
+
],
|
|
900
|
+
)
|
|
901
|
+
|
|
902
|
+
assert rc == 0
|
|
903
|
+
cmd = mock_run.call_args[0][0]
|
|
904
|
+
assert "--tag" in cmd
|
|
905
|
+
assert "test_tag" in cmd
|
|
906
|
+
assert "--onnx-dir" in cmd
|
|
907
|
+
assert "custom_onnx" in cmd
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
@pytest.mark.unit
|
|
911
|
+
def test_bench_missing_required_args() -> None:
|
|
912
|
+
with pytest.raises(SystemExit) as exc_info:
|
|
913
|
+
main(["--no-banner", "bench"])
|
|
914
|
+
# argparse exits with code 2 for usage errors
|
|
915
|
+
assert exc_info.value.code == 2 # noqa: PLR2004
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
@pytest.mark.unit
|
|
919
|
+
def test_bench_nonexistent_model_path() -> None:
|
|
920
|
+
rc = main(["--no-banner", "bench", "model", "-m", "nonexistent.onnx"])
|
|
921
|
+
assert rc == 1
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
@pytest.mark.unit
|
|
925
|
+
def test_bench_no_models_found() -> None:
|
|
926
|
+
with patch(
|
|
927
|
+
"python.core.utils.model_registry.get_models_to_test",
|
|
928
|
+
return_value=[],
|
|
929
|
+
):
|
|
930
|
+
rc = main(["--no-banner", "bench", "model", "--model", "nonexistent_model"])
|
|
931
|
+
|
|
932
|
+
assert rc == 1
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
@pytest.mark.unit
|
|
936
|
+
def test_bench_subprocess_failure(tmp_path: Path) -> None:
|
|
937
|
+
model = tmp_path / "model.onnx"
|
|
938
|
+
model.write_bytes(b"\x00")
|
|
939
|
+
|
|
940
|
+
fake_circuit = MagicMock()
|
|
941
|
+
fake_circuit.get_inputs.return_value = {"input": [0]}
|
|
942
|
+
fake_circuit.format_inputs.return_value = {"input": [0]}
|
|
943
|
+
|
|
944
|
+
with (
|
|
945
|
+
patch(
|
|
946
|
+
"python.frontend.commands.bench.model.ModelCommand._build_circuit",
|
|
947
|
+
return_value=fake_circuit,
|
|
948
|
+
),
|
|
949
|
+
patch(
|
|
950
|
+
"python.frontend.commands.bench.model.run_subprocess",
|
|
951
|
+
side_effect=RuntimeError("Subprocess failed"),
|
|
952
|
+
),
|
|
953
|
+
):
|
|
954
|
+
rc = main(["--no-banner", "bench", "model", "-m", str(model)])
|
|
955
|
+
|
|
956
|
+
assert rc == 1
|
|
957
|
+
|
|
958
|
+
|
|
959
|
+
@pytest.mark.unit
|
|
960
|
+
def test_bench_model_load_failure(tmp_path: Path) -> None:
|
|
961
|
+
model = tmp_path / "model.onnx"
|
|
962
|
+
model.write_bytes(b"\x00")
|
|
963
|
+
|
|
964
|
+
fake_circuit = MagicMock()
|
|
965
|
+
fake_circuit.load_model.side_effect = RuntimeError("Failed to load model")
|
|
966
|
+
|
|
967
|
+
with patch(
|
|
968
|
+
"python.frontend.commands.bench.model.ModelCommand._build_circuit",
|
|
969
|
+
return_value=fake_circuit,
|
|
970
|
+
):
|
|
971
|
+
rc = main(["--no-banner", "bench", "model", "-m", str(model)])
|
|
972
|
+
|
|
973
|
+
assert rc == 1
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
@pytest.mark.unit
|
|
977
|
+
def test_bench_input_generation_failure(tmp_path: Path) -> None:
|
|
978
|
+
model = tmp_path / "model.onnx"
|
|
979
|
+
model.write_bytes(b"\x00")
|
|
980
|
+
|
|
981
|
+
fake_circuit = MagicMock()
|
|
982
|
+
fake_circuit.load_model.return_value = None
|
|
983
|
+
fake_circuit.get_inputs.side_effect = RuntimeError("Failed to generate input")
|
|
984
|
+
|
|
985
|
+
with patch(
|
|
986
|
+
"python.frontend.commands.bench.model.ModelCommand._build_circuit",
|
|
987
|
+
return_value=fake_circuit,
|
|
988
|
+
):
|
|
989
|
+
rc = main(["--no-banner", "bench", "model", "-m", str(model)])
|
|
990
|
+
|
|
991
|
+
assert rc == 1
|
|
992
|
+
|
|
993
|
+
|
|
994
|
+
@pytest.mark.unit
|
|
995
|
+
def test_bench_with_iterations(tmp_path: Path) -> None:
|
|
996
|
+
model = tmp_path / "model.onnx"
|
|
997
|
+
model.write_bytes(b"\x00")
|
|
998
|
+
|
|
999
|
+
with (
|
|
1000
|
+
patch(
|
|
1001
|
+
"python.frontend.commands.bench.model.ModelCommand._generate_model_input",
|
|
1002
|
+
),
|
|
1003
|
+
patch("python.frontend.commands.bench.model.run_subprocess") as mock_run,
|
|
1004
|
+
):
|
|
1005
|
+
rc = main(
|
|
1006
|
+
[
|
|
1007
|
+
"--no-banner",
|
|
1008
|
+
"bench",
|
|
1009
|
+
"model",
|
|
1010
|
+
"--model-path",
|
|
1011
|
+
str(model),
|
|
1012
|
+
"--iterations",
|
|
1013
|
+
"10",
|
|
1014
|
+
],
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
assert rc == 0
|
|
1018
|
+
cmd = mock_run.call_args[0][0]
|
|
1019
|
+
assert "--iterations" in cmd
|
|
1020
|
+
idx = cmd.index("--iterations")
|
|
1021
|
+
assert cmd[idx + 1] == "10"
|