JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +5 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,969 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from collections.abc import Generator
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from unittest.mock import MagicMock, patch
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
sys.modules.pop("python.core.circuits.base", None)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
with (
|
|
12
|
+
patch(
|
|
13
|
+
"python.core.utils.helper_functions.compute_and_store_output",
|
|
14
|
+
lambda x: x,
|
|
15
|
+
),
|
|
16
|
+
patch(
|
|
17
|
+
"python.core.utils.helper_functions.prepare_io_files",
|
|
18
|
+
lambda f: f,
|
|
19
|
+
),
|
|
20
|
+
): # MUST BE BEFORE THE UUT GETS IMPORTED ANYWHERE!
|
|
21
|
+
from python.core.circuits.base import (
|
|
22
|
+
Circuit,
|
|
23
|
+
CircuitExecutionConfig,
|
|
24
|
+
RunType,
|
|
25
|
+
ZKProofSystems,
|
|
26
|
+
)
|
|
27
|
+
from python.core.circuits.errors import (
|
|
28
|
+
CircuitConfigurationError,
|
|
29
|
+
CircuitFileError,
|
|
30
|
+
CircuitInputError,
|
|
31
|
+
CircuitProcessingError,
|
|
32
|
+
CircuitRunError,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ---------- Test __init__ ----------
|
|
37
|
+
@pytest.mark.unit
|
|
38
|
+
def test_circuit_init_defaults() -> None:
|
|
39
|
+
c = Circuit()
|
|
40
|
+
assert c.input_folder == "inputs"
|
|
41
|
+
assert c.proof_folder == "analysis"
|
|
42
|
+
assert c.temp_folder == "temp"
|
|
43
|
+
assert c.circuit_folder == ""
|
|
44
|
+
assert c.weights_folder == "weights"
|
|
45
|
+
assert c.output_folder == "output"
|
|
46
|
+
assert c.proof_system == ZKProofSystems.Expander
|
|
47
|
+
assert c._file_info is None
|
|
48
|
+
assert c.required_keys is None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.mark.unit
|
|
52
|
+
def test_circuit_execution_config_with_new_paths() -> None:
|
|
53
|
+
config = CircuitExecutionConfig(
|
|
54
|
+
circuit_name="test_circuit",
|
|
55
|
+
metadata_path="meta.json",
|
|
56
|
+
architecture_path="arch.json",
|
|
57
|
+
w_and_b_path="weights.json",
|
|
58
|
+
)
|
|
59
|
+
assert config.circuit_name == "test_circuit"
|
|
60
|
+
assert config.metadata_path == "meta.json"
|
|
61
|
+
assert config.architecture_path == "arch.json"
|
|
62
|
+
assert config.w_and_b_path == "weights.json"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ---------- Test parse_inputs ----------
|
|
66
|
+
@pytest.mark.unit
|
|
67
|
+
def test_parse_inputs_missing_required_keys() -> None:
|
|
68
|
+
c = Circuit()
|
|
69
|
+
c.required_keys = ["x", "y"]
|
|
70
|
+
with pytest.raises(CircuitInputError, match="Missing required parameter: 'x'"):
|
|
71
|
+
c.parse_inputs(y=5)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@pytest.mark.unit
|
|
75
|
+
def test_parse_inputs_type_check() -> None:
|
|
76
|
+
c = Circuit()
|
|
77
|
+
c.required_keys = ["x"]
|
|
78
|
+
with pytest.raises(
|
|
79
|
+
CircuitInputError,
|
|
80
|
+
match="Parameter 'x' must be an int or list of ints",
|
|
81
|
+
):
|
|
82
|
+
c.parse_inputs(x="not-an-int")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@pytest.mark.unit
|
|
86
|
+
def test_parse_inputs_success_int() -> None:
|
|
87
|
+
c = Circuit()
|
|
88
|
+
c.required_keys = ["x", "y"]
|
|
89
|
+
x = 10
|
|
90
|
+
y = 20
|
|
91
|
+
|
|
92
|
+
c.parse_inputs(x=x, y=y)
|
|
93
|
+
|
|
94
|
+
assert c.x == x
|
|
95
|
+
assert c.y == y
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.mark.unit
|
|
99
|
+
def test_parse_inputs_success_list() -> None:
|
|
100
|
+
c = Circuit()
|
|
101
|
+
c.required_keys = ["arr"]
|
|
102
|
+
c.parse_inputs(arr=[1, 2, 3])
|
|
103
|
+
assert c.arr == [1, 2, 3]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.mark.unit
|
|
107
|
+
def test_parse_inputs_required_keys_none() -> None:
|
|
108
|
+
c = Circuit()
|
|
109
|
+
with pytest.raises(CircuitConfigurationError):
|
|
110
|
+
c.parse_inputs()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# ---------- Test Not Implemented --------------
|
|
114
|
+
@pytest.mark.unit
|
|
115
|
+
def test_get_inputs_not_implemented() -> None:
|
|
116
|
+
c = Circuit()
|
|
117
|
+
with pytest.raises(NotImplementedError, match="get_inputs must be implemented"):
|
|
118
|
+
c.get_inputs()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@pytest.mark.unit
|
|
122
|
+
def test_get_outputs_not_implemented() -> None:
|
|
123
|
+
c = Circuit()
|
|
124
|
+
with pytest.raises(NotImplementedError, match="get_outputs must be implemented"):
|
|
125
|
+
c.get_outputs()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# ---------- Test parse_proof_run_type ----------
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@pytest.mark.unit
|
|
132
|
+
@patch("python.core.circuits.base.compile_circuit")
|
|
133
|
+
@patch("python.core.circuits.base.generate_witness")
|
|
134
|
+
@patch("python.core.circuits.base.generate_proof")
|
|
135
|
+
@patch("python.core.circuits.base.generate_verification")
|
|
136
|
+
@patch("python.core.circuits.base.run_end_to_end")
|
|
137
|
+
def test_parse_proof_dispatch_logic(
|
|
138
|
+
mock_end_to_end: MagicMock,
|
|
139
|
+
mock_verify: MagicMock,
|
|
140
|
+
mock_proof: MagicMock,
|
|
141
|
+
mock_witness: MagicMock,
|
|
142
|
+
mock_compile: MagicMock,
|
|
143
|
+
) -> None:
|
|
144
|
+
c = Circuit()
|
|
145
|
+
|
|
146
|
+
# Mock internal preprocessing methods
|
|
147
|
+
c._compile_preprocessing = MagicMock()
|
|
148
|
+
c._gen_witness_preprocessing = MagicMock(return_value="i")
|
|
149
|
+
c.adjust_inputs = MagicMock(return_value="i")
|
|
150
|
+
c.rename_inputs = MagicMock(return_value="i")
|
|
151
|
+
|
|
152
|
+
c.load_and_compare_witness_to_io = MagicMock(return_value="True")
|
|
153
|
+
|
|
154
|
+
# COMPILE_CIRCUIT
|
|
155
|
+
config_compile = CircuitExecutionConfig(
|
|
156
|
+
witness_file="w",
|
|
157
|
+
input_file="i",
|
|
158
|
+
proof_file="p",
|
|
159
|
+
public_path="pub",
|
|
160
|
+
verification_key="vk",
|
|
161
|
+
circuit_name="circuit",
|
|
162
|
+
circuit_path="path",
|
|
163
|
+
proof_system=ZKProofSystems.Expander,
|
|
164
|
+
output_file="out",
|
|
165
|
+
metadata_path="metadata",
|
|
166
|
+
architecture_path="architecture",
|
|
167
|
+
w_and_b_path="w_and_b",
|
|
168
|
+
quantized_path="q",
|
|
169
|
+
run_type=RunType.COMPILE_CIRCUIT,
|
|
170
|
+
dev_mode=False,
|
|
171
|
+
ecc=True,
|
|
172
|
+
write_json=False,
|
|
173
|
+
bench=False,
|
|
174
|
+
)
|
|
175
|
+
c.parse_proof_run_type(config_compile)
|
|
176
|
+
mock_compile.assert_called_once()
|
|
177
|
+
c._compile_preprocessing.assert_called_once_with(
|
|
178
|
+
metadata_path="metadata",
|
|
179
|
+
architecture_path="architecture",
|
|
180
|
+
w_and_b_path="w_and_b",
|
|
181
|
+
quantized_path="q",
|
|
182
|
+
)
|
|
183
|
+
_, kwargs = mock_compile.call_args
|
|
184
|
+
assert kwargs == {
|
|
185
|
+
"circuit_name": "circuit",
|
|
186
|
+
"circuit_path": "path",
|
|
187
|
+
"proof_system": ZKProofSystems.Expander,
|
|
188
|
+
"dev_mode": False,
|
|
189
|
+
"bench": False,
|
|
190
|
+
"architecture_path": "architecture",
|
|
191
|
+
"metadata_path": "metadata",
|
|
192
|
+
"w_and_b_path": "w_and_b",
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
# GEN_WITNESS
|
|
196
|
+
config_witness = CircuitExecutionConfig(
|
|
197
|
+
witness_file="w",
|
|
198
|
+
input_file="i",
|
|
199
|
+
proof_file="p",
|
|
200
|
+
public_path="pub",
|
|
201
|
+
verification_key="vk",
|
|
202
|
+
circuit_name="circuit",
|
|
203
|
+
circuit_path="path",
|
|
204
|
+
proof_system=ZKProofSystems.Expander,
|
|
205
|
+
output_file="out",
|
|
206
|
+
metadata_path="metadata",
|
|
207
|
+
architecture_path="architecture",
|
|
208
|
+
w_and_b_path="w_and_b",
|
|
209
|
+
quantized_path="q",
|
|
210
|
+
run_type=RunType.GEN_WITNESS,
|
|
211
|
+
dev_mode=False,
|
|
212
|
+
ecc=True,
|
|
213
|
+
write_json=False,
|
|
214
|
+
bench=False,
|
|
215
|
+
)
|
|
216
|
+
c.parse_proof_run_type(config_witness)
|
|
217
|
+
mock_witness.assert_called_once()
|
|
218
|
+
c._gen_witness_preprocessing.assert_called()
|
|
219
|
+
_, kwargs = mock_witness.call_args
|
|
220
|
+
assert kwargs == {
|
|
221
|
+
"circuit_name": "circuit",
|
|
222
|
+
"circuit_path": "path",
|
|
223
|
+
"witness_file": "w",
|
|
224
|
+
"input_file": "i",
|
|
225
|
+
"output_file": "out",
|
|
226
|
+
"proof_system": ZKProofSystems.Expander,
|
|
227
|
+
"dev_mode": False,
|
|
228
|
+
"bench": False,
|
|
229
|
+
"metadata_path": "metadata",
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
# PROVE_WITNESS
|
|
233
|
+
config_prove = CircuitExecutionConfig(
|
|
234
|
+
witness_file="w",
|
|
235
|
+
input_file="i",
|
|
236
|
+
proof_file="p",
|
|
237
|
+
public_path="pub",
|
|
238
|
+
verification_key="vk",
|
|
239
|
+
circuit_name="circuit",
|
|
240
|
+
circuit_path="path",
|
|
241
|
+
proof_system=ZKProofSystems.Expander,
|
|
242
|
+
output_file="out",
|
|
243
|
+
metadata_path="metadata",
|
|
244
|
+
architecture_path="architecture",
|
|
245
|
+
w_and_b_path="w_and_b",
|
|
246
|
+
quantized_path="q",
|
|
247
|
+
run_type=RunType.PROVE_WITNESS,
|
|
248
|
+
dev_mode=False,
|
|
249
|
+
ecc=True,
|
|
250
|
+
write_json=False,
|
|
251
|
+
bench=False,
|
|
252
|
+
)
|
|
253
|
+
c.parse_proof_run_type(config_prove)
|
|
254
|
+
mock_proof.assert_called_once()
|
|
255
|
+
_, kwargs = mock_proof.call_args
|
|
256
|
+
|
|
257
|
+
assert kwargs == {
|
|
258
|
+
"circuit_name": "circuit",
|
|
259
|
+
"circuit_path": "path",
|
|
260
|
+
"witness_file": "w",
|
|
261
|
+
"proof_file": "p",
|
|
262
|
+
"proof_system": ZKProofSystems.Expander,
|
|
263
|
+
"dev_mode": False,
|
|
264
|
+
"ecc": True,
|
|
265
|
+
"bench": False,
|
|
266
|
+
"metadata_path": "metadata",
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
# GEN_VERIFY
|
|
270
|
+
config_verify = CircuitExecutionConfig(
|
|
271
|
+
witness_file="w",
|
|
272
|
+
input_file="i",
|
|
273
|
+
proof_file="p",
|
|
274
|
+
public_path="pub",
|
|
275
|
+
verification_key="vk",
|
|
276
|
+
circuit_name="circuit",
|
|
277
|
+
circuit_path="path",
|
|
278
|
+
proof_system=ZKProofSystems.Expander,
|
|
279
|
+
output_file="out",
|
|
280
|
+
metadata_path="metadata",
|
|
281
|
+
architecture_path="architecture",
|
|
282
|
+
w_and_b_path="w_and_b",
|
|
283
|
+
quantized_path="q",
|
|
284
|
+
run_type=RunType.GEN_VERIFY,
|
|
285
|
+
dev_mode=False,
|
|
286
|
+
ecc=True,
|
|
287
|
+
write_json=False,
|
|
288
|
+
bench=False,
|
|
289
|
+
)
|
|
290
|
+
c.parse_proof_run_type(config_verify)
|
|
291
|
+
mock_verify.assert_called_once()
|
|
292
|
+
_, kwargs = mock_verify.call_args
|
|
293
|
+
assert kwargs == {
|
|
294
|
+
"circuit_name": "circuit",
|
|
295
|
+
"circuit_path": "path",
|
|
296
|
+
"input_file": "i",
|
|
297
|
+
"output_file": "out",
|
|
298
|
+
"witness_file": "w",
|
|
299
|
+
"proof_file": "p",
|
|
300
|
+
"proof_system": ZKProofSystems.Expander,
|
|
301
|
+
"dev_mode": False,
|
|
302
|
+
"ecc": True,
|
|
303
|
+
"bench": False,
|
|
304
|
+
"metadata_path": "metadata",
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
# END_TO_END
|
|
308
|
+
config_end_to_end = CircuitExecutionConfig(
|
|
309
|
+
witness_file="w",
|
|
310
|
+
input_file="i",
|
|
311
|
+
proof_file="p",
|
|
312
|
+
public_path="pub",
|
|
313
|
+
verification_key="vk",
|
|
314
|
+
circuit_name="circuit",
|
|
315
|
+
circuit_path="path",
|
|
316
|
+
proof_system=ZKProofSystems.Expander,
|
|
317
|
+
output_file="out",
|
|
318
|
+
metadata_path="metadata",
|
|
319
|
+
architecture_path="architecture",
|
|
320
|
+
w_and_b_path="w_and_b",
|
|
321
|
+
quantized_path="q",
|
|
322
|
+
run_type=RunType.END_TO_END,
|
|
323
|
+
dev_mode=False,
|
|
324
|
+
ecc=True,
|
|
325
|
+
write_json=False,
|
|
326
|
+
bench=False,
|
|
327
|
+
)
|
|
328
|
+
c.parse_proof_run_type(config_end_to_end)
|
|
329
|
+
|
|
330
|
+
preprocess_call_count = 2
|
|
331
|
+
|
|
332
|
+
mock_end_to_end.assert_called_once()
|
|
333
|
+
assert c._compile_preprocessing.call_count >= preprocess_call_count
|
|
334
|
+
assert c._gen_witness_preprocessing.call_count >= preprocess_call_count
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# ---------- Test new methods for metadata, architecture, w_and_b ----------
|
|
338
|
+
@pytest.mark.unit
|
|
339
|
+
def test_get_metadata_default() -> None:
|
|
340
|
+
c = Circuit()
|
|
341
|
+
assert c.get_metadata() == {}
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
@pytest.mark.unit
|
|
345
|
+
def test_get_architecture_default() -> None:
|
|
346
|
+
c = Circuit()
|
|
347
|
+
assert c.get_architecture() == {}
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@pytest.mark.unit
|
|
351
|
+
def test_get_w_and_b_default() -> None:
|
|
352
|
+
c = Circuit()
|
|
353
|
+
assert c.get_w_and_b() == {}
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
# ---------- Optional: test get_weights ----------
|
|
357
|
+
@pytest.mark.unit
|
|
358
|
+
def test_get_weights_default() -> None:
|
|
359
|
+
c = Circuit()
|
|
360
|
+
assert c.get_weights() == {}
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
@pytest.mark.unit
|
|
364
|
+
def test_get_inputs_from_file() -> None:
|
|
365
|
+
c = Circuit()
|
|
366
|
+
c.scale_base = 2
|
|
367
|
+
c.scale_exponent = 2
|
|
368
|
+
with patch(
|
|
369
|
+
"python.core.circuits.base.read_from_json",
|
|
370
|
+
return_value={"input": [1, 2, 3, 4]},
|
|
371
|
+
):
|
|
372
|
+
x = c.get_inputs_from_file("", is_scaled=True)
|
|
373
|
+
assert x == {"input": [1, 2, 3, 4]}
|
|
374
|
+
|
|
375
|
+
y = c.get_inputs_from_file("", is_scaled=False)
|
|
376
|
+
assert y == {"input": [4, 8, 12, 16]}
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@pytest.mark.unit
|
|
380
|
+
def test_get_inputs_from_file_multiple_inputs() -> None:
|
|
381
|
+
c = Circuit()
|
|
382
|
+
c.scale_base = 2
|
|
383
|
+
c.scale_exponent = 2
|
|
384
|
+
with patch(
|
|
385
|
+
"python.core.circuits.base.read_from_json",
|
|
386
|
+
return_value={"input": [1, 2, 3, 4], "nonce": 25},
|
|
387
|
+
):
|
|
388
|
+
x = c.get_inputs_from_file("", is_scaled=True)
|
|
389
|
+
assert x == {"input": [1, 2, 3, 4], "nonce": 25}
|
|
390
|
+
|
|
391
|
+
y = c.get_inputs_from_file("", is_scaled=False)
|
|
392
|
+
assert y == {"input": [4, 8, 12, 16], "nonce": 100}
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
@pytest.mark.unit
|
|
396
|
+
def test_get_inputs_from_file_dne() -> None:
|
|
397
|
+
c = Circuit()
|
|
398
|
+
c.scale_base = 2
|
|
399
|
+
c.scale_exponent = 2
|
|
400
|
+
with pytest.raises(CircuitFileError, match="Failed to read input file"):
|
|
401
|
+
c.get_inputs_from_file("this_file_should_not_exist_12345.json", is_scaled=True)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
@pytest.mark.unit
|
|
405
|
+
def test_format_outputs() -> None:
|
|
406
|
+
c = Circuit()
|
|
407
|
+
out = c.format_outputs([10, 15, 20])
|
|
408
|
+
assert out == {"output": [10, 15, 20]}
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
# ---------- _gen_witness_preprocessing ----------
|
|
412
|
+
@pytest.mark.unit
|
|
413
|
+
@patch("python.core.circuits.base.to_json")
|
|
414
|
+
def test_gen_witness_preprocessing_write_json_true(mock_to_json: MagicMock) -> None:
|
|
415
|
+
c = Circuit()
|
|
416
|
+
c._file_info = {"quantized_model_path": "quant.pt"}
|
|
417
|
+
c.load_quantized_model = MagicMock()
|
|
418
|
+
c.get_inputs = MagicMock(return_value="inputs")
|
|
419
|
+
c.get_outputs = MagicMock(return_value="outputs")
|
|
420
|
+
c.format_inputs = MagicMock(return_value={"input": 1})
|
|
421
|
+
c.format_outputs = MagicMock(return_value={"output": 2})
|
|
422
|
+
|
|
423
|
+
c._gen_witness_preprocessing(
|
|
424
|
+
"in.json",
|
|
425
|
+
"out.json",
|
|
426
|
+
None,
|
|
427
|
+
write_json=True,
|
|
428
|
+
is_scaled=True,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
c.load_quantized_model.assert_called_once_with("quant.pt")
|
|
432
|
+
c.get_inputs.assert_called_once()
|
|
433
|
+
c.get_outputs.assert_called_once_with("inputs")
|
|
434
|
+
mock_to_json.assert_any_call({"input": 1}, "in.json")
|
|
435
|
+
mock_to_json.assert_any_call({"output": 2}, "out.json")
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
@pytest.mark.unit
|
|
439
|
+
@patch("python.core.circuits.base.to_json")
|
|
440
|
+
def test_gen_witness_preprocessing_write_json_false(mock_to_json: MagicMock) -> None:
|
|
441
|
+
c = Circuit()
|
|
442
|
+
c._file_info = {"quantized_model_path": "quant.pt"}
|
|
443
|
+
c.load_quantized_model = MagicMock()
|
|
444
|
+
c.get_inputs_from_file = MagicMock(return_value="mock_inputs")
|
|
445
|
+
c.reshape_inputs = MagicMock(return_value="in.json")
|
|
446
|
+
c.rescale_inputs = MagicMock(return_value="in.json")
|
|
447
|
+
c.rename_inputs = MagicMock(return_value="in.json")
|
|
448
|
+
c.rescale_and_reshape_inputs = MagicMock(return_value="in.json")
|
|
449
|
+
c.adjust_inputs = MagicMock(return_value="in.json")
|
|
450
|
+
|
|
451
|
+
c.get_outputs = MagicMock(return_value="mock_outputs")
|
|
452
|
+
c.format_outputs = MagicMock(return_value={"output": 99})
|
|
453
|
+
|
|
454
|
+
c._gen_witness_preprocessing(
|
|
455
|
+
"in.json",
|
|
456
|
+
"out.json",
|
|
457
|
+
None,
|
|
458
|
+
write_json=False,
|
|
459
|
+
is_scaled=False,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
c.load_quantized_model.assert_called_once_with("quant.pt")
|
|
463
|
+
c.get_inputs_from_file.assert_called_once_with("in.json", is_scaled=False)
|
|
464
|
+
c.get_outputs.assert_called_once_with("mock_inputs")
|
|
465
|
+
c.format_outputs.assert_called_once_with("mock_outputs")
|
|
466
|
+
mock_to_json.assert_called_once_with({"output": 99}, "out.json")
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
# ---------- _compile_preprocessing ----------
|
|
470
|
+
@pytest.mark.unit
|
|
471
|
+
@patch("python.core.circuits.base.to_json")
|
|
472
|
+
def test_compile_preprocessing_saves_all_files(mock_to_json: MagicMock) -> None:
|
|
473
|
+
c = Circuit()
|
|
474
|
+
c._file_info = {"quantized_model_path": "model.pth"}
|
|
475
|
+
c.get_model_and_quantize = MagicMock()
|
|
476
|
+
c.get_metadata = MagicMock(return_value={"version": "1.0"})
|
|
477
|
+
c.get_architecture = MagicMock(return_value={"layers": ["conv", "relu"]})
|
|
478
|
+
c.get_w_and_b = MagicMock(return_value={"weights": [1, 2, 3]})
|
|
479
|
+
c.save_quantized_model = MagicMock()
|
|
480
|
+
|
|
481
|
+
c._compile_preprocessing("metadata.json", "architecture.json", "w_and_b.json", None)
|
|
482
|
+
|
|
483
|
+
c.get_model_and_quantize.assert_called_once()
|
|
484
|
+
c.get_metadata.assert_called_once()
|
|
485
|
+
c.get_architecture.assert_called_once()
|
|
486
|
+
c.get_w_and_b.assert_called_once()
|
|
487
|
+
c.save_quantized_model.assert_called_once_with("model.pth")
|
|
488
|
+
mock_to_json.assert_any_call({"version": "1.0"}, "metadata.json")
|
|
489
|
+
mock_to_json.assert_any_call({"layers": ["conv", "relu"]}, "architecture.json")
|
|
490
|
+
mock_to_json.assert_any_call({"weights": [1, 2, 3]}, "w_and_b.json")
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
@pytest.mark.unit
|
|
494
|
+
@patch("python.core.circuits.base.to_json")
|
|
495
|
+
def test_compile_preprocessing_saves_all_files(mock_to_json: MagicMock) -> None:
|
|
496
|
+
c = Circuit()
|
|
497
|
+
c._file_info = {"quantized_model_path": "model.pth"}
|
|
498
|
+
c.get_model_and_quantize = MagicMock()
|
|
499
|
+
c.get_metadata = MagicMock(return_value={"version": "1.0"})
|
|
500
|
+
c.get_architecture = MagicMock(return_value={"layers": ["conv", "relu"]})
|
|
501
|
+
c.get_w_and_b = MagicMock(return_value={"weights": [1, 2, 3]})
|
|
502
|
+
c.save_quantized_model = MagicMock()
|
|
503
|
+
|
|
504
|
+
c._compile_preprocessing("metadata.json", "architecture.json", "w_and_b.json", None)
|
|
505
|
+
|
|
506
|
+
c.get_model_and_quantize.assert_called_once()
|
|
507
|
+
c.get_metadata.assert_called_once()
|
|
508
|
+
c.get_architecture.assert_called_once()
|
|
509
|
+
c.get_w_and_b.assert_called_once()
|
|
510
|
+
c.save_quantized_model.assert_called_once_with("model.pth")
|
|
511
|
+
mock_to_json.assert_any_call({"version": "1.0"}, "metadata.json")
|
|
512
|
+
mock_to_json.assert_any_call({"layers": ["conv", "relu"]}, "architecture.json")
|
|
513
|
+
mock_to_json.assert_any_call({"weights": [1, 2, 3]}, "w_and_b.json")
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
@pytest.mark.unit
|
|
517
|
+
@patch("python.core.circuits.base.to_json")
|
|
518
|
+
def test_compile_preprocessing_weights_dict(mock_to_json: MagicMock) -> None:
|
|
519
|
+
c = Circuit()
|
|
520
|
+
c._file_info = {"quantized_model_path": "model.pth"}
|
|
521
|
+
c.get_model_and_quantize = MagicMock()
|
|
522
|
+
c.get_metadata = MagicMock(return_value={"TEST": "2"})
|
|
523
|
+
c.get_architecture = MagicMock(return_value={"TEST": "1"})
|
|
524
|
+
c.get_w_and_b = MagicMock(return_value={"a": 1})
|
|
525
|
+
c.save_quantized_model = MagicMock()
|
|
526
|
+
|
|
527
|
+
c._compile_preprocessing("metadata.json", "architecture.json", "w_and_b.json", None)
|
|
528
|
+
|
|
529
|
+
c.get_model_and_quantize.assert_called_once()
|
|
530
|
+
c.get_w_and_b.assert_called_once()
|
|
531
|
+
c.save_quantized_model.assert_called_once_with("model.pth")
|
|
532
|
+
mock_to_json.assert_any_call({"TEST": "2"}, "metadata.json")
|
|
533
|
+
mock_to_json.assert_any_call({"TEST": "1"}, "architecture.json")
|
|
534
|
+
mock_to_json.assert_any_call({"a": 1}, "w_and_b.json")
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
@pytest.mark.unit
|
|
538
|
+
@patch("python.core.circuits.base.to_json")
|
|
539
|
+
def test_compile_preprocessing_weights_list(
|
|
540
|
+
mock_to_json: MagicMock,
|
|
541
|
+
) -> None:
|
|
542
|
+
c = Circuit()
|
|
543
|
+
c._file_info = {"quantized_model_path": "model.pth"}
|
|
544
|
+
c.get_model_and_quantize = MagicMock()
|
|
545
|
+
c.get_metadata = MagicMock(return_value={"TEST": "1"})
|
|
546
|
+
c.get_architecture = MagicMock(return_value={"TEST": "2"})
|
|
547
|
+
c.get_w_and_b = MagicMock(return_value=[{"w1": 1}, {"w2": 2}, {"w3": 3}])
|
|
548
|
+
c.save_quantized_model = MagicMock()
|
|
549
|
+
|
|
550
|
+
c._compile_preprocessing("metadata.json", "architecture.json", "w_and_b.json", None)
|
|
551
|
+
|
|
552
|
+
call_count = 5 # 2 for metadata/architecture + 3 for weights
|
|
553
|
+
|
|
554
|
+
assert mock_to_json.call_count == call_count
|
|
555
|
+
mock_to_json.assert_any_call({"TEST": "1"}, "metadata.json")
|
|
556
|
+
mock_to_json.assert_any_call({"TEST": "2"}, "architecture.json")
|
|
557
|
+
mock_to_json.assert_any_call({"w1": 1}, Path("w_and_b.json"))
|
|
558
|
+
mock_to_json.assert_any_call({"w2": 2}, Path("w_and_b2.json"))
|
|
559
|
+
mock_to_json.assert_any_call({"w3": 3}, Path("w_and_b3.json"))
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
@pytest.mark.unit
|
|
563
|
+
@patch("python.core.circuits.base.to_json")
|
|
564
|
+
def test_compile_preprocessing_weights_list_single_call(
|
|
565
|
+
mock_to_json: MagicMock,
|
|
566
|
+
) -> None:
|
|
567
|
+
c = Circuit()
|
|
568
|
+
c._file_info = {"quantized_model_path": "model.pth"}
|
|
569
|
+
c.get_model_and_quantize = MagicMock()
|
|
570
|
+
c.get_metadata = MagicMock(return_value={})
|
|
571
|
+
c.get_architecture = MagicMock(return_value={})
|
|
572
|
+
c.get_weights = MagicMock(return_value=[{"w1": 1}, {"w2": 2}, {"w3": 3}])
|
|
573
|
+
c.save_quantized_model = MagicMock()
|
|
574
|
+
|
|
575
|
+
c._compile_preprocessing("metadata.json", "architecture.json", "w_and_b.json", None)
|
|
576
|
+
|
|
577
|
+
call_count = 3
|
|
578
|
+
|
|
579
|
+
assert mock_to_json.call_count == call_count # +2 for metadata and architecture
|
|
580
|
+
mock_to_json.assert_any_call({"w1": 1}, Path("w_and_b.json"))
|
|
581
|
+
mock_to_json.assert_any_call({"w2": 2}, Path("w_and_b2.json"))
|
|
582
|
+
mock_to_json.assert_any_call({"w3": 3}, Path("w_and_b3.json"))
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
@pytest.mark.unit
|
|
586
|
+
def test_compile_preprocessing_raises_on_bad_weights() -> None:
|
|
587
|
+
c = Circuit()
|
|
588
|
+
c._file_info = {"quantized_model_path": "model.pth"}
|
|
589
|
+
c.get_model_and_quantize = MagicMock()
|
|
590
|
+
c.get_metadata = MagicMock(return_value={})
|
|
591
|
+
c.get_architecture = MagicMock(return_value={})
|
|
592
|
+
c.get_w_and_b = MagicMock(return_value="bad_type")
|
|
593
|
+
c.save_quantized_model = MagicMock()
|
|
594
|
+
|
|
595
|
+
with pytest.raises(CircuitConfigurationError, match="Unsupported w_and_b type"):
|
|
596
|
+
c._compile_preprocessing(
|
|
597
|
+
"metadata.json",
|
|
598
|
+
"architecture.json",
|
|
599
|
+
"w_and_b.json",
|
|
600
|
+
None,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
# ---------- Test check attributes --------------
|
|
605
|
+
@pytest.mark.unit
|
|
606
|
+
def test_check_attributes_true() -> None:
|
|
607
|
+
c = Circuit()
|
|
608
|
+
c.required_keys = ["input"]
|
|
609
|
+
c.name = "test"
|
|
610
|
+
c.scale_exponent = 2
|
|
611
|
+
c.scale_base = 2
|
|
612
|
+
c.check_attributes()
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
@pytest.mark.unit
|
|
616
|
+
def test_check_attributes_no_scaling() -> None:
|
|
617
|
+
c = Circuit()
|
|
618
|
+
c.required_keys = ["input"]
|
|
619
|
+
c.name = "test"
|
|
620
|
+
c.scale_base = 2
|
|
621
|
+
with pytest.raises(CircuitConfigurationError) as exc_info:
|
|
622
|
+
c.check_attributes()
|
|
623
|
+
|
|
624
|
+
msg = str(exc_info.value)
|
|
625
|
+
assert "Circuit class (python) is misconfigured" in msg
|
|
626
|
+
assert "scale_exponent" in msg
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
@pytest.mark.unit
|
|
630
|
+
def test_check_attributes_no_scalebase() -> None:
|
|
631
|
+
c = Circuit()
|
|
632
|
+
c.required_keys = ["input"]
|
|
633
|
+
c.name = "test"
|
|
634
|
+
c.scale_exponent = 2
|
|
635
|
+
|
|
636
|
+
with pytest.raises(CircuitConfigurationError) as exc_info:
|
|
637
|
+
c.check_attributes()
|
|
638
|
+
|
|
639
|
+
msg = str(exc_info.value)
|
|
640
|
+
assert "Circuit class (python) is misconfigured" in msg
|
|
641
|
+
assert "scale_base" in msg
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
@pytest.mark.unit
|
|
645
|
+
def test_check_attributes_no_name() -> None:
|
|
646
|
+
c = Circuit()
|
|
647
|
+
c.required_keys = ["input"]
|
|
648
|
+
c.scale_base = 2
|
|
649
|
+
c.scale_exponent = 2
|
|
650
|
+
|
|
651
|
+
with pytest.raises(CircuitConfigurationError) as exc_info:
|
|
652
|
+
c.check_attributes()
|
|
653
|
+
|
|
654
|
+
msg = str(exc_info.value)
|
|
655
|
+
assert "Circuit class (python) is misconfigured" in msg
|
|
656
|
+
assert "name" in msg
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
# ---------- base_testing ------------
|
|
660
|
+
@pytest.mark.unit
|
|
661
|
+
@patch.object(Circuit, "parse_proof_run_type")
|
|
662
|
+
def test_base_testing_calls_parse_proof_run_type_correctly(
|
|
663
|
+
mock_parse: MagicMock,
|
|
664
|
+
) -> None:
|
|
665
|
+
c = Circuit()
|
|
666
|
+
c.name = "test"
|
|
667
|
+
|
|
668
|
+
c._file_info = {}
|
|
669
|
+
c._file_info["metadata_path"] = "metadata.json"
|
|
670
|
+
c._file_info["architecture_path"] = "architecture.json"
|
|
671
|
+
c._file_info["w_and_b_path"] = "w_and_b.json"
|
|
672
|
+
c.base_testing(
|
|
673
|
+
CircuitExecutionConfig(
|
|
674
|
+
run_type=RunType.GEN_WITNESS,
|
|
675
|
+
witness_file="w.wtns",
|
|
676
|
+
input_file="i.json",
|
|
677
|
+
proof_file="p.json",
|
|
678
|
+
public_path="pub.json",
|
|
679
|
+
verification_key="vk.key",
|
|
680
|
+
circuit_name="circuit_model",
|
|
681
|
+
output_file="o.json",
|
|
682
|
+
circuit_path="circuit_path.txt",
|
|
683
|
+
quantized_path="quantized_path.pt",
|
|
684
|
+
write_json=True,
|
|
685
|
+
proof_system=ZKProofSystems.Expander,
|
|
686
|
+
),
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
mock_parse.assert_called_once()
|
|
690
|
+
expected_config = CircuitExecutionConfig(
|
|
691
|
+
witness_file="w.wtns",
|
|
692
|
+
input_file="i.json",
|
|
693
|
+
proof_file="p.json",
|
|
694
|
+
public_path="pub.json",
|
|
695
|
+
verification_key="vk.key",
|
|
696
|
+
circuit_name="circuit_model",
|
|
697
|
+
circuit_path="circuit_path.txt",
|
|
698
|
+
proof_system=ZKProofSystems.Expander,
|
|
699
|
+
output_file="o.json",
|
|
700
|
+
metadata_path="metadata.json",
|
|
701
|
+
architecture_path="architecture.json",
|
|
702
|
+
w_and_b_path="w_and_b.json",
|
|
703
|
+
quantized_path="quantized_path.pt",
|
|
704
|
+
run_type=RunType.GEN_WITNESS,
|
|
705
|
+
dev_mode=False,
|
|
706
|
+
ecc=True,
|
|
707
|
+
write_json=True,
|
|
708
|
+
bench=False,
|
|
709
|
+
)
|
|
710
|
+
mock_parse.assert_called_once_with(expected_config)
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
@pytest.mark.unit
|
|
714
|
+
def test_prepare_io_files_sets_new_file_paths() -> None:
|
|
715
|
+
"""Test that prepare_io_files decorator sets the new file paths correctly."""
|
|
716
|
+
from python.core.utils.helper_functions import prepare_io_files # noqa: PLC0415
|
|
717
|
+
|
|
718
|
+
class TestCircuit(Circuit):
|
|
719
|
+
def __init__(self: Circuit) -> None:
|
|
720
|
+
super().__init__()
|
|
721
|
+
self.name = "test_circuit"
|
|
722
|
+
|
|
723
|
+
@prepare_io_files
|
|
724
|
+
def test_method(self: Circuit, exec_config: str) -> str:
|
|
725
|
+
_ = exec_config
|
|
726
|
+
return self._file_info
|
|
727
|
+
|
|
728
|
+
c = TestCircuit()
|
|
729
|
+
|
|
730
|
+
with patch("python.core.utils.helper_functions.get_files") as mock_get_files:
|
|
731
|
+
mock_get_files.return_value = {
|
|
732
|
+
"witness_file": "witness.wtns",
|
|
733
|
+
"input_file": "input.json",
|
|
734
|
+
"proof_path": "proof.json",
|
|
735
|
+
"public_path": "public.json",
|
|
736
|
+
"circuit_name": "test_circuit",
|
|
737
|
+
"metadata_path": "metadata.json",
|
|
738
|
+
"architecture_path": "architecture.json",
|
|
739
|
+
"w_and_b_path": "w_and_b.json",
|
|
740
|
+
"output_file": "output.json",
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
config = CircuitExecutionConfig(run_type=RunType.COMPILE_CIRCUIT)
|
|
744
|
+
file_info = c.test_method(config)
|
|
745
|
+
|
|
746
|
+
assert file_info["metadata_path"] == "metadata.json"
|
|
747
|
+
assert file_info["architecture_path"] == "architecture.json"
|
|
748
|
+
assert file_info["w_and_b_path"] == "w_and_b.json"
|
|
749
|
+
assert config.metadata_path == "metadata.json"
|
|
750
|
+
assert config.architecture_path == "architecture.json"
|
|
751
|
+
assert config.w_and_b_path == "w_and_b.json"
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
@pytest.mark.unit
|
|
755
|
+
@patch.object(Circuit, "parse_proof_run_type")
|
|
756
|
+
def test_base_testing_uses_default_circuit_path(mock_parse: MagicMock) -> None:
|
|
757
|
+
class MyCircuit(Circuit):
|
|
758
|
+
def __init__(self: "MyCircuit") -> None:
|
|
759
|
+
super().__init__()
|
|
760
|
+
self._file_info = {
|
|
761
|
+
"metadata_path": "metadata.json",
|
|
762
|
+
"architecture_path": "architecture.json",
|
|
763
|
+
"w_and_b_path": "w_and_b.json",
|
|
764
|
+
}
|
|
765
|
+
|
|
766
|
+
c = MyCircuit()
|
|
767
|
+
c.base_testing(CircuitExecutionConfig(circuit_name="test_model"))
|
|
768
|
+
|
|
769
|
+
mock_parse.assert_called_once()
|
|
770
|
+
config = mock_parse.call_args[0][0]
|
|
771
|
+
|
|
772
|
+
assert config.circuit_name == "test_model"
|
|
773
|
+
assert config.circuit_path == "test_model.txt"
|
|
774
|
+
assert config.metadata_path == "metadata.json"
|
|
775
|
+
assert config.architecture_path == "architecture.json"
|
|
776
|
+
assert config.w_and_b_path == "w_and_b.json"
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
@pytest.mark.unit
|
|
780
|
+
@patch.object(Circuit, "parse_proof_run_type")
|
|
781
|
+
def test_base_testing_returns_none(mock_parse: MagicMock) -> None:
|
|
782
|
+
class MyCircuit(Circuit):
|
|
783
|
+
def __init__(self: "MyCircuit") -> None:
|
|
784
|
+
super().__init__()
|
|
785
|
+
self._file_info = {
|
|
786
|
+
"metadata_path": "metadata.json",
|
|
787
|
+
"architecture_path": "architecture.json",
|
|
788
|
+
"w_and_b_path": "w_and_b.json",
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
c = MyCircuit()
|
|
792
|
+
result = c.base_testing(CircuitExecutionConfig(circuit_name="abc"))
|
|
793
|
+
assert result is None
|
|
794
|
+
mock_parse.assert_called_once()
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
@pytest.mark.unit
|
|
798
|
+
@patch.object(Circuit, "parse_proof_run_type")
|
|
799
|
+
def test_base_testing_weights_exists(mock_parse: MagicMock) -> None:
|
|
800
|
+
_ = mock_parse
|
|
801
|
+
|
|
802
|
+
class MyCircuit(Circuit):
|
|
803
|
+
def __init__(self: "MyCircuit") -> None:
|
|
804
|
+
super().__init__()
|
|
805
|
+
|
|
806
|
+
c = MyCircuit()
|
|
807
|
+
with pytest.raises(CircuitConfigurationError, match="Circuit file information"):
|
|
808
|
+
c.base_testing(CircuitExecutionConfig(circuit_name="abc"))
|
|
809
|
+
|
|
810
|
+
|
|
811
|
+
@pytest.mark.unit
|
|
812
|
+
def test_parse_proof_run_type_invalid_run_type(
|
|
813
|
+
caplog: Generator[pytest.LogCaptureFixture, None, None],
|
|
814
|
+
) -> None:
|
|
815
|
+
c = Circuit()
|
|
816
|
+
config_invalid = CircuitExecutionConfig(
|
|
817
|
+
witness_file="w.wtns",
|
|
818
|
+
input_file="i.json",
|
|
819
|
+
proof_file="p.json",
|
|
820
|
+
public_path="pub.json",
|
|
821
|
+
verification_key="vk.key",
|
|
822
|
+
circuit_name="model",
|
|
823
|
+
circuit_path="path.txt",
|
|
824
|
+
proof_system=None,
|
|
825
|
+
output_file="out.json",
|
|
826
|
+
metadata_path="metadata.json",
|
|
827
|
+
architecture_path="architecture.json",
|
|
828
|
+
w_and_b_path="w_and_b.json",
|
|
829
|
+
quantized_path="quantized_model.pt",
|
|
830
|
+
run_type="NOT_A_REAL_RUN_TYPE", # Invalid run type
|
|
831
|
+
dev_mode=False,
|
|
832
|
+
ecc=True,
|
|
833
|
+
write_json=False,
|
|
834
|
+
bench=False,
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
with pytest.raises(CircuitRunError, match="Unsupported run type"):
|
|
838
|
+
c.parse_proof_run_type(config_invalid)
|
|
839
|
+
|
|
840
|
+
# Check that the error messages are logged
|
|
841
|
+
assert "Unknown run type: NOT_A_REAL_RUN_TYPE" in caplog.text
|
|
842
|
+
assert "Operation NOT_A_REAL_RUN_TYPE failed" in caplog.text
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
@pytest.mark.unit
|
|
846
|
+
@patch(
|
|
847
|
+
"python.core.circuits.base.compile_circuit",
|
|
848
|
+
side_effect=Exception("Boom goes the dynamite!"),
|
|
849
|
+
)
|
|
850
|
+
@patch.object(Circuit, "_compile_preprocessing")
|
|
851
|
+
def test_parse_proof_run_type_catches_internal_exception(
|
|
852
|
+
mock_compile_preprocessing: MagicMock,
|
|
853
|
+
mock_compile: MagicMock,
|
|
854
|
+
caplog: Generator[pytest.LogCaptureFixture, None, None],
|
|
855
|
+
) -> None:
|
|
856
|
+
c = Circuit()
|
|
857
|
+
|
|
858
|
+
config_exception = CircuitExecutionConfig(
|
|
859
|
+
witness_file="w.wtns",
|
|
860
|
+
input_file="i.json",
|
|
861
|
+
proof_file="p.json",
|
|
862
|
+
public_path="pub.json",
|
|
863
|
+
verification_key="vk.key",
|
|
864
|
+
circuit_name="model",
|
|
865
|
+
circuit_path="path.txt",
|
|
866
|
+
proof_system=None,
|
|
867
|
+
output_file="out.json",
|
|
868
|
+
metadata_path="metadata.json",
|
|
869
|
+
architecture_path="architecture.json",
|
|
870
|
+
w_and_b_path="w_and_b.json",
|
|
871
|
+
quantized_path="quantized_path.pt",
|
|
872
|
+
run_type=RunType.COMPILE_CIRCUIT,
|
|
873
|
+
dev_mode=False,
|
|
874
|
+
ecc=True,
|
|
875
|
+
write_json=False,
|
|
876
|
+
bench=False,
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
# This will raise inside `compile_circuit`, which is patched to raise
|
|
880
|
+
with pytest.raises(CircuitRunError, match="Circuit operation 'Compile' failed"):
|
|
881
|
+
|
|
882
|
+
c.parse_proof_run_type(config_exception)
|
|
883
|
+
|
|
884
|
+
# Check that the error message is logged
|
|
885
|
+
assert "Operation RunType.COMPILE_CIRCUIT failed" in caplog.text
|
|
886
|
+
assert mock_compile.called
|
|
887
|
+
assert mock_compile_preprocessing.called
|
|
888
|
+
|
|
889
|
+
|
|
890
|
+
@pytest.mark.unit
|
|
891
|
+
def test_save_and_load_model_not_implemented() -> None:
|
|
892
|
+
c = Circuit()
|
|
893
|
+
assert hasattr(c, "save_model")
|
|
894
|
+
assert hasattr(c, "load_model")
|
|
895
|
+
assert hasattr(c, "save_quantized_model")
|
|
896
|
+
assert hasattr(c, "load_quantized_model")
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
# ---------- New error handling tests ----------
|
|
900
|
+
@pytest.mark.unit
|
|
901
|
+
def test_adjust_inputs_file_error() -> None:
|
|
902
|
+
c = Circuit()
|
|
903
|
+
c.input_variables = ["input"]
|
|
904
|
+
c.input_shape = [2, 2]
|
|
905
|
+
c.scale_base = 2
|
|
906
|
+
c.scale_exponent = 1
|
|
907
|
+
|
|
908
|
+
with patch(
|
|
909
|
+
"python.core.circuits.base.read_from_json",
|
|
910
|
+
side_effect=FileNotFoundError("File not found"),
|
|
911
|
+
):
|
|
912
|
+
_ = c
|
|
913
|
+
with pytest.raises(CircuitFileError, match="Failed to read input file"):
|
|
914
|
+
c.adjust_inputs("nonexistent.json")
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
@pytest.mark.unit
|
|
918
|
+
def test_adjust_inputs_processing_error() -> None:
|
|
919
|
+
c = Circuit()
|
|
920
|
+
c.input_variables = ["input"]
|
|
921
|
+
c.input_shape = [2, 2]
|
|
922
|
+
c.scale_base = 2
|
|
923
|
+
c.scale_exponent = 1
|
|
924
|
+
|
|
925
|
+
with patch(
|
|
926
|
+
"python.core.circuits.base.read_from_json",
|
|
927
|
+
return_value={"input": [1, 2, 3, 4]},
|
|
928
|
+
):
|
|
929
|
+
_ = c
|
|
930
|
+
with patch("torch.tensor") as mock_tensor:
|
|
931
|
+
mock_tensor.side_effect = RuntimeError("Invalid tensor shape")
|
|
932
|
+
|
|
933
|
+
with pytest.raises(
|
|
934
|
+
CircuitProcessingError,
|
|
935
|
+
match="Failed to reshape input data",
|
|
936
|
+
):
|
|
937
|
+
c.adjust_inputs("dummy.json")
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
@pytest.mark.unit
|
|
941
|
+
def test_get_inputs_from_file_file_error() -> None:
|
|
942
|
+
c = Circuit()
|
|
943
|
+
with patch.object(
|
|
944
|
+
c,
|
|
945
|
+
"_read_from_json_safely",
|
|
946
|
+
side_effect=CircuitFileError("Failed to read input file: protected.json"),
|
|
947
|
+
):
|
|
948
|
+
_ = c
|
|
949
|
+
with pytest.raises(CircuitFileError, match="Failed to read input file"):
|
|
950
|
+
c.get_inputs_from_file("protected.json")
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
@pytest.mark.unit
|
|
954
|
+
def test_get_inputs_from_file_processing_error() -> None:
|
|
955
|
+
c = Circuit()
|
|
956
|
+
c.scale_base = 2
|
|
957
|
+
c.scale_exponent = 1
|
|
958
|
+
|
|
959
|
+
with patch.object(
|
|
960
|
+
c,
|
|
961
|
+
"_read_from_json_safely",
|
|
962
|
+
return_value={"input": "invalid_data"},
|
|
963
|
+
):
|
|
964
|
+
_ = c
|
|
965
|
+
with pytest.raises(
|
|
966
|
+
CircuitProcessingError,
|
|
967
|
+
match="Failed to scale input data",
|
|
968
|
+
):
|
|
969
|
+
c.get_inputs_from_file("dummy.json", is_scaled=False)
|