JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +5 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,1158 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Generator
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
import python.tests.circuit_e2e_tests.helper_fns_for_tests # noqa: F401
|
|
11
|
+
from python.core.circuits.errors import CircuitRunError
|
|
12
|
+
|
|
13
|
+
# Assume these are your models
|
|
14
|
+
# Enums, utils
|
|
15
|
+
from python.core.utils.helper_functions import CircuitExecutionConfig, RunType
|
|
16
|
+
from python.tests.circuit_e2e_tests.helper_fns_for_tests import (
|
|
17
|
+
BAD_OUTPUT,
|
|
18
|
+
GOOD_OUTPUT,
|
|
19
|
+
NestedArray,
|
|
20
|
+
add_1_to_first_element,
|
|
21
|
+
assert_very_close,
|
|
22
|
+
check_model_compiles, # noqa: F401
|
|
23
|
+
check_witness_generated, # noqa: F401
|
|
24
|
+
circuit_compile_results,
|
|
25
|
+
contains_float,
|
|
26
|
+
model_fixture, # noqa: F401
|
|
27
|
+
temp_input_file, # noqa: F401
|
|
28
|
+
temp_output_file, # noqa: F401
|
|
29
|
+
temp_proof_file, # noqa: F401
|
|
30
|
+
temp_witness_file, # noqa: F401
|
|
31
|
+
witness_generated_results,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
OUTPUTTWICE = 2
|
|
35
|
+
OUTPUTTHREETIMES = 3
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@pytest.mark.e2e()
|
|
39
|
+
def test_circuit_compiles(model_fixture: dict[str, Any]) -> None:
|
|
40
|
+
# Here you could just check that circuit file exists
|
|
41
|
+
circuit_compile_results[model_fixture["model"]] = False
|
|
42
|
+
assert Path.exists(model_fixture["circuit_path"])
|
|
43
|
+
circuit_compile_results[model_fixture["model"]] = True
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.mark.e2e()
|
|
47
|
+
def test_witness_dev(
|
|
48
|
+
model_fixture: dict[str, Any],
|
|
49
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
50
|
+
temp_witness_file: Generator[Path, None, None],
|
|
51
|
+
temp_input_file: Generator[Path, None, None],
|
|
52
|
+
temp_output_file: Generator[Path, None, None],
|
|
53
|
+
check_model_compiles: None,
|
|
54
|
+
) -> None:
|
|
55
|
+
_ = check_model_compiles
|
|
56
|
+
|
|
57
|
+
model = model_fixture["model"]
|
|
58
|
+
witness_generated_results[model_fixture["model"]] = False
|
|
59
|
+
model.base_testing(
|
|
60
|
+
CircuitExecutionConfig(
|
|
61
|
+
run_type=RunType.GEN_WITNESS,
|
|
62
|
+
dev_mode=False,
|
|
63
|
+
witness_file=temp_witness_file,
|
|
64
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
65
|
+
input_file=temp_input_file,
|
|
66
|
+
output_file=temp_output_file,
|
|
67
|
+
write_json=True,
|
|
68
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
69
|
+
),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
captured = capsys.readouterr()
|
|
73
|
+
stdout = captured.out
|
|
74
|
+
stderr = captured.err
|
|
75
|
+
|
|
76
|
+
print(stdout)
|
|
77
|
+
|
|
78
|
+
assert stderr == ""
|
|
79
|
+
|
|
80
|
+
assert Path.exists(temp_witness_file)
|
|
81
|
+
assert "Running cargo command:" in stdout
|
|
82
|
+
for output in GOOD_OUTPUT:
|
|
83
|
+
assert output in stdout, f"Expected '{output}' in stdout, but it was not found."
|
|
84
|
+
for output in BAD_OUTPUT:
|
|
85
|
+
assert (
|
|
86
|
+
output not in stdout
|
|
87
|
+
), f"Did not expect '{output}' in stdout, but it was found."
|
|
88
|
+
|
|
89
|
+
witness_generated_results[model_fixture["model"]] = True
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@pytest.mark.e2e()
|
|
93
|
+
def test_witness_wrong_outputs_dev(
|
|
94
|
+
model_fixture: dict[str, Any],
|
|
95
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
96
|
+
temp_witness_file: Generator[Path, None, None],
|
|
97
|
+
temp_input_file: Generator[Path, None, None],
|
|
98
|
+
temp_output_file: Generator[Path, None, None],
|
|
99
|
+
monkeypatch: Generator[pytest.MonkeyPatch, None, None],
|
|
100
|
+
check_model_compiles: None,
|
|
101
|
+
check_witness_generated: None,
|
|
102
|
+
caplog: Generator[pytest.LogCaptureFixture, None, None],
|
|
103
|
+
) -> None:
|
|
104
|
+
_ = check_witness_generated
|
|
105
|
+
_ = check_model_compiles
|
|
106
|
+
|
|
107
|
+
model = model_fixture["model"]
|
|
108
|
+
original_get_outputs = model.get_outputs
|
|
109
|
+
|
|
110
|
+
def patched_get_outputs(*args: tuple, **kwargs: dict[str, Any]) -> NestedArray:
|
|
111
|
+
result = original_get_outputs(*args, **kwargs)
|
|
112
|
+
return add_1_to_first_element(result)
|
|
113
|
+
|
|
114
|
+
monkeypatch.setattr(model, "get_outputs", patched_get_outputs)
|
|
115
|
+
with pytest.raises(CircuitRunError):
|
|
116
|
+
model.base_testing(
|
|
117
|
+
CircuitExecutionConfig(
|
|
118
|
+
run_type=RunType.GEN_WITNESS,
|
|
119
|
+
dev_mode=False,
|
|
120
|
+
witness_file=temp_witness_file,
|
|
121
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
122
|
+
input_file=temp_input_file,
|
|
123
|
+
output_file=temp_output_file,
|
|
124
|
+
write_json=True,
|
|
125
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
126
|
+
),
|
|
127
|
+
)
|
|
128
|
+
captured = capsys.readouterr()
|
|
129
|
+
stdout = captured.out
|
|
130
|
+
stderr = captured.err
|
|
131
|
+
print(stdout)
|
|
132
|
+
# assert False
|
|
133
|
+
|
|
134
|
+
assert stderr == ""
|
|
135
|
+
|
|
136
|
+
assert not Path.exists(temp_witness_file)
|
|
137
|
+
assert "Running cargo command:" in stdout
|
|
138
|
+
for output in GOOD_OUTPUT:
|
|
139
|
+
assert (
|
|
140
|
+
output not in stdout
|
|
141
|
+
), f"Did not expect '{output}' in stdout, but it was found."
|
|
142
|
+
for output in BAD_OUTPUT:
|
|
143
|
+
assert (
|
|
144
|
+
output in caplog.text
|
|
145
|
+
), f"Expected '{output}' in stdout, but it was not found."
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@pytest.mark.e2e()
|
|
149
|
+
def test_witness_prove_verify_true_inputs_dev(
|
|
150
|
+
model_fixture: dict[str, Any],
|
|
151
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
152
|
+
temp_witness_file: Generator[Path, None, None],
|
|
153
|
+
temp_input_file: Generator[Path, None, None],
|
|
154
|
+
temp_output_file: Generator[Path, None, None],
|
|
155
|
+
temp_proof_file: Generator[Path, None, None],
|
|
156
|
+
check_model_compiles: None,
|
|
157
|
+
check_witness_generated: None,
|
|
158
|
+
) -> None:
|
|
159
|
+
_ = check_witness_generated
|
|
160
|
+
_ = check_model_compiles
|
|
161
|
+
|
|
162
|
+
model = model_fixture["model"]
|
|
163
|
+
print(model)
|
|
164
|
+
model.base_testing(
|
|
165
|
+
CircuitExecutionConfig(
|
|
166
|
+
run_type=RunType.GEN_WITNESS,
|
|
167
|
+
dev_mode=False,
|
|
168
|
+
witness_file=temp_witness_file,
|
|
169
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
170
|
+
input_file=temp_input_file,
|
|
171
|
+
output_file=temp_output_file,
|
|
172
|
+
write_json=True,
|
|
173
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
174
|
+
),
|
|
175
|
+
)
|
|
176
|
+
model.base_testing(
|
|
177
|
+
CircuitExecutionConfig(
|
|
178
|
+
run_type=RunType.PROVE_WITNESS,
|
|
179
|
+
dev_mode=False,
|
|
180
|
+
witness_file=temp_witness_file,
|
|
181
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
182
|
+
input_file=temp_input_file,
|
|
183
|
+
output_file=temp_output_file,
|
|
184
|
+
proof_file=temp_proof_file,
|
|
185
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
186
|
+
),
|
|
187
|
+
)
|
|
188
|
+
model.base_testing(
|
|
189
|
+
CircuitExecutionConfig(
|
|
190
|
+
run_type=RunType.GEN_VERIFY,
|
|
191
|
+
dev_mode=False,
|
|
192
|
+
witness_file=temp_witness_file,
|
|
193
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
194
|
+
input_file=temp_input_file,
|
|
195
|
+
output_file=temp_output_file,
|
|
196
|
+
proof_file=temp_proof_file,
|
|
197
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
198
|
+
),
|
|
199
|
+
)
|
|
200
|
+
# ASSERTIONS TODO
|
|
201
|
+
|
|
202
|
+
captured = capsys.readouterr()
|
|
203
|
+
stdout = captured.out
|
|
204
|
+
stderr = captured.err
|
|
205
|
+
print(stdout)
|
|
206
|
+
|
|
207
|
+
assert stderr == ""
|
|
208
|
+
|
|
209
|
+
assert Path.exists(temp_witness_file), "Witness file not generated"
|
|
210
|
+
|
|
211
|
+
# Unexpected output
|
|
212
|
+
assert stdout.count("poly.num_vars() == *params") == 0, (
|
|
213
|
+
"'poly.num_vars() == *params' thrown. May need a dummy variable(s) "
|
|
214
|
+
"to get rid of error. Dummy variables should be private variables. "
|
|
215
|
+
"Can set = 1 in read_inputs and assert == 1 at end of circuit"
|
|
216
|
+
)
|
|
217
|
+
assert stdout.count("Proof generation failed") == 0, "Proof generation failed"
|
|
218
|
+
assert Path.exists(temp_proof_file), "Proof file not generated"
|
|
219
|
+
|
|
220
|
+
assert stdout.count("Verification generation failed") == 0, "Verification failed"
|
|
221
|
+
# Expected output
|
|
222
|
+
assert stdout.count("Running cargo command:") == OUTPUTTHREETIMES, (
|
|
223
|
+
"Expected 'Running cargo command: ' in stdout three times, "
|
|
224
|
+
"but it was not found."
|
|
225
|
+
)
|
|
226
|
+
assert (
|
|
227
|
+
stdout.count("Witness Generated") == 1
|
|
228
|
+
), "Expected 'Witness Generated' in stdout three times, but it was not found."
|
|
229
|
+
|
|
230
|
+
assert (
|
|
231
|
+
stdout.count("proving") == 1
|
|
232
|
+
), "Expected 'proving' in stdout three times, but it was not found."
|
|
233
|
+
assert (
|
|
234
|
+
stdout.count("Proved") == 1
|
|
235
|
+
), "Expected 'Proved' in stdout three times, but it was not found."
|
|
236
|
+
|
|
237
|
+
assert (
|
|
238
|
+
stdout.count("Verified") == 1
|
|
239
|
+
), "Expected 'Verified' in stdout three times, but it was not found."
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@pytest.mark.e2e()
|
|
243
|
+
def test_witness_prove_verify_true_inputs_dev_expander_call(
|
|
244
|
+
model_fixture: dict[str, Any],
|
|
245
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
246
|
+
temp_witness_file: Generator[Path, None, None],
|
|
247
|
+
temp_input_file: Generator[Path, None, None],
|
|
248
|
+
temp_output_file: Generator[Path, None, None],
|
|
249
|
+
temp_proof_file: Generator[Path, None, None],
|
|
250
|
+
check_model_compiles: None,
|
|
251
|
+
check_witness_generated: None,
|
|
252
|
+
) -> None:
|
|
253
|
+
_ = check_witness_generated
|
|
254
|
+
_ = check_model_compiles
|
|
255
|
+
|
|
256
|
+
model = model_fixture["model"]
|
|
257
|
+
model.base_testing(
|
|
258
|
+
CircuitExecutionConfig(
|
|
259
|
+
run_type=RunType.GEN_WITNESS,
|
|
260
|
+
dev_mode=False,
|
|
261
|
+
witness_file=temp_witness_file,
|
|
262
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
263
|
+
input_file=temp_input_file,
|
|
264
|
+
output_file=temp_output_file,
|
|
265
|
+
write_json=True,
|
|
266
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
267
|
+
),
|
|
268
|
+
)
|
|
269
|
+
model.base_testing(
|
|
270
|
+
CircuitExecutionConfig(
|
|
271
|
+
run_type=RunType.PROVE_WITNESS,
|
|
272
|
+
dev_mode=False,
|
|
273
|
+
witness_file=temp_witness_file,
|
|
274
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
275
|
+
input_file=temp_input_file,
|
|
276
|
+
output_file=temp_output_file,
|
|
277
|
+
proof_file=temp_proof_file,
|
|
278
|
+
ecc=False,
|
|
279
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
280
|
+
),
|
|
281
|
+
)
|
|
282
|
+
model.base_testing(
|
|
283
|
+
CircuitExecutionConfig(
|
|
284
|
+
run_type=RunType.GEN_VERIFY,
|
|
285
|
+
dev_mode=False,
|
|
286
|
+
witness_file=temp_witness_file,
|
|
287
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
288
|
+
input_file=temp_input_file,
|
|
289
|
+
output_file=temp_output_file,
|
|
290
|
+
proof_file=temp_proof_file,
|
|
291
|
+
ecc=False,
|
|
292
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
293
|
+
),
|
|
294
|
+
)
|
|
295
|
+
# ASSERTIONS TODO
|
|
296
|
+
|
|
297
|
+
captured = capsys.readouterr()
|
|
298
|
+
stdout = captured.out
|
|
299
|
+
stderr = captured.err
|
|
300
|
+
print(stdout)
|
|
301
|
+
print(stderr)
|
|
302
|
+
|
|
303
|
+
assert stderr == ""
|
|
304
|
+
# assert False
|
|
305
|
+
assert Path.exists(temp_witness_file), "Witness file not generated"
|
|
306
|
+
|
|
307
|
+
# Unexpected output
|
|
308
|
+
assert stdout.count("poly.num_vars() == *params") == 0, (
|
|
309
|
+
"'poly.num_vars() == *params' thrown. May need a dummy variable(s) "
|
|
310
|
+
"to get rid of error. Dummy variables should be private variables. "
|
|
311
|
+
"Can set = 1 in read_inputs and assert == 1 at end of circuit"
|
|
312
|
+
)
|
|
313
|
+
assert stdout.count("Proof generation failed") == 0, "Proof generation failed"
|
|
314
|
+
assert Path.exists(temp_proof_file), "Proof file not generated"
|
|
315
|
+
|
|
316
|
+
assert stdout.count("Verification generation failed") == 0, "Verification failed"
|
|
317
|
+
# Expected output
|
|
318
|
+
assert (
|
|
319
|
+
stdout.count("Running cargo command:") == 1
|
|
320
|
+
), "Expected 'Running cargo command: ' in stdout once, but it was not found."
|
|
321
|
+
assert (
|
|
322
|
+
stdout.count("Witness Generated") == 1
|
|
323
|
+
), "Expected 'Witness Generated' in stdout three times, but it was not found."
|
|
324
|
+
|
|
325
|
+
assert stdout.count("proving") == 1, "Expected 'proving' but it was not found."
|
|
326
|
+
|
|
327
|
+
assert (
|
|
328
|
+
stdout.count("verifying proof") == 1
|
|
329
|
+
), "Expected 'verifying proof' but it was not found."
|
|
330
|
+
assert stdout.count("success") == 1, "Expected 'success' but it was not found."
|
|
331
|
+
|
|
332
|
+
assert (
|
|
333
|
+
stdout.count("expander-exec verify succeeded") == 1
|
|
334
|
+
), "Expected 'expander-exec verify succeeded' but it was not found."
|
|
335
|
+
assert (
|
|
336
|
+
stdout.count("expander-exec prove succeeded") == 1
|
|
337
|
+
), "Expected 'expander-exec prove succeeded' but it was not found."
|
|
338
|
+
|
|
339
|
+
assert stdout.count("proving") == 1, "Expected 'proving' but it was not found."
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@pytest.mark.e2e()
|
|
343
|
+
def test_witness_read_after_write_json(
|
|
344
|
+
model_fixture: dict[str, Any],
|
|
345
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
346
|
+
temp_witness_file: Generator[Path, None, None],
|
|
347
|
+
temp_input_file: Generator[Path, None, None],
|
|
348
|
+
temp_output_file: Generator[Path, None, None],
|
|
349
|
+
check_model_compiles: None,
|
|
350
|
+
check_witness_generated: None,
|
|
351
|
+
) -> None:
|
|
352
|
+
_ = check_witness_generated
|
|
353
|
+
_ = check_model_compiles
|
|
354
|
+
|
|
355
|
+
# Step 1: Write the input file via write_json=True
|
|
356
|
+
model_write = model_fixture["model"]
|
|
357
|
+
model_write.base_testing(
|
|
358
|
+
CircuitExecutionConfig(
|
|
359
|
+
run_type=RunType.GEN_WITNESS,
|
|
360
|
+
dev_mode=False,
|
|
361
|
+
witness_file=temp_witness_file,
|
|
362
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
363
|
+
input_file=temp_input_file,
|
|
364
|
+
output_file=temp_output_file,
|
|
365
|
+
write_json=True,
|
|
366
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
367
|
+
),
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
if Path.exists(temp_witness_file):
|
|
371
|
+
Path.unlink(temp_witness_file)
|
|
372
|
+
assert not Path.exists(temp_witness_file)
|
|
373
|
+
|
|
374
|
+
# Optional: Load the written input for inspection
|
|
375
|
+
with Path.open(temp_input_file, "r") as f:
|
|
376
|
+
written_input_data = f.read()
|
|
377
|
+
|
|
378
|
+
# Step 2: Read from that same input file (write_json=False)
|
|
379
|
+
model_read = model_fixture["model"]
|
|
380
|
+
model_read.base_testing(
|
|
381
|
+
CircuitExecutionConfig(
|
|
382
|
+
run_type=RunType.GEN_WITNESS,
|
|
383
|
+
dev_mode=False,
|
|
384
|
+
witness_file=temp_witness_file,
|
|
385
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
386
|
+
input_file=temp_input_file,
|
|
387
|
+
output_file=temp_output_file,
|
|
388
|
+
write_json=False,
|
|
389
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
390
|
+
),
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Step 3: Validate expected outputs and no errors
|
|
394
|
+
captured = capsys.readouterr()
|
|
395
|
+
stdout = captured.out
|
|
396
|
+
stderr = captured.err
|
|
397
|
+
|
|
398
|
+
print(stdout)
|
|
399
|
+
|
|
400
|
+
assert stderr == ""
|
|
401
|
+
|
|
402
|
+
assert Path.exists(temp_witness_file), "Witness file not generated"
|
|
403
|
+
assert "Running cargo command:" in stdout
|
|
404
|
+
|
|
405
|
+
# Check good output appeared
|
|
406
|
+
for output in GOOD_OUTPUT:
|
|
407
|
+
assert output in stdout, f"Expected '{output}' in stdout, but it was not found."
|
|
408
|
+
|
|
409
|
+
# Ensure no unexpected errors
|
|
410
|
+
for output in BAD_OUTPUT:
|
|
411
|
+
assert (
|
|
412
|
+
output not in stdout
|
|
413
|
+
), f"Did not expect '{output}' in stdout, but it was found."
|
|
414
|
+
|
|
415
|
+
# Optional: verify that input file content was actually read
|
|
416
|
+
with Path.open(temp_input_file, "r") as f:
|
|
417
|
+
read_input_data = f.read()
|
|
418
|
+
|
|
419
|
+
assert (
|
|
420
|
+
read_input_data == written_input_data
|
|
421
|
+
), "Input JSON read is not identical to what was written"
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@pytest.mark.e2e()
|
|
425
|
+
def test_witness_fresh_compile_dev(
|
|
426
|
+
model_fixture: dict[str, Any],
|
|
427
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
428
|
+
temp_witness_file: Generator[Path, None, None],
|
|
429
|
+
temp_input_file: Generator[Path, None, None],
|
|
430
|
+
temp_output_file: Generator[Path, None, None],
|
|
431
|
+
check_model_compiles: None,
|
|
432
|
+
check_witness_generated: None,
|
|
433
|
+
) -> None:
|
|
434
|
+
_ = check_witness_generated
|
|
435
|
+
_ = check_model_compiles
|
|
436
|
+
|
|
437
|
+
model = model_fixture["model"]
|
|
438
|
+
model.base_testing(
|
|
439
|
+
CircuitExecutionConfig(
|
|
440
|
+
run_type=RunType.GEN_WITNESS,
|
|
441
|
+
dev_mode=True,
|
|
442
|
+
witness_file=temp_witness_file,
|
|
443
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
444
|
+
input_file=temp_input_file,
|
|
445
|
+
output_file=temp_output_file,
|
|
446
|
+
write_json=True,
|
|
447
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
448
|
+
),
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
captured = capsys.readouterr()
|
|
452
|
+
stdout = captured.out
|
|
453
|
+
stderr = captured.err
|
|
454
|
+
|
|
455
|
+
print(stdout)
|
|
456
|
+
|
|
457
|
+
assert stderr == ""
|
|
458
|
+
|
|
459
|
+
assert Path.exists(temp_witness_file)
|
|
460
|
+
assert "Running cargo command:" in stdout
|
|
461
|
+
for output in GOOD_OUTPUT:
|
|
462
|
+
assert output in stdout, f"Expected '{output}' in stdout, but it was not found."
|
|
463
|
+
for output in BAD_OUTPUT:
|
|
464
|
+
assert (
|
|
465
|
+
output not in stdout
|
|
466
|
+
), f"Did not expect '{output}' in stdout, but it was found."
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
# Use once fixed input shape read in rust
|
|
470
|
+
@pytest.mark.e2e()
|
|
471
|
+
def test_witness_incorrect_input_shape(
|
|
472
|
+
model_fixture: dict[str, Any],
|
|
473
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
474
|
+
temp_witness_file: Generator[Path, None, None],
|
|
475
|
+
temp_input_file: Generator[Path, None, None],
|
|
476
|
+
temp_output_file: Generator[Path, None, None],
|
|
477
|
+
check_model_compiles: None,
|
|
478
|
+
check_witness_generated: None,
|
|
479
|
+
) -> None:
|
|
480
|
+
_ = check_witness_generated
|
|
481
|
+
_ = check_model_compiles
|
|
482
|
+
|
|
483
|
+
# Step 1: Write the input file via write_json=True
|
|
484
|
+
model_write = model_fixture["model"]
|
|
485
|
+
model_write.base_testing(
|
|
486
|
+
CircuitExecutionConfig(
|
|
487
|
+
run_type=RunType.GEN_WITNESS,
|
|
488
|
+
dev_mode=False,
|
|
489
|
+
witness_file=temp_witness_file,
|
|
490
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
491
|
+
input_file=temp_input_file,
|
|
492
|
+
output_file=temp_output_file,
|
|
493
|
+
write_json=True,
|
|
494
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
495
|
+
),
|
|
496
|
+
)
|
|
497
|
+
assert Path.exists(temp_witness_file)
|
|
498
|
+
Path.unlink(temp_witness_file)
|
|
499
|
+
assert not Path.exists(temp_witness_file)
|
|
500
|
+
|
|
501
|
+
# Optional: Load the written input for inspection
|
|
502
|
+
with Path.open(temp_input_file, "r") as f:
|
|
503
|
+
written_input_data = f.read()
|
|
504
|
+
input_data = json.loads(written_input_data)
|
|
505
|
+
for key in input_data:
|
|
506
|
+
if isinstance(input_data[key], list):
|
|
507
|
+
input_data[key] = torch.as_tensor(input_data[key]).flatten().tolist()
|
|
508
|
+
assert torch.as_tensor(input_data[key]).dim() <= 1, (
|
|
509
|
+
f"Input data for {key} is not 1D tensor. "
|
|
510
|
+
"This is a testing error, not a model error."
|
|
511
|
+
"Please fix this test to properly flatten."
|
|
512
|
+
)
|
|
513
|
+
with Path.open(temp_input_file, "w") as f:
|
|
514
|
+
json.dump(input_data, f)
|
|
515
|
+
|
|
516
|
+
# Step 2: Read from that same input file (write_json=False)
|
|
517
|
+
model_read = model_fixture["model"]
|
|
518
|
+
model_read.base_testing(
|
|
519
|
+
CircuitExecutionConfig(
|
|
520
|
+
run_type=RunType.GEN_WITNESS,
|
|
521
|
+
dev_mode=False,
|
|
522
|
+
witness_file=temp_witness_file,
|
|
523
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
524
|
+
input_file=temp_input_file,
|
|
525
|
+
output_file=temp_output_file,
|
|
526
|
+
write_json=False,
|
|
527
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
528
|
+
),
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
# Step 3: Validate expected outputs and no errors
|
|
532
|
+
captured = capsys.readouterr()
|
|
533
|
+
stdout = captured.out
|
|
534
|
+
stderr = captured.err
|
|
535
|
+
|
|
536
|
+
print(stdout)
|
|
537
|
+
|
|
538
|
+
assert stderr == ""
|
|
539
|
+
|
|
540
|
+
assert Path.exists(temp_witness_file), "Witness file not generated"
|
|
541
|
+
assert (
|
|
542
|
+
stdout.count("Running cargo command:") == OUTPUTTWICE
|
|
543
|
+
), "Expected 'Running cargo command: ' in stdout twice, but it was not found."
|
|
544
|
+
|
|
545
|
+
# Check good output appeared
|
|
546
|
+
for output in GOOD_OUTPUT:
|
|
547
|
+
assert (
|
|
548
|
+
stdout.count(output) == OUTPUTTWICE
|
|
549
|
+
), f"Expected '{output}' in stdout, but it was not found."
|
|
550
|
+
|
|
551
|
+
# Ensure no unexpected errors
|
|
552
|
+
for output in BAD_OUTPUT:
|
|
553
|
+
assert (
|
|
554
|
+
output not in stdout
|
|
555
|
+
), f"Did not expect '{output}' in stdout, but it was found."
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
@pytest.mark.e2e()
|
|
559
|
+
def test_witness_unscaled(
|
|
560
|
+
model_fixture: dict[str, Any],
|
|
561
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
562
|
+
temp_witness_file: Generator[Path, None, None],
|
|
563
|
+
temp_input_file: Generator[Path, None, None],
|
|
564
|
+
temp_output_file: Generator[Path, None, None],
|
|
565
|
+
check_model_compiles: None,
|
|
566
|
+
check_witness_generated: None,
|
|
567
|
+
) -> None:
|
|
568
|
+
_ = check_witness_generated
|
|
569
|
+
_ = check_model_compiles
|
|
570
|
+
|
|
571
|
+
# Step 1: Write the input file via write_json=True
|
|
572
|
+
model_write = model_fixture["model"]
|
|
573
|
+
model_write.base_testing(
|
|
574
|
+
CircuitExecutionConfig(
|
|
575
|
+
run_type=RunType.GEN_WITNESS,
|
|
576
|
+
dev_mode=False,
|
|
577
|
+
witness_file=temp_witness_file,
|
|
578
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
579
|
+
input_file=temp_input_file,
|
|
580
|
+
output_file=temp_output_file,
|
|
581
|
+
write_json=True,
|
|
582
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
583
|
+
),
|
|
584
|
+
)
|
|
585
|
+
if Path.exists(temp_witness_file):
|
|
586
|
+
Path.unlink(temp_witness_file)
|
|
587
|
+
assert not Path.exists(temp_witness_file)
|
|
588
|
+
|
|
589
|
+
# Rescale
|
|
590
|
+
with Path.open(temp_input_file, "r") as f:
|
|
591
|
+
written_input_data = f.read()
|
|
592
|
+
input_data = json.loads(written_input_data)
|
|
593
|
+
if hasattr(model_write, "scale_base") and hasattr(model_write, "scale_exponent"):
|
|
594
|
+
for key in input_data:
|
|
595
|
+
print(input_data[key])
|
|
596
|
+
input_data[key] = torch.div(
|
|
597
|
+
torch.as_tensor(input_data[key]),
|
|
598
|
+
model_write.scale_base**model_write.scale_exponent,
|
|
599
|
+
).tolist()
|
|
600
|
+
print(input_data[key])
|
|
601
|
+
else:
|
|
602
|
+
msg = "Model does not have scale_base attribute"
|
|
603
|
+
raise NotImplementedError(msg)
|
|
604
|
+
assert contains_float(
|
|
605
|
+
input_data,
|
|
606
|
+
), (
|
|
607
|
+
"This is a testing error, not a model error. "
|
|
608
|
+
"Please fix this test to properly turn data to float."
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
with Path.open(temp_output_file, "r") as f:
|
|
612
|
+
written_output_data = f.read()
|
|
613
|
+
Path.unlink(temp_output_file)
|
|
614
|
+
|
|
615
|
+
with Path.open(temp_input_file, "w") as f:
|
|
616
|
+
json.dump(input_data, f)
|
|
617
|
+
|
|
618
|
+
# Step 2: Read from that same input file (write_json=False)
|
|
619
|
+
model_read = model_fixture["model"]
|
|
620
|
+
model_read.base_testing(
|
|
621
|
+
CircuitExecutionConfig(
|
|
622
|
+
run_type=RunType.GEN_WITNESS,
|
|
623
|
+
dev_mode=False,
|
|
624
|
+
witness_file=temp_witness_file,
|
|
625
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
626
|
+
input_file=temp_input_file,
|
|
627
|
+
output_file=temp_output_file,
|
|
628
|
+
write_json=False,
|
|
629
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
630
|
+
),
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
# Step 3: Validate expected outputs and no errors
|
|
634
|
+
captured = capsys.readouterr()
|
|
635
|
+
stdout = captured.out
|
|
636
|
+
stderr = captured.err
|
|
637
|
+
|
|
638
|
+
print(stdout)
|
|
639
|
+
|
|
640
|
+
assert stderr == ""
|
|
641
|
+
|
|
642
|
+
assert Path.exists(temp_witness_file), "Witness file not generated"
|
|
643
|
+
assert (
|
|
644
|
+
stdout.count("Running cargo command:") == OUTPUTTWICE
|
|
645
|
+
), "Expected 'Running cargo command: ' in stdout twice, but it was not found."
|
|
646
|
+
|
|
647
|
+
# Check good output appeared
|
|
648
|
+
for output in GOOD_OUTPUT:
|
|
649
|
+
assert (
|
|
650
|
+
stdout.count(output) == OUTPUTTWICE
|
|
651
|
+
), f"Expected '{output}' in stdout, but it was not found."
|
|
652
|
+
|
|
653
|
+
# Ensure no unexpected errors
|
|
654
|
+
for output in BAD_OUTPUT:
|
|
655
|
+
assert (
|
|
656
|
+
output not in stdout
|
|
657
|
+
), f"Did not expect '{output}' in stdout, but it was found."
|
|
658
|
+
|
|
659
|
+
assert Path.exists(temp_output_file), "Output file not generated"
|
|
660
|
+
with Path.open(temp_output_file, "r") as f:
|
|
661
|
+
new_output_file = f.read()
|
|
662
|
+
assert_very_close(
|
|
663
|
+
json.loads(new_output_file),
|
|
664
|
+
json.loads(written_output_data),
|
|
665
|
+
model_write,
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
@pytest.mark.e2e()
|
|
670
|
+
def test_witness_unscaled_and_incorrect_shape_input(
|
|
671
|
+
model_fixture: dict[str, Any],
|
|
672
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
673
|
+
temp_witness_file: Generator[Path, None, None],
|
|
674
|
+
temp_input_file: Generator[Path, None, None],
|
|
675
|
+
temp_output_file: Generator[Path, None, None],
|
|
676
|
+
check_model_compiles: None,
|
|
677
|
+
check_witness_generated: None,
|
|
678
|
+
) -> None:
|
|
679
|
+
_ = check_witness_generated
|
|
680
|
+
_ = check_model_compiles
|
|
681
|
+
|
|
682
|
+
# Step 1: Write the input file via write_json=True
|
|
683
|
+
model_write = model_fixture["model"]
|
|
684
|
+
model_write.base_testing(
|
|
685
|
+
CircuitExecutionConfig(
|
|
686
|
+
run_type=RunType.GEN_WITNESS,
|
|
687
|
+
dev_mode=False,
|
|
688
|
+
witness_file=temp_witness_file,
|
|
689
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
690
|
+
input_file=temp_input_file,
|
|
691
|
+
output_file=temp_output_file,
|
|
692
|
+
write_json=True,
|
|
693
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
694
|
+
),
|
|
695
|
+
)
|
|
696
|
+
if Path.exists(temp_witness_file):
|
|
697
|
+
Path.unlink(temp_witness_file)
|
|
698
|
+
assert not Path.exists(temp_witness_file)
|
|
699
|
+
|
|
700
|
+
# flatten shape
|
|
701
|
+
with Path.open(temp_input_file, "r") as f:
|
|
702
|
+
written_input_data = f.read()
|
|
703
|
+
input_data = json.loads(written_input_data)
|
|
704
|
+
for key in input_data:
|
|
705
|
+
if isinstance(input_data[key], list):
|
|
706
|
+
input_data[key] = torch.as_tensor(input_data[key]).flatten().tolist()
|
|
707
|
+
assert torch.as_tensor(input_data[key]).dim() <= 1, (
|
|
708
|
+
f"Input data for {key} is not 1D tensor. "
|
|
709
|
+
"This is a testing error, not a model error."
|
|
710
|
+
"Please fix this test to properly flatten."
|
|
711
|
+
)
|
|
712
|
+
with Path.open(temp_input_file, "w") as f:
|
|
713
|
+
json.dump(input_data, f)
|
|
714
|
+
# Rescale
|
|
715
|
+
with Path.open(temp_input_file, "r") as f:
|
|
716
|
+
written_input_data = f.read()
|
|
717
|
+
input_data = json.loads(written_input_data)
|
|
718
|
+
if hasattr(model_write, "scale_base") and hasattr(model_write, "scale_exponent"):
|
|
719
|
+
for key in input_data:
|
|
720
|
+
input_data[key] = torch.div(
|
|
721
|
+
torch.as_tensor(input_data[key]),
|
|
722
|
+
model_write.scale_base**model_write.scale_exponent,
|
|
723
|
+
).tolist()
|
|
724
|
+
else:
|
|
725
|
+
msg = "Model does not have scale_base attribute"
|
|
726
|
+
raise NotImplementedError(msg)
|
|
727
|
+
assert contains_float(
|
|
728
|
+
input_data,
|
|
729
|
+
), (
|
|
730
|
+
"This is a testing error, not a model error. "
|
|
731
|
+
"Please fix this test to properly turn data to float."
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
with Path.open(temp_output_file, "r") as f:
|
|
735
|
+
written_output_data = f.read()
|
|
736
|
+
Path.unlink(temp_output_file)
|
|
737
|
+
|
|
738
|
+
with Path.open(temp_input_file, "w") as f:
|
|
739
|
+
json.dump(input_data, f)
|
|
740
|
+
|
|
741
|
+
# Step 2: Read from that same input file (write_json=False)
|
|
742
|
+
# that has been rescaled and flattened
|
|
743
|
+
model_read = model_fixture["model"]
|
|
744
|
+
model_read.base_testing(
|
|
745
|
+
CircuitExecutionConfig(
|
|
746
|
+
run_type=RunType.GEN_WITNESS,
|
|
747
|
+
dev_mode=False,
|
|
748
|
+
witness_file=temp_witness_file,
|
|
749
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
750
|
+
input_file=temp_input_file,
|
|
751
|
+
output_file=temp_output_file,
|
|
752
|
+
write_json=False,
|
|
753
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
754
|
+
),
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
# Step 3: Validate expected outputs and no errors
|
|
758
|
+
captured = capsys.readouterr()
|
|
759
|
+
stdout = captured.out
|
|
760
|
+
stderr = captured.err
|
|
761
|
+
|
|
762
|
+
print(stdout)
|
|
763
|
+
|
|
764
|
+
assert stderr == ""
|
|
765
|
+
|
|
766
|
+
assert Path.exists(temp_witness_file), "Witness file not generated"
|
|
767
|
+
assert (
|
|
768
|
+
stdout.count("Running cargo command:") == OUTPUTTWICE
|
|
769
|
+
), "Expected 'Running cargo command: ' in stdout twice, but it was not found."
|
|
770
|
+
|
|
771
|
+
# Check good output appeared
|
|
772
|
+
for output in GOOD_OUTPUT:
|
|
773
|
+
assert (
|
|
774
|
+
stdout.count(output) == OUTPUTTWICE
|
|
775
|
+
), f"Expected '{output}' in stdout, but it was not found."
|
|
776
|
+
|
|
777
|
+
# Ensure no unexpected errors
|
|
778
|
+
for output in BAD_OUTPUT:
|
|
779
|
+
assert (
|
|
780
|
+
output not in stdout
|
|
781
|
+
), f"Did not expect '{output}' in stdout, but it was found."
|
|
782
|
+
|
|
783
|
+
assert Path.exists(temp_output_file), "Output file not generated"
|
|
784
|
+
with Path.open(temp_output_file, "r") as f:
|
|
785
|
+
new_output_file = f.read()
|
|
786
|
+
assert_very_close(
|
|
787
|
+
json.loads(new_output_file),
|
|
788
|
+
json.loads(written_output_data),
|
|
789
|
+
model_write,
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
@pytest.mark.e2e()
|
|
794
|
+
def test_witness_unscaled_and_incorrect_and_bad_named_input( # noqa: PLR0915
|
|
795
|
+
model_fixture: dict[str, Any],
|
|
796
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
797
|
+
temp_witness_file: Generator[Path, None, None],
|
|
798
|
+
temp_input_file: Generator[Path, None, None],
|
|
799
|
+
temp_output_file: Generator[Path, None, None],
|
|
800
|
+
check_model_compiles: None,
|
|
801
|
+
check_witness_generated: None,
|
|
802
|
+
) -> None:
|
|
803
|
+
_ = check_witness_generated
|
|
804
|
+
_ = check_model_compiles
|
|
805
|
+
|
|
806
|
+
# Step 1: Write the input file via write_json=True
|
|
807
|
+
model_write = model_fixture["model"]
|
|
808
|
+
model_write.base_testing(
|
|
809
|
+
CircuitExecutionConfig(
|
|
810
|
+
run_type=RunType.GEN_WITNESS,
|
|
811
|
+
dev_mode=False,
|
|
812
|
+
witness_file=temp_witness_file,
|
|
813
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
814
|
+
input_file=temp_input_file,
|
|
815
|
+
output_file=temp_output_file,
|
|
816
|
+
write_json=True,
|
|
817
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
818
|
+
),
|
|
819
|
+
)
|
|
820
|
+
if Path.exists(temp_witness_file):
|
|
821
|
+
Path.unlink(temp_witness_file)
|
|
822
|
+
assert not Path.exists(temp_witness_file)
|
|
823
|
+
|
|
824
|
+
# flatten shape
|
|
825
|
+
with Path.open(temp_input_file, "r") as f:
|
|
826
|
+
written_input_data = f.read()
|
|
827
|
+
input_data = json.loads(written_input_data)
|
|
828
|
+
for key in input_data:
|
|
829
|
+
if isinstance(input_data[key], list):
|
|
830
|
+
input_data[key] = torch.as_tensor(input_data[key]).flatten().tolist()
|
|
831
|
+
assert torch.as_tensor(input_data[key]).dim() <= 1, (
|
|
832
|
+
f"Input data for {key} is not 1D tensor. "
|
|
833
|
+
"This is a testing error, not a model error. "
|
|
834
|
+
"Please fix this test to properly flatten."
|
|
835
|
+
)
|
|
836
|
+
with Path.open(temp_input_file, "w") as f:
|
|
837
|
+
json.dump(input_data, f)
|
|
838
|
+
|
|
839
|
+
# Rescale
|
|
840
|
+
with Path.open(temp_input_file, "r") as f:
|
|
841
|
+
written_input_data = f.read()
|
|
842
|
+
input_data = json.loads(written_input_data)
|
|
843
|
+
if hasattr(model_write, "scale_base") and hasattr(model_write, "scale_exponent"):
|
|
844
|
+
for key in input_data:
|
|
845
|
+
input_data[key] = torch.div(
|
|
846
|
+
torch.as_tensor(input_data[key]),
|
|
847
|
+
model_write.scale_base**model_write.scale_exponent,
|
|
848
|
+
).tolist()
|
|
849
|
+
else:
|
|
850
|
+
msg = "Model does not have scale_base attribute"
|
|
851
|
+
raise NotImplementedError(msg)
|
|
852
|
+
assert contains_float(
|
|
853
|
+
input_data,
|
|
854
|
+
), (
|
|
855
|
+
"This is a testing error, not a model error. "
|
|
856
|
+
"Please fix this test to properly turn data to float."
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
with Path.open(temp_input_file, "w") as f:
|
|
860
|
+
json.dump(input_data, f)
|
|
861
|
+
|
|
862
|
+
# Rename
|
|
863
|
+
|
|
864
|
+
with Path.open(temp_input_file, "r") as f:
|
|
865
|
+
written_input_data = f.read()
|
|
866
|
+
new_input_data = {}
|
|
867
|
+
count = 0
|
|
868
|
+
for key in input_data:
|
|
869
|
+
if key == "input":
|
|
870
|
+
new_input_data[f"input_TESTESTTEST_{count}"] = input_data[key]
|
|
871
|
+
count += 1
|
|
872
|
+
else:
|
|
873
|
+
new_input_data[key] = input_data[key]
|
|
874
|
+
assert "input" not in new_input_data, (
|
|
875
|
+
"This is a testing error, not a model error. "
|
|
876
|
+
"Please fix this test to not include 'input' as a key in the input data."
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
with Path.open(temp_input_file, "w") as f:
|
|
880
|
+
json.dump(new_input_data, f)
|
|
881
|
+
|
|
882
|
+
# Read outputs
|
|
883
|
+
with Path.open(temp_output_file, "r") as f:
|
|
884
|
+
written_output_data = f.read()
|
|
885
|
+
Path.unlink(temp_output_file)
|
|
886
|
+
|
|
887
|
+
# Step 2: Read from that same input file (write_json=False)
|
|
888
|
+
# that has been rescaled and flattened
|
|
889
|
+
model_read = model_fixture["model"]
|
|
890
|
+
model_read.base_testing(
|
|
891
|
+
CircuitExecutionConfig(
|
|
892
|
+
run_type=RunType.GEN_WITNESS,
|
|
893
|
+
dev_mode=False,
|
|
894
|
+
witness_file=temp_witness_file,
|
|
895
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
896
|
+
input_file=temp_input_file,
|
|
897
|
+
output_file=temp_output_file,
|
|
898
|
+
write_json=False,
|
|
899
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
900
|
+
),
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
# Step 3: Validate expected outputs and no errors
|
|
904
|
+
captured = capsys.readouterr()
|
|
905
|
+
stdout = captured.out
|
|
906
|
+
stderr = captured.err
|
|
907
|
+
|
|
908
|
+
print(stdout)
|
|
909
|
+
|
|
910
|
+
assert stderr == ""
|
|
911
|
+
|
|
912
|
+
assert Path.exists(temp_witness_file), "Witness file not generated"
|
|
913
|
+
assert (
|
|
914
|
+
stdout.count("Running cargo command:") == OUTPUTTWICE
|
|
915
|
+
), "Expected 'Running cargo command: ' in stdout twice, but it was not found."
|
|
916
|
+
|
|
917
|
+
# Check good output appeared
|
|
918
|
+
for output in GOOD_OUTPUT:
|
|
919
|
+
assert (
|
|
920
|
+
stdout.count(output) == OUTPUTTWICE
|
|
921
|
+
), f"Expected '{output}' in stdout, but it was not found."
|
|
922
|
+
|
|
923
|
+
# Ensure no unexpected errors
|
|
924
|
+
for output in BAD_OUTPUT:
|
|
925
|
+
assert (
|
|
926
|
+
output not in stdout
|
|
927
|
+
), f"Did not expect '{output}' in stdout, but it was found."
|
|
928
|
+
|
|
929
|
+
assert Path.exists(temp_output_file), "Output file not generated"
|
|
930
|
+
with Path.open(temp_output_file, "r") as f:
|
|
931
|
+
new_output_file = f.read()
|
|
932
|
+
|
|
933
|
+
assert_very_close(
|
|
934
|
+
json.loads(new_output_file),
|
|
935
|
+
json.loads(written_output_data),
|
|
936
|
+
model_write,
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
@pytest.mark.e2e()
|
|
941
|
+
def test_witness_wrong_name(
|
|
942
|
+
model_fixture: dict[str, Any],
|
|
943
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
944
|
+
temp_witness_file: Generator[Path, None, None],
|
|
945
|
+
temp_input_file: Generator[Path, None, None],
|
|
946
|
+
temp_output_file: Generator[Path, None, None],
|
|
947
|
+
check_model_compiles: None,
|
|
948
|
+
check_witness_generated: None,
|
|
949
|
+
) -> None:
|
|
950
|
+
_ = check_witness_generated
|
|
951
|
+
_ = check_model_compiles
|
|
952
|
+
# Step 1: Write the input file via write_json=True
|
|
953
|
+
model_write = model_fixture["model"]
|
|
954
|
+
model_write.base_testing(
|
|
955
|
+
CircuitExecutionConfig(
|
|
956
|
+
run_type=RunType.GEN_WITNESS,
|
|
957
|
+
dev_mode=False,
|
|
958
|
+
witness_file=temp_witness_file,
|
|
959
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
960
|
+
input_file=temp_input_file,
|
|
961
|
+
output_file=temp_output_file,
|
|
962
|
+
write_json=True,
|
|
963
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
964
|
+
),
|
|
965
|
+
)
|
|
966
|
+
if Path.exists(temp_witness_file):
|
|
967
|
+
Path.unlink(temp_witness_file)
|
|
968
|
+
assert not Path.exists(temp_witness_file)
|
|
969
|
+
|
|
970
|
+
# Rescale
|
|
971
|
+
with Path.open(temp_input_file, "r") as f:
|
|
972
|
+
written_input_data = f.read()
|
|
973
|
+
input_data = json.loads(written_input_data)
|
|
974
|
+
count = 0
|
|
975
|
+
new_input_data = {}
|
|
976
|
+
for key in input_data:
|
|
977
|
+
if key == "input":
|
|
978
|
+
new_input_data["output"] = input_data[key]
|
|
979
|
+
count += 1
|
|
980
|
+
else:
|
|
981
|
+
new_input_data[key] = input_data[key]
|
|
982
|
+
assert "input" not in new_input_data, (
|
|
983
|
+
"This is a testing error, not a model error. "
|
|
984
|
+
"Please fix this test to not include 'input' as a key in the input data."
|
|
985
|
+
)
|
|
986
|
+
assert "output" in new_input_data or count == 0, (
|
|
987
|
+
"This is a testing error, not a model error. "
|
|
988
|
+
"Please fix this test to include 'output' as a key in the input data."
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
with Path.open(temp_output_file, "r") as f:
|
|
992
|
+
written_output_data = f.read()
|
|
993
|
+
Path.unlink(temp_output_file)
|
|
994
|
+
|
|
995
|
+
with Path.open(temp_input_file, "w") as f:
|
|
996
|
+
json.dump(new_input_data, f)
|
|
997
|
+
|
|
998
|
+
# Step 2: Read from that same input file (write_json=False)
|
|
999
|
+
model_read = model_fixture["model"]
|
|
1000
|
+
model_read.base_testing(
|
|
1001
|
+
CircuitExecutionConfig(
|
|
1002
|
+
run_type=RunType.GEN_WITNESS,
|
|
1003
|
+
dev_mode=False,
|
|
1004
|
+
witness_file=temp_witness_file,
|
|
1005
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
1006
|
+
input_file=temp_input_file,
|
|
1007
|
+
output_file=temp_output_file,
|
|
1008
|
+
write_json=False,
|
|
1009
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
1010
|
+
),
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
# Step 3: Validate expected outputs and no errors
|
|
1014
|
+
captured = capsys.readouterr()
|
|
1015
|
+
stdout = captured.out
|
|
1016
|
+
stderr = captured.err
|
|
1017
|
+
|
|
1018
|
+
print(stdout)
|
|
1019
|
+
|
|
1020
|
+
assert stderr == ""
|
|
1021
|
+
|
|
1022
|
+
assert Path.exists(temp_witness_file), "Witness file not generated"
|
|
1023
|
+
assert (
|
|
1024
|
+
stdout.count("Running cargo command:") == OUTPUTTWICE
|
|
1025
|
+
), "Expected 'Running cargo command: ' in stdout twice, but it was not found."
|
|
1026
|
+
|
|
1027
|
+
# Check good output appeared
|
|
1028
|
+
for output in GOOD_OUTPUT:
|
|
1029
|
+
assert (
|
|
1030
|
+
stdout.count(output) == OUTPUTTWICE
|
|
1031
|
+
), f"Expected '{output}' in stdout, but it was not found."
|
|
1032
|
+
|
|
1033
|
+
# Ensure no unexpected errors
|
|
1034
|
+
for output in BAD_OUTPUT:
|
|
1035
|
+
assert (
|
|
1036
|
+
output not in stdout
|
|
1037
|
+
), f"Did not expect '{output}' in stdout, but it was found."
|
|
1038
|
+
|
|
1039
|
+
assert Path.exists(temp_output_file), "Output file not generated"
|
|
1040
|
+
with Path.open(temp_output_file, "r") as f:
|
|
1041
|
+
new_output_file = f.read()
|
|
1042
|
+
assert_very_close(
|
|
1043
|
+
json.loads(new_output_file),
|
|
1044
|
+
json.loads(written_output_data),
|
|
1045
|
+
model_write,
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
assert (
|
|
1049
|
+
new_output_file == written_output_data
|
|
1050
|
+
), "Output file content does not match the expected output"
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
def add_to_first_scalar(data: list, delta: float = 0.1) -> bool:
|
|
1054
|
+
"""
|
|
1055
|
+
Traverse nested lists until the first scalar (non-list) element is found,
|
|
1056
|
+
then add `delta` to it. Returns True if modified, False otherwise.
|
|
1057
|
+
"""
|
|
1058
|
+
if isinstance(data, list) and len(data) > 0:
|
|
1059
|
+
if isinstance(data[0], list):
|
|
1060
|
+
return add_to_first_scalar(data[0], delta)
|
|
1061
|
+
data[0] = data[0] + delta
|
|
1062
|
+
return True
|
|
1063
|
+
return False
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
@pytest.mark.e2e()
|
|
1067
|
+
def test_witness_prove_verify_false_inputs_dev(
|
|
1068
|
+
model_fixture: dict[str, Any],
|
|
1069
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
1070
|
+
temp_witness_file: Generator[Path, None, None],
|
|
1071
|
+
temp_input_file: Generator[Path, None, None],
|
|
1072
|
+
temp_output_file: Generator[Path, None, None],
|
|
1073
|
+
temp_proof_file: Generator[Path, None, None],
|
|
1074
|
+
check_model_compiles: None,
|
|
1075
|
+
check_witness_generated: None,
|
|
1076
|
+
) -> None:
|
|
1077
|
+
"""
|
|
1078
|
+
Same as test_witness_prove_verify_true_inputs_dev, but deliberately
|
|
1079
|
+
corrupts the witness outputs to trigger verification failure.
|
|
1080
|
+
"""
|
|
1081
|
+
_ = check_witness_generated
|
|
1082
|
+
_ = check_model_compiles
|
|
1083
|
+
|
|
1084
|
+
model = model_fixture["model"]
|
|
1085
|
+
|
|
1086
|
+
# Step 1: Generate witness
|
|
1087
|
+
model.base_testing(
|
|
1088
|
+
CircuitExecutionConfig(
|
|
1089
|
+
run_type=RunType.GEN_WITNESS,
|
|
1090
|
+
dev_mode=False,
|
|
1091
|
+
witness_file=temp_witness_file,
|
|
1092
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
1093
|
+
input_file=temp_input_file,
|
|
1094
|
+
output_file=temp_output_file,
|
|
1095
|
+
write_json=True,
|
|
1096
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
1097
|
+
),
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
# Step 2: Corrupt the witness file by flipping some bytes
|
|
1101
|
+
with Path(temp_input_file).open(encoding="utf-8") as f:
|
|
1102
|
+
input_data = json.load(f)
|
|
1103
|
+
|
|
1104
|
+
first_key = next(iter(input_data)) # get the first key
|
|
1105
|
+
modified = add_to_first_scalar(input_data[first_key], 0.1)
|
|
1106
|
+
|
|
1107
|
+
if not modified:
|
|
1108
|
+
pytest.skip("Input file format not suitable for tampering test.")
|
|
1109
|
+
|
|
1110
|
+
tampered_input_file = temp_input_file.parent / "tampered_input.json"
|
|
1111
|
+
with Path(tampered_input_file).open("w", encoding="utf-8") as f:
|
|
1112
|
+
json.dump(input_data, f)
|
|
1113
|
+
|
|
1114
|
+
# Step 3: Try to prove with corrupted witness
|
|
1115
|
+
model.base_testing(
|
|
1116
|
+
CircuitExecutionConfig(
|
|
1117
|
+
run_type=RunType.PROVE_WITNESS,
|
|
1118
|
+
dev_mode=False,
|
|
1119
|
+
witness_file=temp_witness_file,
|
|
1120
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
1121
|
+
input_file=temp_input_file,
|
|
1122
|
+
output_file=temp_output_file,
|
|
1123
|
+
proof_file=temp_proof_file,
|
|
1124
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
1125
|
+
ecc=False,
|
|
1126
|
+
),
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
# Step 4: Attempt verification
|
|
1130
|
+
with pytest.raises(CircuitRunError) as excinfo:
|
|
1131
|
+
model.base_testing(
|
|
1132
|
+
CircuitExecutionConfig(
|
|
1133
|
+
run_type=RunType.GEN_VERIFY,
|
|
1134
|
+
dev_mode=False,
|
|
1135
|
+
witness_file=temp_witness_file,
|
|
1136
|
+
circuit_path=str(model_fixture["circuit_path"]),
|
|
1137
|
+
input_file=tampered_input_file,
|
|
1138
|
+
output_file=temp_output_file,
|
|
1139
|
+
proof_file=temp_proof_file,
|
|
1140
|
+
quantized_path=str(model_fixture["quantized_model"]),
|
|
1141
|
+
ecc=False,
|
|
1142
|
+
),
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
# ---- ASSERTIONS ----
|
|
1146
|
+
captured = capsys.readouterr()
|
|
1147
|
+
stdout = captured.out
|
|
1148
|
+
stderr = captured.err
|
|
1149
|
+
print(stdout)
|
|
1150
|
+
print(stderr)
|
|
1151
|
+
|
|
1152
|
+
print(excinfo.value)
|
|
1153
|
+
assert "Witness does not match provided inputs and outputs" in str(
|
|
1154
|
+
excinfo.value,
|
|
1155
|
+
)
|
|
1156
|
+
assert "'Verify' failed" in str(
|
|
1157
|
+
excinfo.value,
|
|
1158
|
+
)
|