JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +6 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,969 @@
1
+ import sys
2
+ from collections.abc import Generator
3
+ from pathlib import Path
4
+ from unittest.mock import MagicMock, patch
5
+
6
+ import pytest
7
+
8
+ sys.modules.pop("python.core.circuits.base", None)
9
+
10
+
11
+ with (
12
+ patch(
13
+ "python.core.utils.helper_functions.compute_and_store_output",
14
+ lambda x: x,
15
+ ),
16
+ patch(
17
+ "python.core.utils.helper_functions.prepare_io_files",
18
+ lambda f: f,
19
+ ),
20
+ ): # MUST BE BEFORE THE UUT GETS IMPORTED ANYWHERE!
21
+ from python.core.circuits.base import (
22
+ Circuit,
23
+ CircuitExecutionConfig,
24
+ RunType,
25
+ ZKProofSystems,
26
+ )
27
+ from python.core.circuits.errors import (
28
+ CircuitConfigurationError,
29
+ CircuitFileError,
30
+ CircuitInputError,
31
+ CircuitProcessingError,
32
+ CircuitRunError,
33
+ )
34
+
35
+
36
+ # ---------- Test __init__ ----------
37
+ @pytest.mark.unit
38
+ def test_circuit_init_defaults() -> None:
39
+ c = Circuit()
40
+ assert c.input_folder == "inputs"
41
+ assert c.proof_folder == "analysis"
42
+ assert c.temp_folder == "temp"
43
+ assert c.circuit_folder == ""
44
+ assert c.weights_folder == "weights"
45
+ assert c.output_folder == "output"
46
+ assert c.proof_system == ZKProofSystems.Expander
47
+ assert c._file_info is None
48
+ assert c.required_keys is None
49
+
50
+
51
+ @pytest.mark.unit
52
+ def test_circuit_execution_config_with_new_paths() -> None:
53
+ config = CircuitExecutionConfig(
54
+ circuit_name="test_circuit",
55
+ metadata_path="meta.json",
56
+ architecture_path="arch.json",
57
+ w_and_b_path="weights.json",
58
+ )
59
+ assert config.circuit_name == "test_circuit"
60
+ assert config.metadata_path == "meta.json"
61
+ assert config.architecture_path == "arch.json"
62
+ assert config.w_and_b_path == "weights.json"
63
+
64
+
65
+ # ---------- Test parse_inputs ----------
66
+ @pytest.mark.unit
67
+ def test_parse_inputs_missing_required_keys() -> None:
68
+ c = Circuit()
69
+ c.required_keys = ["x", "y"]
70
+ with pytest.raises(CircuitInputError, match="Missing required parameter: 'x'"):
71
+ c.parse_inputs(y=5)
72
+
73
+
74
+ @pytest.mark.unit
75
+ def test_parse_inputs_type_check() -> None:
76
+ c = Circuit()
77
+ c.required_keys = ["x"]
78
+ with pytest.raises(
79
+ CircuitInputError,
80
+ match="Parameter 'x' must be an int or list of ints",
81
+ ):
82
+ c.parse_inputs(x="not-an-int")
83
+
84
+
85
+ @pytest.mark.unit
86
+ def test_parse_inputs_success_int() -> None:
87
+ c = Circuit()
88
+ c.required_keys = ["x", "y"]
89
+ x = 10
90
+ y = 20
91
+
92
+ c.parse_inputs(x=x, y=y)
93
+
94
+ assert c.x == x
95
+ assert c.y == y
96
+
97
+
98
+ @pytest.mark.unit
99
+ def test_parse_inputs_success_list() -> None:
100
+ c = Circuit()
101
+ c.required_keys = ["arr"]
102
+ c.parse_inputs(arr=[1, 2, 3])
103
+ assert c.arr == [1, 2, 3]
104
+
105
+
106
+ @pytest.mark.unit
107
+ def test_parse_inputs_required_keys_none() -> None:
108
+ c = Circuit()
109
+ with pytest.raises(CircuitConfigurationError):
110
+ c.parse_inputs()
111
+
112
+
113
+ # ---------- Test Not Implemented --------------
114
+ @pytest.mark.unit
115
+ def test_get_inputs_not_implemented() -> None:
116
+ c = Circuit()
117
+ with pytest.raises(NotImplementedError, match="get_inputs must be implemented"):
118
+ c.get_inputs()
119
+
120
+
121
+ @pytest.mark.unit
122
+ def test_get_outputs_not_implemented() -> None:
123
+ c = Circuit()
124
+ with pytest.raises(NotImplementedError, match="get_outputs must be implemented"):
125
+ c.get_outputs()
126
+
127
+
128
+ # ---------- Test parse_proof_run_type ----------
129
+
130
+
131
+ @pytest.mark.unit
132
+ @patch("python.core.circuits.base.compile_circuit")
133
+ @patch("python.core.circuits.base.generate_witness")
134
+ @patch("python.core.circuits.base.generate_proof")
135
+ @patch("python.core.circuits.base.generate_verification")
136
+ @patch("python.core.circuits.base.run_end_to_end")
137
+ def test_parse_proof_dispatch_logic(
138
+ mock_end_to_end: MagicMock,
139
+ mock_verify: MagicMock,
140
+ mock_proof: MagicMock,
141
+ mock_witness: MagicMock,
142
+ mock_compile: MagicMock,
143
+ ) -> None:
144
+ c = Circuit()
145
+
146
+ # Mock internal preprocessing methods
147
+ c._compile_preprocessing = MagicMock()
148
+ c._gen_witness_preprocessing = MagicMock(return_value="i")
149
+ c.adjust_inputs = MagicMock(return_value="i")
150
+ c.rename_inputs = MagicMock(return_value="i")
151
+
152
+ c.load_and_compare_witness_to_io = MagicMock(return_value="True")
153
+
154
+ # COMPILE_CIRCUIT
155
+ config_compile = CircuitExecutionConfig(
156
+ witness_file="w",
157
+ input_file="i",
158
+ proof_file="p",
159
+ public_path="pub",
160
+ verification_key="vk",
161
+ circuit_name="circuit",
162
+ circuit_path="path",
163
+ proof_system=ZKProofSystems.Expander,
164
+ output_file="out",
165
+ metadata_path="metadata",
166
+ architecture_path="architecture",
167
+ w_and_b_path="w_and_b",
168
+ quantized_path="q",
169
+ run_type=RunType.COMPILE_CIRCUIT,
170
+ dev_mode=False,
171
+ ecc=True,
172
+ write_json=False,
173
+ bench=False,
174
+ )
175
+ c.parse_proof_run_type(config_compile)
176
+ mock_compile.assert_called_once()
177
+ c._compile_preprocessing.assert_called_once_with(
178
+ metadata_path="metadata",
179
+ architecture_path="architecture",
180
+ w_and_b_path="w_and_b",
181
+ quantized_path="q",
182
+ )
183
+ _, kwargs = mock_compile.call_args
184
+ assert kwargs == {
185
+ "circuit_name": "circuit",
186
+ "circuit_path": "path",
187
+ "proof_system": ZKProofSystems.Expander,
188
+ "dev_mode": False,
189
+ "bench": False,
190
+ "architecture_path": "architecture",
191
+ "metadata_path": "metadata",
192
+ "w_and_b_path": "w_and_b",
193
+ }
194
+
195
+ # GEN_WITNESS
196
+ config_witness = CircuitExecutionConfig(
197
+ witness_file="w",
198
+ input_file="i",
199
+ proof_file="p",
200
+ public_path="pub",
201
+ verification_key="vk",
202
+ circuit_name="circuit",
203
+ circuit_path="path",
204
+ proof_system=ZKProofSystems.Expander,
205
+ output_file="out",
206
+ metadata_path="metadata",
207
+ architecture_path="architecture",
208
+ w_and_b_path="w_and_b",
209
+ quantized_path="q",
210
+ run_type=RunType.GEN_WITNESS,
211
+ dev_mode=False,
212
+ ecc=True,
213
+ write_json=False,
214
+ bench=False,
215
+ )
216
+ c.parse_proof_run_type(config_witness)
217
+ mock_witness.assert_called_once()
218
+ c._gen_witness_preprocessing.assert_called()
219
+ _, kwargs = mock_witness.call_args
220
+ assert kwargs == {
221
+ "circuit_name": "circuit",
222
+ "circuit_path": "path",
223
+ "witness_file": "w",
224
+ "input_file": "i",
225
+ "output_file": "out",
226
+ "proof_system": ZKProofSystems.Expander,
227
+ "dev_mode": False,
228
+ "bench": False,
229
+ "metadata_path": "metadata",
230
+ }
231
+
232
+ # PROVE_WITNESS
233
+ config_prove = CircuitExecutionConfig(
234
+ witness_file="w",
235
+ input_file="i",
236
+ proof_file="p",
237
+ public_path="pub",
238
+ verification_key="vk",
239
+ circuit_name="circuit",
240
+ circuit_path="path",
241
+ proof_system=ZKProofSystems.Expander,
242
+ output_file="out",
243
+ metadata_path="metadata",
244
+ architecture_path="architecture",
245
+ w_and_b_path="w_and_b",
246
+ quantized_path="q",
247
+ run_type=RunType.PROVE_WITNESS,
248
+ dev_mode=False,
249
+ ecc=True,
250
+ write_json=False,
251
+ bench=False,
252
+ )
253
+ c.parse_proof_run_type(config_prove)
254
+ mock_proof.assert_called_once()
255
+ _, kwargs = mock_proof.call_args
256
+
257
+ assert kwargs == {
258
+ "circuit_name": "circuit",
259
+ "circuit_path": "path",
260
+ "witness_file": "w",
261
+ "proof_file": "p",
262
+ "proof_system": ZKProofSystems.Expander,
263
+ "dev_mode": False,
264
+ "ecc": True,
265
+ "bench": False,
266
+ "metadata_path": "metadata",
267
+ }
268
+
269
+ # GEN_VERIFY
270
+ config_verify = CircuitExecutionConfig(
271
+ witness_file="w",
272
+ input_file="i",
273
+ proof_file="p",
274
+ public_path="pub",
275
+ verification_key="vk",
276
+ circuit_name="circuit",
277
+ circuit_path="path",
278
+ proof_system=ZKProofSystems.Expander,
279
+ output_file="out",
280
+ metadata_path="metadata",
281
+ architecture_path="architecture",
282
+ w_and_b_path="w_and_b",
283
+ quantized_path="q",
284
+ run_type=RunType.GEN_VERIFY,
285
+ dev_mode=False,
286
+ ecc=True,
287
+ write_json=False,
288
+ bench=False,
289
+ )
290
+ c.parse_proof_run_type(config_verify)
291
+ mock_verify.assert_called_once()
292
+ _, kwargs = mock_verify.call_args
293
+ assert kwargs == {
294
+ "circuit_name": "circuit",
295
+ "circuit_path": "path",
296
+ "input_file": "i",
297
+ "output_file": "out",
298
+ "witness_file": "w",
299
+ "proof_file": "p",
300
+ "proof_system": ZKProofSystems.Expander,
301
+ "dev_mode": False,
302
+ "ecc": True,
303
+ "bench": False,
304
+ "metadata_path": "metadata",
305
+ }
306
+
307
+ # END_TO_END
308
+ config_end_to_end = CircuitExecutionConfig(
309
+ witness_file="w",
310
+ input_file="i",
311
+ proof_file="p",
312
+ public_path="pub",
313
+ verification_key="vk",
314
+ circuit_name="circuit",
315
+ circuit_path="path",
316
+ proof_system=ZKProofSystems.Expander,
317
+ output_file="out",
318
+ metadata_path="metadata",
319
+ architecture_path="architecture",
320
+ w_and_b_path="w_and_b",
321
+ quantized_path="q",
322
+ run_type=RunType.END_TO_END,
323
+ dev_mode=False,
324
+ ecc=True,
325
+ write_json=False,
326
+ bench=False,
327
+ )
328
+ c.parse_proof_run_type(config_end_to_end)
329
+
330
+ preprocess_call_count = 2
331
+
332
+ mock_end_to_end.assert_called_once()
333
+ assert c._compile_preprocessing.call_count >= preprocess_call_count
334
+ assert c._gen_witness_preprocessing.call_count >= preprocess_call_count
335
+
336
+
337
+ # ---------- Test new methods for metadata, architecture, w_and_b ----------
338
+ @pytest.mark.unit
339
+ def test_get_metadata_default() -> None:
340
+ c = Circuit()
341
+ assert c.get_metadata() == {}
342
+
343
+
344
+ @pytest.mark.unit
345
+ def test_get_architecture_default() -> None:
346
+ c = Circuit()
347
+ assert c.get_architecture() == {}
348
+
349
+
350
+ @pytest.mark.unit
351
+ def test_get_w_and_b_default() -> None:
352
+ c = Circuit()
353
+ assert c.get_w_and_b() == {}
354
+
355
+
356
+ # ---------- Optional: test get_weights ----------
357
+ @pytest.mark.unit
358
+ def test_get_weights_default() -> None:
359
+ c = Circuit()
360
+ assert c.get_weights() == {}
361
+
362
+
363
+ @pytest.mark.unit
364
+ def test_get_inputs_from_file() -> None:
365
+ c = Circuit()
366
+ c.scale_base = 2
367
+ c.scale_exponent = 2
368
+ with patch(
369
+ "python.core.circuits.base.read_from_json",
370
+ return_value={"input": [1, 2, 3, 4]},
371
+ ):
372
+ x = c.get_inputs_from_file("", is_scaled=True)
373
+ assert x == {"input": [1, 2, 3, 4]}
374
+
375
+ y = c.get_inputs_from_file("", is_scaled=False)
376
+ assert y == {"input": [4, 8, 12, 16]}
377
+
378
+
379
+ @pytest.mark.unit
380
+ def test_get_inputs_from_file_multiple_inputs() -> None:
381
+ c = Circuit()
382
+ c.scale_base = 2
383
+ c.scale_exponent = 2
384
+ with patch(
385
+ "python.core.circuits.base.read_from_json",
386
+ return_value={"input": [1, 2, 3, 4], "nonce": 25},
387
+ ):
388
+ x = c.get_inputs_from_file("", is_scaled=True)
389
+ assert x == {"input": [1, 2, 3, 4], "nonce": 25}
390
+
391
+ y = c.get_inputs_from_file("", is_scaled=False)
392
+ assert y == {"input": [4, 8, 12, 16], "nonce": 100}
393
+
394
+
395
+ @pytest.mark.unit
396
+ def test_get_inputs_from_file_dne() -> None:
397
+ c = Circuit()
398
+ c.scale_base = 2
399
+ c.scale_exponent = 2
400
+ with pytest.raises(CircuitFileError, match="Failed to read input file"):
401
+ c.get_inputs_from_file("this_file_should_not_exist_12345.json", is_scaled=True)
402
+
403
+
404
+ @pytest.mark.unit
405
+ def test_format_outputs() -> None:
406
+ c = Circuit()
407
+ out = c.format_outputs([10, 15, 20])
408
+ assert out == {"output": [10, 15, 20]}
409
+
410
+
411
+ # ---------- _gen_witness_preprocessing ----------
412
+ @pytest.mark.unit
413
+ @patch("python.core.circuits.base.to_json")
414
+ def test_gen_witness_preprocessing_write_json_true(mock_to_json: MagicMock) -> None:
415
+ c = Circuit()
416
+ c._file_info = {"quantized_model_path": "quant.pt"}
417
+ c.load_quantized_model = MagicMock()
418
+ c.get_inputs = MagicMock(return_value="inputs")
419
+ c.get_outputs = MagicMock(return_value="outputs")
420
+ c.format_inputs = MagicMock(return_value={"input": 1})
421
+ c.format_outputs = MagicMock(return_value={"output": 2})
422
+
423
+ c._gen_witness_preprocessing(
424
+ "in.json",
425
+ "out.json",
426
+ None,
427
+ write_json=True,
428
+ is_scaled=True,
429
+ )
430
+
431
+ c.load_quantized_model.assert_called_once_with("quant.pt")
432
+ c.get_inputs.assert_called_once()
433
+ c.get_outputs.assert_called_once_with("inputs")
434
+ mock_to_json.assert_any_call({"input": 1}, "in.json")
435
+ mock_to_json.assert_any_call({"output": 2}, "out.json")
436
+
437
+
438
+ @pytest.mark.unit
439
+ @patch("python.core.circuits.base.to_json")
440
+ def test_gen_witness_preprocessing_write_json_false(mock_to_json: MagicMock) -> None:
441
+ c = Circuit()
442
+ c._file_info = {"quantized_model_path": "quant.pt"}
443
+ c.load_quantized_model = MagicMock()
444
+ c.get_inputs_from_file = MagicMock(return_value="mock_inputs")
445
+ c.reshape_inputs = MagicMock(return_value="in.json")
446
+ c.rescale_inputs = MagicMock(return_value="in.json")
447
+ c.rename_inputs = MagicMock(return_value="in.json")
448
+ c.rescale_and_reshape_inputs = MagicMock(return_value="in.json")
449
+ c.adjust_inputs = MagicMock(return_value="in.json")
450
+
451
+ c.get_outputs = MagicMock(return_value="mock_outputs")
452
+ c.format_outputs = MagicMock(return_value={"output": 99})
453
+
454
+ c._gen_witness_preprocessing(
455
+ "in.json",
456
+ "out.json",
457
+ None,
458
+ write_json=False,
459
+ is_scaled=False,
460
+ )
461
+
462
+ c.load_quantized_model.assert_called_once_with("quant.pt")
463
+ c.get_inputs_from_file.assert_called_once_with("in.json", is_scaled=False)
464
+ c.get_outputs.assert_called_once_with("mock_inputs")
465
+ c.format_outputs.assert_called_once_with("mock_outputs")
466
+ mock_to_json.assert_called_once_with({"output": 99}, "out.json")
467
+
468
+
469
+ # ---------- _compile_preprocessing ----------
470
+ @pytest.mark.unit
471
+ @patch("python.core.circuits.base.to_json")
472
+ def test_compile_preprocessing_saves_all_files(mock_to_json: MagicMock) -> None:
473
+ c = Circuit()
474
+ c._file_info = {"quantized_model_path": "model.pth"}
475
+ c.get_model_and_quantize = MagicMock()
476
+ c.get_metadata = MagicMock(return_value={"version": "1.0"})
477
+ c.get_architecture = MagicMock(return_value={"layers": ["conv", "relu"]})
478
+ c.get_w_and_b = MagicMock(return_value={"weights": [1, 2, 3]})
479
+ c.save_quantized_model = MagicMock()
480
+
481
+ c._compile_preprocessing("metadata.json", "architecture.json", "w_and_b.json", None)
482
+
483
+ c.get_model_and_quantize.assert_called_once()
484
+ c.get_metadata.assert_called_once()
485
+ c.get_architecture.assert_called_once()
486
+ c.get_w_and_b.assert_called_once()
487
+ c.save_quantized_model.assert_called_once_with("model.pth")
488
+ mock_to_json.assert_any_call({"version": "1.0"}, "metadata.json")
489
+ mock_to_json.assert_any_call({"layers": ["conv", "relu"]}, "architecture.json")
490
+ mock_to_json.assert_any_call({"weights": [1, 2, 3]}, "w_and_b.json")
491
+
492
+
493
+ @pytest.mark.unit
494
+ @patch("python.core.circuits.base.to_json")
495
+ def test_compile_preprocessing_saves_all_files(mock_to_json: MagicMock) -> None:
496
+ c = Circuit()
497
+ c._file_info = {"quantized_model_path": "model.pth"}
498
+ c.get_model_and_quantize = MagicMock()
499
+ c.get_metadata = MagicMock(return_value={"version": "1.0"})
500
+ c.get_architecture = MagicMock(return_value={"layers": ["conv", "relu"]})
501
+ c.get_w_and_b = MagicMock(return_value={"weights": [1, 2, 3]})
502
+ c.save_quantized_model = MagicMock()
503
+
504
+ c._compile_preprocessing("metadata.json", "architecture.json", "w_and_b.json", None)
505
+
506
+ c.get_model_and_quantize.assert_called_once()
507
+ c.get_metadata.assert_called_once()
508
+ c.get_architecture.assert_called_once()
509
+ c.get_w_and_b.assert_called_once()
510
+ c.save_quantized_model.assert_called_once_with("model.pth")
511
+ mock_to_json.assert_any_call({"version": "1.0"}, "metadata.json")
512
+ mock_to_json.assert_any_call({"layers": ["conv", "relu"]}, "architecture.json")
513
+ mock_to_json.assert_any_call({"weights": [1, 2, 3]}, "w_and_b.json")
514
+
515
+
516
+ @pytest.mark.unit
517
+ @patch("python.core.circuits.base.to_json")
518
+ def test_compile_preprocessing_weights_dict(mock_to_json: MagicMock) -> None:
519
+ c = Circuit()
520
+ c._file_info = {"quantized_model_path": "model.pth"}
521
+ c.get_model_and_quantize = MagicMock()
522
+ c.get_metadata = MagicMock(return_value={"TEST": "2"})
523
+ c.get_architecture = MagicMock(return_value={"TEST": "1"})
524
+ c.get_w_and_b = MagicMock(return_value={"a": 1})
525
+ c.save_quantized_model = MagicMock()
526
+
527
+ c._compile_preprocessing("metadata.json", "architecture.json", "w_and_b.json", None)
528
+
529
+ c.get_model_and_quantize.assert_called_once()
530
+ c.get_w_and_b.assert_called_once()
531
+ c.save_quantized_model.assert_called_once_with("model.pth")
532
+ mock_to_json.assert_any_call({"TEST": "2"}, "metadata.json")
533
+ mock_to_json.assert_any_call({"TEST": "1"}, "architecture.json")
534
+ mock_to_json.assert_any_call({"a": 1}, "w_and_b.json")
535
+
536
+
537
+ @pytest.mark.unit
538
+ @patch("python.core.circuits.base.to_json")
539
+ def test_compile_preprocessing_weights_list(
540
+ mock_to_json: MagicMock,
541
+ ) -> None:
542
+ c = Circuit()
543
+ c._file_info = {"quantized_model_path": "model.pth"}
544
+ c.get_model_and_quantize = MagicMock()
545
+ c.get_metadata = MagicMock(return_value={"TEST": "1"})
546
+ c.get_architecture = MagicMock(return_value={"TEST": "2"})
547
+ c.get_w_and_b = MagicMock(return_value=[{"w1": 1}, {"w2": 2}, {"w3": 3}])
548
+ c.save_quantized_model = MagicMock()
549
+
550
+ c._compile_preprocessing("metadata.json", "architecture.json", "w_and_b.json", None)
551
+
552
+ call_count = 5 # 2 for metadata/architecture + 3 for weights
553
+
554
+ assert mock_to_json.call_count == call_count
555
+ mock_to_json.assert_any_call({"TEST": "1"}, "metadata.json")
556
+ mock_to_json.assert_any_call({"TEST": "2"}, "architecture.json")
557
+ mock_to_json.assert_any_call({"w1": 1}, Path("w_and_b.json"))
558
+ mock_to_json.assert_any_call({"w2": 2}, Path("w_and_b2.json"))
559
+ mock_to_json.assert_any_call({"w3": 3}, Path("w_and_b3.json"))
560
+
561
+
562
+ @pytest.mark.unit
563
+ @patch("python.core.circuits.base.to_json")
564
+ def test_compile_preprocessing_weights_list_single_call(
565
+ mock_to_json: MagicMock,
566
+ ) -> None:
567
+ c = Circuit()
568
+ c._file_info = {"quantized_model_path": "model.pth"}
569
+ c.get_model_and_quantize = MagicMock()
570
+ c.get_metadata = MagicMock(return_value={})
571
+ c.get_architecture = MagicMock(return_value={})
572
+ c.get_weights = MagicMock(return_value=[{"w1": 1}, {"w2": 2}, {"w3": 3}])
573
+ c.save_quantized_model = MagicMock()
574
+
575
+ c._compile_preprocessing("metadata.json", "architecture.json", "w_and_b.json", None)
576
+
577
+ call_count = 3
578
+
579
+ assert mock_to_json.call_count == call_count # +2 for metadata and architecture
580
+ mock_to_json.assert_any_call({"w1": 1}, Path("w_and_b.json"))
581
+ mock_to_json.assert_any_call({"w2": 2}, Path("w_and_b2.json"))
582
+ mock_to_json.assert_any_call({"w3": 3}, Path("w_and_b3.json"))
583
+
584
+
585
+ @pytest.mark.unit
586
+ def test_compile_preprocessing_raises_on_bad_weights() -> None:
587
+ c = Circuit()
588
+ c._file_info = {"quantized_model_path": "model.pth"}
589
+ c.get_model_and_quantize = MagicMock()
590
+ c.get_metadata = MagicMock(return_value={})
591
+ c.get_architecture = MagicMock(return_value={})
592
+ c.get_w_and_b = MagicMock(return_value="bad_type")
593
+ c.save_quantized_model = MagicMock()
594
+
595
+ with pytest.raises(CircuitConfigurationError, match="Unsupported w_and_b type"):
596
+ c._compile_preprocessing(
597
+ "metadata.json",
598
+ "architecture.json",
599
+ "w_and_b.json",
600
+ None,
601
+ )
602
+
603
+
604
+ # ---------- Test check attributes --------------
605
+ @pytest.mark.unit
606
+ def test_check_attributes_true() -> None:
607
+ c = Circuit()
608
+ c.required_keys = ["input"]
609
+ c.name = "test"
610
+ c.scale_exponent = 2
611
+ c.scale_base = 2
612
+ c.check_attributes()
613
+
614
+
615
+ @pytest.mark.unit
616
+ def test_check_attributes_no_scaling() -> None:
617
+ c = Circuit()
618
+ c.required_keys = ["input"]
619
+ c.name = "test"
620
+ c.scale_base = 2
621
+ with pytest.raises(CircuitConfigurationError) as exc_info:
622
+ c.check_attributes()
623
+
624
+ msg = str(exc_info.value)
625
+ assert "Circuit class (python) is misconfigured" in msg
626
+ assert "scale_exponent" in msg
627
+
628
+
629
+ @pytest.mark.unit
630
+ def test_check_attributes_no_scalebase() -> None:
631
+ c = Circuit()
632
+ c.required_keys = ["input"]
633
+ c.name = "test"
634
+ c.scale_exponent = 2
635
+
636
+ with pytest.raises(CircuitConfigurationError) as exc_info:
637
+ c.check_attributes()
638
+
639
+ msg = str(exc_info.value)
640
+ assert "Circuit class (python) is misconfigured" in msg
641
+ assert "scale_base" in msg
642
+
643
+
644
+ @pytest.mark.unit
645
+ def test_check_attributes_no_name() -> None:
646
+ c = Circuit()
647
+ c.required_keys = ["input"]
648
+ c.scale_base = 2
649
+ c.scale_exponent = 2
650
+
651
+ with pytest.raises(CircuitConfigurationError) as exc_info:
652
+ c.check_attributes()
653
+
654
+ msg = str(exc_info.value)
655
+ assert "Circuit class (python) is misconfigured" in msg
656
+ assert "name" in msg
657
+
658
+
659
+ # ---------- base_testing ------------
660
+ @pytest.mark.unit
661
+ @patch.object(Circuit, "parse_proof_run_type")
662
+ def test_base_testing_calls_parse_proof_run_type_correctly(
663
+ mock_parse: MagicMock,
664
+ ) -> None:
665
+ c = Circuit()
666
+ c.name = "test"
667
+
668
+ c._file_info = {}
669
+ c._file_info["metadata_path"] = "metadata.json"
670
+ c._file_info["architecture_path"] = "architecture.json"
671
+ c._file_info["w_and_b_path"] = "w_and_b.json"
672
+ c.base_testing(
673
+ CircuitExecutionConfig(
674
+ run_type=RunType.GEN_WITNESS,
675
+ witness_file="w.wtns",
676
+ input_file="i.json",
677
+ proof_file="p.json",
678
+ public_path="pub.json",
679
+ verification_key="vk.key",
680
+ circuit_name="circuit_model",
681
+ output_file="o.json",
682
+ circuit_path="circuit_path.txt",
683
+ quantized_path="quantized_path.pt",
684
+ write_json=True,
685
+ proof_system=ZKProofSystems.Expander,
686
+ ),
687
+ )
688
+
689
+ mock_parse.assert_called_once()
690
+ expected_config = CircuitExecutionConfig(
691
+ witness_file="w.wtns",
692
+ input_file="i.json",
693
+ proof_file="p.json",
694
+ public_path="pub.json",
695
+ verification_key="vk.key",
696
+ circuit_name="circuit_model",
697
+ circuit_path="circuit_path.txt",
698
+ proof_system=ZKProofSystems.Expander,
699
+ output_file="o.json",
700
+ metadata_path="metadata.json",
701
+ architecture_path="architecture.json",
702
+ w_and_b_path="w_and_b.json",
703
+ quantized_path="quantized_path.pt",
704
+ run_type=RunType.GEN_WITNESS,
705
+ dev_mode=False,
706
+ ecc=True,
707
+ write_json=True,
708
+ bench=False,
709
+ )
710
+ mock_parse.assert_called_once_with(expected_config)
711
+
712
+
713
+ @pytest.mark.unit
714
+ def test_prepare_io_files_sets_new_file_paths() -> None:
715
+ """Test that prepare_io_files decorator sets the new file paths correctly."""
716
+ from python.core.utils.helper_functions import prepare_io_files # noqa: PLC0415
717
+
718
+ class TestCircuit(Circuit):
719
+ def __init__(self: Circuit) -> None:
720
+ super().__init__()
721
+ self.name = "test_circuit"
722
+
723
+ @prepare_io_files
724
+ def test_method(self: Circuit, exec_config: str) -> str:
725
+ _ = exec_config
726
+ return self._file_info
727
+
728
+ c = TestCircuit()
729
+
730
+ with patch("python.core.utils.helper_functions.get_files") as mock_get_files:
731
+ mock_get_files.return_value = {
732
+ "witness_file": "witness.wtns",
733
+ "input_file": "input.json",
734
+ "proof_path": "proof.json",
735
+ "public_path": "public.json",
736
+ "circuit_name": "test_circuit",
737
+ "metadata_path": "metadata.json",
738
+ "architecture_path": "architecture.json",
739
+ "w_and_b_path": "w_and_b.json",
740
+ "output_file": "output.json",
741
+ }
742
+
743
+ config = CircuitExecutionConfig(run_type=RunType.COMPILE_CIRCUIT)
744
+ file_info = c.test_method(config)
745
+
746
+ assert file_info["metadata_path"] == "metadata.json"
747
+ assert file_info["architecture_path"] == "architecture.json"
748
+ assert file_info["w_and_b_path"] == "w_and_b.json"
749
+ assert config.metadata_path == "metadata.json"
750
+ assert config.architecture_path == "architecture.json"
751
+ assert config.w_and_b_path == "w_and_b.json"
752
+
753
+
754
+ @pytest.mark.unit
755
+ @patch.object(Circuit, "parse_proof_run_type")
756
+ def test_base_testing_uses_default_circuit_path(mock_parse: MagicMock) -> None:
757
+ class MyCircuit(Circuit):
758
+ def __init__(self: "MyCircuit") -> None:
759
+ super().__init__()
760
+ self._file_info = {
761
+ "metadata_path": "metadata.json",
762
+ "architecture_path": "architecture.json",
763
+ "w_and_b_path": "w_and_b.json",
764
+ }
765
+
766
+ c = MyCircuit()
767
+ c.base_testing(CircuitExecutionConfig(circuit_name="test_model"))
768
+
769
+ mock_parse.assert_called_once()
770
+ config = mock_parse.call_args[0][0]
771
+
772
+ assert config.circuit_name == "test_model"
773
+ assert config.circuit_path == "test_model.txt"
774
+ assert config.metadata_path == "metadata.json"
775
+ assert config.architecture_path == "architecture.json"
776
+ assert config.w_and_b_path == "w_and_b.json"
777
+
778
+
779
+ @pytest.mark.unit
780
+ @patch.object(Circuit, "parse_proof_run_type")
781
+ def test_base_testing_returns_none(mock_parse: MagicMock) -> None:
782
+ class MyCircuit(Circuit):
783
+ def __init__(self: "MyCircuit") -> None:
784
+ super().__init__()
785
+ self._file_info = {
786
+ "metadata_path": "metadata.json",
787
+ "architecture_path": "architecture.json",
788
+ "w_and_b_path": "w_and_b.json",
789
+ }
790
+
791
+ c = MyCircuit()
792
+ result = c.base_testing(CircuitExecutionConfig(circuit_name="abc"))
793
+ assert result is None
794
+ mock_parse.assert_called_once()
795
+
796
+
797
+ @pytest.mark.unit
798
+ @patch.object(Circuit, "parse_proof_run_type")
799
+ def test_base_testing_weights_exists(mock_parse: MagicMock) -> None:
800
+ _ = mock_parse
801
+
802
+ class MyCircuit(Circuit):
803
+ def __init__(self: "MyCircuit") -> None:
804
+ super().__init__()
805
+
806
+ c = MyCircuit()
807
+ with pytest.raises(CircuitConfigurationError, match="Circuit file information"):
808
+ c.base_testing(CircuitExecutionConfig(circuit_name="abc"))
809
+
810
+
811
+ @pytest.mark.unit
812
+ def test_parse_proof_run_type_invalid_run_type(
813
+ caplog: Generator[pytest.LogCaptureFixture, None, None],
814
+ ) -> None:
815
+ c = Circuit()
816
+ config_invalid = CircuitExecutionConfig(
817
+ witness_file="w.wtns",
818
+ input_file="i.json",
819
+ proof_file="p.json",
820
+ public_path="pub.json",
821
+ verification_key="vk.key",
822
+ circuit_name="model",
823
+ circuit_path="path.txt",
824
+ proof_system=None,
825
+ output_file="out.json",
826
+ metadata_path="metadata.json",
827
+ architecture_path="architecture.json",
828
+ w_and_b_path="w_and_b.json",
829
+ quantized_path="quantized_model.pt",
830
+ run_type="NOT_A_REAL_RUN_TYPE", # Invalid run type
831
+ dev_mode=False,
832
+ ecc=True,
833
+ write_json=False,
834
+ bench=False,
835
+ )
836
+
837
+ with pytest.raises(CircuitRunError, match="Unsupported run type"):
838
+ c.parse_proof_run_type(config_invalid)
839
+
840
+ # Check that the error messages are logged
841
+ assert "Unknown run type: NOT_A_REAL_RUN_TYPE" in caplog.text
842
+ assert "Operation NOT_A_REAL_RUN_TYPE failed" in caplog.text
843
+
844
+
845
+ @pytest.mark.unit
846
+ @patch(
847
+ "python.core.circuits.base.compile_circuit",
848
+ side_effect=Exception("Boom goes the dynamite!"),
849
+ )
850
+ @patch.object(Circuit, "_compile_preprocessing")
851
+ def test_parse_proof_run_type_catches_internal_exception(
852
+ mock_compile_preprocessing: MagicMock,
853
+ mock_compile: MagicMock,
854
+ caplog: Generator[pytest.LogCaptureFixture, None, None],
855
+ ) -> None:
856
+ c = Circuit()
857
+
858
+ config_exception = CircuitExecutionConfig(
859
+ witness_file="w.wtns",
860
+ input_file="i.json",
861
+ proof_file="p.json",
862
+ public_path="pub.json",
863
+ verification_key="vk.key",
864
+ circuit_name="model",
865
+ circuit_path="path.txt",
866
+ proof_system=None,
867
+ output_file="out.json",
868
+ metadata_path="metadata.json",
869
+ architecture_path="architecture.json",
870
+ w_and_b_path="w_and_b.json",
871
+ quantized_path="quantized_path.pt",
872
+ run_type=RunType.COMPILE_CIRCUIT,
873
+ dev_mode=False,
874
+ ecc=True,
875
+ write_json=False,
876
+ bench=False,
877
+ )
878
+
879
+ # This will raise inside `compile_circuit`, which is patched to raise
880
+ with pytest.raises(CircuitRunError, match="Circuit operation 'Compile' failed"):
881
+
882
+ c.parse_proof_run_type(config_exception)
883
+
884
+ # Check that the error message is logged
885
+ assert "Operation RunType.COMPILE_CIRCUIT failed" in caplog.text
886
+ assert mock_compile.called
887
+ assert mock_compile_preprocessing.called
888
+
889
+
890
+ @pytest.mark.unit
891
+ def test_save_and_load_model_not_implemented() -> None:
892
+ c = Circuit()
893
+ assert hasattr(c, "save_model")
894
+ assert hasattr(c, "load_model")
895
+ assert hasattr(c, "save_quantized_model")
896
+ assert hasattr(c, "load_quantized_model")
897
+
898
+
899
+ # ---------- New error handling tests ----------
900
+ @pytest.mark.unit
901
+ def test_adjust_inputs_file_error() -> None:
902
+ c = Circuit()
903
+ c.input_variables = ["input"]
904
+ c.input_shape = [2, 2]
905
+ c.scale_base = 2
906
+ c.scale_exponent = 1
907
+
908
+ with patch(
909
+ "python.core.circuits.base.read_from_json",
910
+ side_effect=FileNotFoundError("File not found"),
911
+ ):
912
+ _ = c
913
+ with pytest.raises(CircuitFileError, match="Failed to read input file"):
914
+ c.adjust_inputs("nonexistent.json")
915
+
916
+
917
+ @pytest.mark.unit
918
+ def test_adjust_inputs_processing_error() -> None:
919
+ c = Circuit()
920
+ c.input_variables = ["input"]
921
+ c.input_shape = [2, 2]
922
+ c.scale_base = 2
923
+ c.scale_exponent = 1
924
+
925
+ with patch(
926
+ "python.core.circuits.base.read_from_json",
927
+ return_value={"input": [1, 2, 3, 4]},
928
+ ):
929
+ _ = c
930
+ with patch("torch.tensor") as mock_tensor:
931
+ mock_tensor.side_effect = RuntimeError("Invalid tensor shape")
932
+
933
+ with pytest.raises(
934
+ CircuitProcessingError,
935
+ match="Failed to reshape input data",
936
+ ):
937
+ c.adjust_inputs("dummy.json")
938
+
939
+
940
+ @pytest.mark.unit
941
+ def test_get_inputs_from_file_file_error() -> None:
942
+ c = Circuit()
943
+ with patch.object(
944
+ c,
945
+ "_read_from_json_safely",
946
+ side_effect=CircuitFileError("Failed to read input file: protected.json"),
947
+ ):
948
+ _ = c
949
+ with pytest.raises(CircuitFileError, match="Failed to read input file"):
950
+ c.get_inputs_from_file("protected.json")
951
+
952
+
953
+ @pytest.mark.unit
954
+ def test_get_inputs_from_file_processing_error() -> None:
955
+ c = Circuit()
956
+ c.scale_base = 2
957
+ c.scale_exponent = 1
958
+
959
+ with patch.object(
960
+ c,
961
+ "_read_from_json_safely",
962
+ return_value={"input": "invalid_data"},
963
+ ):
964
+ _ = c
965
+ with pytest.raises(
966
+ CircuitProcessingError,
967
+ match="Failed to scale input data",
968
+ ):
969
+ c.get_inputs_from_file("dummy.json", is_scaled=False)