JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +6 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,1021 @@
1
+ # python/testing/core/tests/test_cli.py
2
+ from pathlib import Path
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import pytest
6
+
7
+ from python.core.circuits.errors import CircuitRunError
8
+ from python.core.model_processing.onnx_quantizer.exceptions import (
9
+ UnsupportedOpError,
10
+ )
11
+ from python.core.utils.helper_functions import RunType
12
+ from python.frontend.cli import main
13
+
14
+ # -----------------------
15
+ # unit tests: dispatch only
16
+ # -----------------------
17
+
18
+
19
+ @pytest.mark.unit
20
+ def test_witness_dispatch(tmp_path: Path) -> None:
21
+ # minimal files so _ensure_exists passes
22
+ circuit = tmp_path / "circuit.txt"
23
+ circuit.write_text("ok")
24
+
25
+ quant = tmp_path / "q.onnx"
26
+ quant.write_bytes(b"\x00")
27
+
28
+ inputj = tmp_path / "in.json"
29
+ inputj.write_text('{"input":[0]}')
30
+
31
+ outputj = tmp_path / "out.json" # doesn't need to pre-exist
32
+ witness = tmp_path / "w.bin" # doesn't need to pre-exist
33
+
34
+ fake_circuit = MagicMock()
35
+ with patch(
36
+ "python.frontend.commands.witness.WitnessCommand._build_circuit",
37
+ return_value=fake_circuit,
38
+ ):
39
+ rc = main(
40
+ [
41
+ "--no-banner",
42
+ "witness",
43
+ "-c",
44
+ str(circuit),
45
+ "-i",
46
+ str(inputj),
47
+ "-o",
48
+ str(outputj),
49
+ "-w",
50
+ str(witness),
51
+ ],
52
+ )
53
+
54
+ assert rc == 0
55
+ call_args = fake_circuit.base_testing.call_args
56
+ config = call_args[0][0]
57
+ assert config.run_type == RunType.GEN_WITNESS
58
+ assert config.circuit_path == str(circuit)
59
+ assert config.input_file == str(inputj)
60
+ assert config.output_file == str(outputj)
61
+ assert config.witness_file == str(witness)
62
+
63
+
64
+ @pytest.mark.unit
65
+ def test_witness_dispatch_positional(tmp_path: Path) -> None:
66
+ circuit = tmp_path / "circuit.txt"
67
+ circuit.write_text("ok")
68
+
69
+ quant = tmp_path / "q.onnx"
70
+ quant.write_bytes(b"\x00")
71
+
72
+ inputj = tmp_path / "in.json"
73
+ inputj.write_text('{"input":[0]}')
74
+
75
+ outputj = tmp_path / "out.json"
76
+ witness = tmp_path / "w.bin"
77
+
78
+ fake_circuit = MagicMock()
79
+ with patch(
80
+ "python.frontend.commands.witness.WitnessCommand._build_circuit",
81
+ return_value=fake_circuit,
82
+ ):
83
+ rc = main(
84
+ [
85
+ "--no-banner",
86
+ "witness",
87
+ str(circuit),
88
+ str(inputj),
89
+ str(outputj),
90
+ str(witness),
91
+ ],
92
+ )
93
+
94
+ assert rc == 0
95
+ call_args = fake_circuit.base_testing.call_args
96
+ config = call_args[0][0]
97
+ assert config.run_type == RunType.GEN_WITNESS
98
+ assert config.circuit_path == str(circuit)
99
+ assert config.input_file == str(inputj)
100
+ assert config.output_file == str(outputj)
101
+ assert config.witness_file == str(witness)
102
+
103
+
104
+ @pytest.mark.unit
105
+ def test_prove_dispatch(tmp_path: Path) -> None:
106
+ circuit = tmp_path / "circuit.txt"
107
+ circuit.write_text("ok")
108
+
109
+ witness = tmp_path / "w.bin"
110
+ witness.write_bytes(b"\x00")
111
+
112
+ proof = tmp_path / "p.bin" # doesn't need to pre-exist
113
+
114
+ fake_circuit = MagicMock()
115
+ with patch(
116
+ "python.frontend.commands.prove.ProveCommand._build_circuit",
117
+ return_value=fake_circuit,
118
+ ):
119
+ rc = main(
120
+ [
121
+ "--no-banner",
122
+ "prove",
123
+ "-c",
124
+ str(circuit),
125
+ "-w",
126
+ str(witness),
127
+ "-p",
128
+ str(proof),
129
+ ],
130
+ )
131
+
132
+ assert rc == 0
133
+ call_args = fake_circuit.base_testing.call_args
134
+ config = call_args[0][0]
135
+ assert config.run_type == RunType.PROVE_WITNESS
136
+ assert config.circuit_path == str(circuit)
137
+ assert config.witness_file == str(witness)
138
+ assert config.proof_file == str(proof)
139
+ assert config.ecc is False
140
+
141
+
142
+ @pytest.mark.unit
143
+ def test_prove_dispatch_positional(tmp_path: Path) -> None:
144
+ circuit = tmp_path / "circuit.txt"
145
+ circuit.write_text("ok")
146
+
147
+ witness = tmp_path / "w.bin"
148
+ witness.write_bytes(b"\x00")
149
+
150
+ proof = tmp_path / "p.bin"
151
+
152
+ fake_circuit = MagicMock()
153
+ with patch(
154
+ "python.frontend.commands.prove.ProveCommand._build_circuit",
155
+ return_value=fake_circuit,
156
+ ):
157
+ rc = main(
158
+ [
159
+ "--no-banner",
160
+ "prove",
161
+ str(circuit),
162
+ str(witness),
163
+ str(proof),
164
+ ],
165
+ )
166
+
167
+ assert rc == 0
168
+ call_args = fake_circuit.base_testing.call_args
169
+ config = call_args[0][0]
170
+ assert config.run_type == RunType.PROVE_WITNESS
171
+ assert config.circuit_path == str(circuit)
172
+ assert config.witness_file == str(witness)
173
+ assert config.proof_file == str(proof)
174
+ assert config.ecc is False
175
+
176
+
177
+ @pytest.mark.unit
178
+ def test_verify_dispatch(tmp_path: Path) -> None:
179
+ circuit = tmp_path / "circuit.txt"
180
+ circuit.write_text("ok")
181
+
182
+ inputj = tmp_path / "in.json"
183
+ inputj.write_text('{"input":[0]}')
184
+
185
+ outputj = tmp_path / "out.json"
186
+ outputj.write_text('{"output":[0]}') # verify requires it exists
187
+
188
+ witness = tmp_path / "w.bin"
189
+ witness.write_bytes(b"\x00")
190
+
191
+ proof = tmp_path / "p.bin"
192
+ proof.write_bytes(b"\x00")
193
+
194
+ quant = tmp_path / "q.onnx"
195
+ quant.write_bytes(b"\x00")
196
+
197
+ fake_circuit = MagicMock()
198
+
199
+ with patch(
200
+ "python.frontend.commands.verify.VerifyCommand._build_circuit",
201
+ return_value=fake_circuit,
202
+ ):
203
+ rc = main(
204
+ [
205
+ "--no-banner",
206
+ "verify",
207
+ "-c",
208
+ str(circuit),
209
+ "-i",
210
+ str(inputj),
211
+ "-o",
212
+ str(outputj),
213
+ "-w",
214
+ str(witness),
215
+ "-p",
216
+ str(proof),
217
+ ],
218
+ )
219
+
220
+ assert rc == 0
221
+ call_args = fake_circuit.base_testing.call_args
222
+ config = call_args[0][0]
223
+ assert config.run_type == RunType.GEN_VERIFY
224
+ assert config.circuit_path == str(circuit)
225
+ assert config.input_file == str(inputj)
226
+ assert config.output_file == str(outputj)
227
+ assert config.witness_file == str(witness)
228
+ assert config.proof_file == str(proof)
229
+ assert config.ecc is False
230
+
231
+
232
+ @pytest.mark.unit
233
+ def test_verify_dispatch_positional(tmp_path: Path) -> None:
234
+ circuit = tmp_path / "circuit.txt"
235
+ circuit.write_text("ok")
236
+
237
+ inputj = tmp_path / "in.json"
238
+ inputj.write_text('{"input":[0]}')
239
+
240
+ outputj = tmp_path / "out.json"
241
+ outputj.write_text('{"output":[0]}')
242
+
243
+ witness = tmp_path / "w.bin"
244
+ witness.write_bytes(b"\x00")
245
+
246
+ proof = tmp_path / "p.bin"
247
+ proof.write_bytes(b"\x00")
248
+
249
+ quant = tmp_path / "q.onnx"
250
+ quant.write_bytes(b"\x00")
251
+
252
+ fake_circuit = MagicMock()
253
+
254
+ with patch(
255
+ "python.frontend.commands.verify.VerifyCommand._build_circuit",
256
+ return_value=fake_circuit,
257
+ ):
258
+ rc = main(
259
+ [
260
+ "--no-banner",
261
+ "verify",
262
+ str(circuit),
263
+ str(inputj),
264
+ str(outputj),
265
+ str(witness),
266
+ str(proof),
267
+ ],
268
+ )
269
+
270
+ assert rc == 0
271
+ call_args = fake_circuit.base_testing.call_args
272
+ config = call_args[0][0]
273
+ assert config.run_type == RunType.GEN_VERIFY
274
+ assert config.circuit_path == str(circuit)
275
+ assert config.input_file == str(inputj)
276
+ assert config.output_file == str(outputj)
277
+ assert config.witness_file == str(witness)
278
+ assert config.proof_file == str(proof)
279
+ assert config.ecc is False
280
+
281
+
282
+ @pytest.mark.unit
283
+ def test_compile_dispatch(tmp_path: Path) -> None:
284
+ # minimal files so _ensure_exists passes
285
+ model = tmp_path / "model.onnx"
286
+ model.write_bytes(b"\x00")
287
+
288
+ circuit = tmp_path / "circuit.txt" # doesn't need to pre-exist
289
+
290
+ fake_circuit = MagicMock()
291
+ with patch(
292
+ "python.frontend.commands.compile.CompileCommand._build_circuit",
293
+ return_value=fake_circuit,
294
+ ):
295
+ rc = main(
296
+ [
297
+ "--no-banner",
298
+ "compile",
299
+ "-m",
300
+ str(model),
301
+ "-c",
302
+ str(circuit),
303
+ ],
304
+ )
305
+
306
+ assert rc == 0
307
+ assert fake_circuit.model_file_name == str(model)
308
+ assert fake_circuit.onnx_path == str(model)
309
+ assert fake_circuit.model_path == str(model)
310
+ # Check the base_testing call
311
+ call_args = fake_circuit.base_testing.call_args
312
+ config = call_args[0][0]
313
+ assert config.run_type == RunType.COMPILE_CIRCUIT
314
+ assert config.circuit_path == str(circuit)
315
+ assert config.dev_mode is False
316
+
317
+
318
+ @pytest.mark.unit
319
+ def test_compile_dispatch_positional(tmp_path: Path) -> None:
320
+ model = tmp_path / "model.onnx"
321
+ model.write_bytes(b"\x00")
322
+
323
+ circuit = tmp_path / "circuit.txt"
324
+
325
+ fake_circuit = MagicMock()
326
+ with patch(
327
+ "python.frontend.commands.compile.CompileCommand._build_circuit",
328
+ return_value=fake_circuit,
329
+ ):
330
+ rc = main(
331
+ [
332
+ "--no-banner",
333
+ "compile",
334
+ str(model),
335
+ str(circuit),
336
+ ],
337
+ )
338
+
339
+ assert rc == 0
340
+ assert fake_circuit.model_file_name == str(model)
341
+ assert fake_circuit.onnx_path == str(model)
342
+ assert fake_circuit.model_path == str(model)
343
+ call_args = fake_circuit.base_testing.call_args
344
+ config = call_args[0][0]
345
+ assert config.run_type == RunType.COMPILE_CIRCUIT
346
+ assert config.circuit_path == str(circuit)
347
+ assert config.dev_mode is False
348
+
349
+
350
+ @pytest.mark.unit
351
+ def test_compile_missing_model_path() -> None:
352
+ rc = main(["--no-banner", "compile", "-c", "circuit.txt"])
353
+ assert rc == 1
354
+
355
+
356
+ @pytest.mark.unit
357
+ def test_compile_missing_circuit_path() -> None:
358
+ rc = main(["--no-banner", "compile", "-m", "model.onnx"])
359
+ assert rc == 1
360
+
361
+
362
+ @pytest.mark.unit
363
+ def test_witness_missing_args() -> None:
364
+ rc = main(["--no-banner", "witness", "-c", "circuit.txt"])
365
+ assert rc == 1
366
+
367
+
368
+ @pytest.mark.unit
369
+ def test_prove_missing_args() -> None:
370
+ rc = main(["--no-banner", "prove", "-c", "circuit.txt"])
371
+ assert rc == 1
372
+
373
+
374
+ @pytest.mark.unit
375
+ def test_verify_missing_args() -> None:
376
+ rc = main(["--no-banner", "verify", "-c", "circuit.txt"])
377
+ assert rc == 1
378
+
379
+
380
+ @pytest.mark.unit
381
+ def test_model_check_missing_model_path() -> None:
382
+ rc = main(["--no-banner", "model_check"])
383
+ assert rc == 1
384
+
385
+
386
+ @pytest.mark.unit
387
+ def test_compile_file_not_found(tmp_path: Path) -> None:
388
+ circuit = tmp_path / "circuit.txt"
389
+ rc = main(
390
+ [
391
+ "--no-banner",
392
+ "compile",
393
+ "-m",
394
+ "nonexistent.onnx",
395
+ "-c",
396
+ str(circuit),
397
+ ],
398
+ )
399
+ assert rc == 1
400
+
401
+
402
+ @pytest.mark.unit
403
+ def test_witness_file_not_found(tmp_path: Path) -> None:
404
+ output = tmp_path / "out.json"
405
+ witness = tmp_path / "w.bin"
406
+ rc = main(
407
+ [
408
+ "--no-banner",
409
+ "witness",
410
+ "-c",
411
+ "nonexistent.txt",
412
+ "-i",
413
+ "nonexistent.json",
414
+ "-o",
415
+ str(output),
416
+ "-w",
417
+ str(witness),
418
+ ],
419
+ )
420
+ assert rc == 1
421
+
422
+
423
+ @pytest.mark.unit
424
+ def test_prove_file_not_found(tmp_path: Path) -> None:
425
+ proof = tmp_path / "proof.bin"
426
+ rc = main(
427
+ [
428
+ "--no-banner",
429
+ "prove",
430
+ "-c",
431
+ "nonexistent.txt",
432
+ "-w",
433
+ "nonexistent.bin",
434
+ "-p",
435
+ str(proof),
436
+ ],
437
+ )
438
+ assert rc == 1
439
+
440
+
441
+ @pytest.mark.unit
442
+ def test_verify_file_not_found(tmp_path: Path) -> None:
443
+ rc = main(
444
+ [
445
+ "--no-banner",
446
+ "verify",
447
+ "-c",
448
+ "nonexistent.txt",
449
+ "-i",
450
+ "nonexistent.json",
451
+ "-o",
452
+ "nonexistent_out.json",
453
+ "-w",
454
+ "nonexistent.bin",
455
+ "-p",
456
+ "nonexistent_proof.bin",
457
+ ],
458
+ )
459
+ assert rc == 1
460
+
461
+
462
+ @pytest.mark.unit
463
+ def test_model_check_file_not_found() -> None:
464
+ rc = main(["--no-banner", "model_check", "-m", "nonexistent.onnx"])
465
+ assert rc == 1
466
+
467
+
468
+ @pytest.mark.unit
469
+ def test_compile_mixed_positional_and_flag(tmp_path: Path) -> None:
470
+ model = tmp_path / "model.onnx"
471
+ model.write_bytes(b"\x00")
472
+ circuit = tmp_path / "circuit.txt"
473
+
474
+ fake_circuit = MagicMock()
475
+ with patch(
476
+ "python.frontend.commands.compile.CompileCommand._build_circuit",
477
+ return_value=fake_circuit,
478
+ ):
479
+ rc = main(
480
+ [
481
+ "--no-banner",
482
+ "compile",
483
+ str(model),
484
+ "-c",
485
+ str(circuit),
486
+ ],
487
+ )
488
+
489
+ assert rc == 0
490
+
491
+
492
+ @pytest.mark.unit
493
+ def test_witness_mixed_positional_and_flag(tmp_path: Path) -> None:
494
+ circuit = tmp_path / "circuit.txt"
495
+ circuit.write_text("ok")
496
+
497
+ inputj = tmp_path / "in.json"
498
+ inputj.write_text('{"input":[0]}')
499
+
500
+ outputj = tmp_path / "out.json"
501
+ witness = tmp_path / "w.bin"
502
+
503
+ fake_circuit = MagicMock()
504
+ with patch(
505
+ "python.frontend.commands.witness.WitnessCommand._build_circuit",
506
+ return_value=fake_circuit,
507
+ ):
508
+ rc = main(
509
+ [
510
+ "--no-banner",
511
+ "witness",
512
+ str(circuit),
513
+ "-i",
514
+ str(inputj),
515
+ "-o",
516
+ str(outputj),
517
+ "-w",
518
+ str(witness),
519
+ ],
520
+ )
521
+
522
+ assert rc == 0
523
+
524
+
525
+ @pytest.mark.unit
526
+ def test_prove_mixed_positional_and_flag(tmp_path: Path) -> None:
527
+ circuit = tmp_path / "circuit.txt"
528
+ circuit.write_text("ok")
529
+
530
+ witness = tmp_path / "w.bin"
531
+ witness.write_bytes(b"\x00")
532
+
533
+ proof = tmp_path / "p.bin"
534
+
535
+ fake_circuit = MagicMock()
536
+ with patch(
537
+ "python.frontend.commands.prove.ProveCommand._build_circuit",
538
+ return_value=fake_circuit,
539
+ ):
540
+ rc = main(
541
+ [
542
+ "--no-banner",
543
+ "prove",
544
+ str(circuit),
545
+ str(witness),
546
+ "-p",
547
+ str(proof),
548
+ ],
549
+ )
550
+
551
+ assert rc == 0
552
+
553
+
554
+ @pytest.mark.unit
555
+ def test_model_check_positional(tmp_path: Path) -> None:
556
+ model = tmp_path / "model.onnx"
557
+ model.write_bytes(b"\x00")
558
+
559
+ with patch("onnx.load") as mock_load:
560
+ mock_model = MagicMock()
561
+ mock_load.return_value = mock_model
562
+
563
+ with patch(
564
+ "python.core.model_processing.onnx_quantizer.onnx_op_quantizer.ONNXOpQuantizer",
565
+ ) as mock_quantizer_cls:
566
+ mock_quantizer = MagicMock()
567
+ mock_quantizer_cls.return_value = mock_quantizer
568
+
569
+ rc = main(["--no-banner", "model_check", str(model)])
570
+
571
+ assert rc == 0
572
+ mock_load.assert_called_once_with(str(model))
573
+ mock_quantizer.check_model.assert_called_once()
574
+
575
+
576
+ @pytest.mark.unit
577
+ def test_flag_takes_precedence_over_positional(tmp_path: Path) -> None:
578
+ model_flag = tmp_path / "flag_model.onnx"
579
+ model_flag.write_bytes(b"\x00")
580
+ model_pos = tmp_path / "pos_model.onnx"
581
+ model_pos.write_bytes(b"\x00")
582
+ circuit = tmp_path / "circuit.txt"
583
+
584
+ fake_circuit = MagicMock()
585
+ with patch(
586
+ "python.frontend.commands.compile.CompileCommand._build_circuit",
587
+ return_value=fake_circuit,
588
+ ):
589
+ rc = main(
590
+ [
591
+ "--no-banner",
592
+ "compile",
593
+ str(model_pos),
594
+ "-m",
595
+ str(model_flag),
596
+ "-c",
597
+ str(circuit),
598
+ ],
599
+ )
600
+
601
+ assert rc == 0
602
+ assert fake_circuit.model_path == str(model_flag)
603
+
604
+
605
+ @pytest.mark.unit
606
+ def test_parent_dir_creation(tmp_path: Path) -> None:
607
+ model = tmp_path / "model.onnx"
608
+ model.write_bytes(b"\x00")
609
+ nested_circuit = tmp_path / "nested" / "deep" / "circuit.txt"
610
+
611
+ fake_circuit = MagicMock()
612
+ with patch(
613
+ "python.frontend.commands.compile.CompileCommand._build_circuit",
614
+ return_value=fake_circuit,
615
+ ):
616
+ rc = main(
617
+ [
618
+ "--no-banner",
619
+ "compile",
620
+ "-m",
621
+ str(model),
622
+ "-c",
623
+ str(nested_circuit),
624
+ ],
625
+ )
626
+
627
+ assert rc == 0
628
+ assert nested_circuit.parent.exists()
629
+
630
+
631
+ @pytest.mark.unit
632
+ def test_verify_mixed_positional_and_flag(tmp_path: Path) -> None:
633
+ circuit = tmp_path / "circuit.txt"
634
+ circuit.write_text("ok")
635
+
636
+ inputj = tmp_path / "in.json"
637
+ inputj.write_text('{"input":[0]}')
638
+
639
+ outputj = tmp_path / "out.json"
640
+ outputj.write_text('{"output":[0]}')
641
+
642
+ witness = tmp_path / "w.bin"
643
+ witness.write_bytes(b"\x00")
644
+
645
+ proof = tmp_path / "p.bin"
646
+ proof.write_bytes(b"\x00")
647
+
648
+ fake_circuit = MagicMock()
649
+ with patch(
650
+ "python.frontend.commands.verify.VerifyCommand._build_circuit",
651
+ return_value=fake_circuit,
652
+ ):
653
+ rc = main(
654
+ [
655
+ "--no-banner",
656
+ "verify",
657
+ str(circuit),
658
+ str(inputj),
659
+ "-o",
660
+ str(outputj),
661
+ "-w",
662
+ str(witness),
663
+ "-p",
664
+ str(proof),
665
+ ],
666
+ )
667
+
668
+ assert rc == 0
669
+
670
+
671
+ @pytest.mark.unit
672
+ def test_circuit_run_error_handling(tmp_path: Path) -> None:
673
+ model = tmp_path / "model.onnx"
674
+ model.write_bytes(b"\x00")
675
+ circuit = tmp_path / "circuit.txt"
676
+
677
+ fake_circuit = MagicMock()
678
+ fake_circuit.base_testing.side_effect = CircuitRunError("Test error")
679
+
680
+ with patch(
681
+ "python.frontend.commands.compile.CompileCommand._build_circuit",
682
+ return_value=fake_circuit,
683
+ ):
684
+ rc = main(
685
+ [
686
+ "--no-banner",
687
+ "compile",
688
+ "-m",
689
+ str(model),
690
+ "-c",
691
+ str(circuit),
692
+ ],
693
+ )
694
+
695
+ assert rc == 1
696
+
697
+
698
+ @pytest.mark.unit
699
+ def test_model_check_unsupported_op_error(tmp_path: Path) -> None:
700
+ model = tmp_path / "model.onnx"
701
+ model.write_bytes(b"\x00")
702
+
703
+ with patch("onnx.load") as mock_load:
704
+ mock_model = MagicMock()
705
+ mock_load.return_value = mock_model
706
+
707
+ with patch(
708
+ "python.core.model_processing.onnx_quantizer.onnx_op_quantizer.ONNXOpQuantizer",
709
+ ) as mock_quantizer_cls:
710
+ mock_quantizer = MagicMock()
711
+ mock_quantizer.check_model.side_effect = UnsupportedOpError(["BadOp"])
712
+ mock_quantizer_cls.return_value = mock_quantizer
713
+
714
+ rc = main(["--no-banner", "model_check", "-m", str(model)])
715
+
716
+ assert rc == 1
717
+
718
+
719
+ @pytest.mark.unit
720
+ def test_empty_string_arg() -> None:
721
+ rc = main(["--no-banner", "compile", "-m", "", "-c", "circuit.txt"])
722
+ assert rc == 1
723
+
724
+
725
+ @pytest.mark.unit
726
+ def test_flag_empty_string_uses_positional(tmp_path: Path) -> None:
727
+ model = tmp_path / "model.onnx"
728
+ model.write_bytes(b"\x00")
729
+ circuit = tmp_path / "circuit.txt"
730
+
731
+ fake_circuit = MagicMock()
732
+ with patch(
733
+ "python.frontend.commands.compile.CompileCommand._build_circuit",
734
+ return_value=fake_circuit,
735
+ ):
736
+ rc = main(
737
+ [
738
+ "--no-banner",
739
+ "compile",
740
+ str(model),
741
+ "-m",
742
+ "",
743
+ "-c",
744
+ str(circuit),
745
+ ],
746
+ )
747
+
748
+ assert rc == 1
749
+
750
+
751
+ # -----------------------
752
+ # bench command tests
753
+ # -----------------------
754
+
755
+
756
+ @pytest.mark.unit
757
+ def test_bench_list_models() -> None:
758
+ with patch(
759
+ "python.core.utils.model_registry.list_available_models",
760
+ return_value=["onnx: model1", "class: model2"],
761
+ ):
762
+ rc = main(["--no-banner", "bench", "list", "--list-models"])
763
+
764
+ assert rc == 0
765
+
766
+
767
+ @pytest.mark.unit
768
+ def test_bench_with_model_path(tmp_path: Path) -> None:
769
+ model = tmp_path / "model.onnx"
770
+ model.write_bytes(b"\x00")
771
+
772
+ with (
773
+ patch(
774
+ "python.frontend.commands.bench.model.ModelCommand._generate_model_input",
775
+ ),
776
+ patch("python.frontend.commands.bench.model.run_subprocess"),
777
+ ):
778
+ rc = main(["--no-banner", "bench", "model", "--model-path", str(model)])
779
+
780
+ assert rc == 0
781
+
782
+
783
+ @pytest.mark.unit
784
+ def test_bench_with_model_flag() -> None:
785
+ fake_model_entry = MagicMock()
786
+ fake_instance = MagicMock()
787
+ fake_instance.model_file_name = "test_model.onnx"
788
+ fake_model_entry.loader.return_value = fake_instance
789
+ fake_model_entry.name = "test_model"
790
+
791
+ with (
792
+ patch(
793
+ "python.core.utils.model_registry.get_models_to_test",
794
+ return_value=[fake_model_entry],
795
+ ),
796
+ patch(
797
+ "python.frontend.commands.bench.model.ModelCommand._generate_model_input",
798
+ ),
799
+ patch("python.frontend.commands.bench.model.run_subprocess"),
800
+ ):
801
+ rc = main(["--no-banner", "bench", "model", "--model", "test_model"])
802
+
803
+ assert rc == 0
804
+
805
+
806
+ @pytest.mark.unit
807
+ def test_bench_with_source_filter() -> None:
808
+ fake_model_entry = MagicMock()
809
+ fake_instance = MagicMock()
810
+ fake_instance.model_file_name = "test_model.onnx"
811
+ fake_model_entry.loader.return_value = fake_instance
812
+ fake_model_entry.name = "test_model"
813
+
814
+ with (
815
+ patch(
816
+ "python.core.utils.model_registry.get_models_to_test",
817
+ return_value=[fake_model_entry],
818
+ ) as mock_get,
819
+ patch(
820
+ "python.frontend.commands.bench.model.ModelCommand._generate_model_input",
821
+ ),
822
+ patch("python.frontend.commands.bench.model.run_subprocess"),
823
+ ):
824
+ rc = main(["--no-banner", "bench", "model", "--source", "onnx"])
825
+
826
+ assert rc == 0
827
+ mock_get.assert_called_once_with(None, "onnx")
828
+
829
+
830
+ @pytest.mark.unit
831
+ def test_bench_depth_sweep_simple() -> None:
832
+ with patch("python.frontend.commands.bench.sweep.run_subprocess") as mock_run:
833
+ rc = main(["--no-banner", "bench", "sweep", "depth"])
834
+
835
+ assert rc == 0
836
+ cmd = mock_run.call_args[0][0]
837
+ assert "python.scripts.gen_and_bench" in cmd[2]
838
+ assert "--sweep" in cmd
839
+ assert "depth" in cmd
840
+ assert "--depth-min" in cmd
841
+ assert "1" in cmd
842
+ assert "--depth-max" in cmd
843
+ assert "16" in cmd
844
+
845
+
846
+ @pytest.mark.unit
847
+ def test_bench_breadth_sweep_simple() -> None:
848
+ with patch("python.frontend.commands.bench.sweep.run_subprocess") as mock_run:
849
+ rc = main(["--no-banner", "bench", "sweep", "breadth"])
850
+
851
+ assert rc == 0
852
+ cmd = mock_run.call_args[0][0]
853
+ assert "python.scripts.gen_and_bench" in cmd[2]
854
+ assert "--sweep" in cmd
855
+ assert "breadth" in cmd
856
+ assert "--arch-depth" in cmd
857
+ assert "5" in cmd
858
+
859
+
860
+ @pytest.mark.unit
861
+ def test_bench_sweep_with_custom_args() -> None:
862
+ with patch("python.frontend.commands.bench.sweep.run_subprocess") as mock_run:
863
+ rc = main(
864
+ [
865
+ "--no-banner",
866
+ "bench",
867
+ "sweep",
868
+ "depth",
869
+ "--depth-min",
870
+ "5",
871
+ "--depth-max",
872
+ "10",
873
+ ],
874
+ )
875
+
876
+ assert rc == 0
877
+ cmd = mock_run.call_args[0][0]
878
+ assert "--depth-min" in cmd
879
+ idx_min = cmd.index("--depth-min")
880
+ assert cmd[idx_min + 1] == "5"
881
+ assert "--depth-max" in cmd
882
+ idx_max = cmd.index("--depth-max")
883
+ assert cmd[idx_max + 1] == "10"
884
+
885
+
886
+ @pytest.mark.unit
887
+ def test_bench_sweep_with_optional_args() -> None:
888
+ with patch("python.frontend.commands.bench.sweep.run_subprocess") as mock_run:
889
+ rc = main(
890
+ [
891
+ "--no-banner",
892
+ "bench",
893
+ "sweep",
894
+ "depth",
895
+ "--tag",
896
+ "test_tag",
897
+ "--onnx-dir",
898
+ "custom_onnx",
899
+ ],
900
+ )
901
+
902
+ assert rc == 0
903
+ cmd = mock_run.call_args[0][0]
904
+ assert "--tag" in cmd
905
+ assert "test_tag" in cmd
906
+ assert "--onnx-dir" in cmd
907
+ assert "custom_onnx" in cmd
908
+
909
+
910
+ @pytest.mark.unit
911
+ def test_bench_missing_required_args() -> None:
912
+ with pytest.raises(SystemExit) as exc_info:
913
+ main(["--no-banner", "bench"])
914
+ # argparse exits with code 2 for usage errors
915
+ assert exc_info.value.code == 2 # noqa: PLR2004
916
+
917
+
918
+ @pytest.mark.unit
919
+ def test_bench_nonexistent_model_path() -> None:
920
+ rc = main(["--no-banner", "bench", "model", "-m", "nonexistent.onnx"])
921
+ assert rc == 1
922
+
923
+
924
+ @pytest.mark.unit
925
+ def test_bench_no_models_found() -> None:
926
+ with patch(
927
+ "python.core.utils.model_registry.get_models_to_test",
928
+ return_value=[],
929
+ ):
930
+ rc = main(["--no-banner", "bench", "model", "--model", "nonexistent_model"])
931
+
932
+ assert rc == 1
933
+
934
+
935
+ @pytest.mark.unit
936
+ def test_bench_subprocess_failure(tmp_path: Path) -> None:
937
+ model = tmp_path / "model.onnx"
938
+ model.write_bytes(b"\x00")
939
+
940
+ fake_circuit = MagicMock()
941
+ fake_circuit.get_inputs.return_value = {"input": [0]}
942
+ fake_circuit.format_inputs.return_value = {"input": [0]}
943
+
944
+ with (
945
+ patch(
946
+ "python.frontend.commands.bench.model.ModelCommand._build_circuit",
947
+ return_value=fake_circuit,
948
+ ),
949
+ patch(
950
+ "python.frontend.commands.bench.model.run_subprocess",
951
+ side_effect=RuntimeError("Subprocess failed"),
952
+ ),
953
+ ):
954
+ rc = main(["--no-banner", "bench", "model", "-m", str(model)])
955
+
956
+ assert rc == 1
957
+
958
+
959
+ @pytest.mark.unit
960
+ def test_bench_model_load_failure(tmp_path: Path) -> None:
961
+ model = tmp_path / "model.onnx"
962
+ model.write_bytes(b"\x00")
963
+
964
+ fake_circuit = MagicMock()
965
+ fake_circuit.load_model.side_effect = RuntimeError("Failed to load model")
966
+
967
+ with patch(
968
+ "python.frontend.commands.bench.model.ModelCommand._build_circuit",
969
+ return_value=fake_circuit,
970
+ ):
971
+ rc = main(["--no-banner", "bench", "model", "-m", str(model)])
972
+
973
+ assert rc == 1
974
+
975
+
976
+ @pytest.mark.unit
977
+ def test_bench_input_generation_failure(tmp_path: Path) -> None:
978
+ model = tmp_path / "model.onnx"
979
+ model.write_bytes(b"\x00")
980
+
981
+ fake_circuit = MagicMock()
982
+ fake_circuit.load_model.return_value = None
983
+ fake_circuit.get_inputs.side_effect = RuntimeError("Failed to generate input")
984
+
985
+ with patch(
986
+ "python.frontend.commands.bench.model.ModelCommand._build_circuit",
987
+ return_value=fake_circuit,
988
+ ):
989
+ rc = main(["--no-banner", "bench", "model", "-m", str(model)])
990
+
991
+ assert rc == 1
992
+
993
+
994
+ @pytest.mark.unit
995
+ def test_bench_with_iterations(tmp_path: Path) -> None:
996
+ model = tmp_path / "model.onnx"
997
+ model.write_bytes(b"\x00")
998
+
999
+ with (
1000
+ patch(
1001
+ "python.frontend.commands.bench.model.ModelCommand._generate_model_input",
1002
+ ),
1003
+ patch("python.frontend.commands.bench.model.run_subprocess") as mock_run,
1004
+ ):
1005
+ rc = main(
1006
+ [
1007
+ "--no-banner",
1008
+ "bench",
1009
+ "model",
1010
+ "--model-path",
1011
+ str(model),
1012
+ "--iterations",
1013
+ "10",
1014
+ ],
1015
+ )
1016
+
1017
+ assert rc == 0
1018
+ cmd = mock_run.call_args[0][0]
1019
+ assert "--iterations" in cmd
1020
+ idx = cmd.index("--iterations")
1021
+ assert cmd[idx + 1] == "10"