JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +6 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,1138 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
import shutil
|
|
9
|
+
import subprocess
|
|
10
|
+
import sys
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from importlib.metadata import version as get_version
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from time import time
|
|
17
|
+
from typing import Any, TypeVar
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
import tomllib # Python 3.11+
|
|
21
|
+
except ModuleNotFoundError:
|
|
22
|
+
import tomli as tomllib
|
|
23
|
+
|
|
24
|
+
from python.core import PACKAGE_NAME
|
|
25
|
+
from python.core.utils.benchmarking_helpers import (
|
|
26
|
+
end_memory_collection,
|
|
27
|
+
start_memory_collection,
|
|
28
|
+
)
|
|
29
|
+
from python.core.utils.errors import (
|
|
30
|
+
FileCacheError,
|
|
31
|
+
MissingFileError,
|
|
32
|
+
ProofBackendError,
|
|
33
|
+
ProofSystemNotImplementedError,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RunType(Enum):
|
|
41
|
+
END_TO_END = "end_to_end"
|
|
42
|
+
COMPILE_CIRCUIT = "run_compile_circuit"
|
|
43
|
+
GEN_WITNESS = "run_gen_witness"
|
|
44
|
+
PROVE_WITNESS = "run_prove_witness"
|
|
45
|
+
GEN_VERIFY = "run_gen_verify"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ZKProofSystems(Enum):
|
|
49
|
+
Expander = "Expander"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ExpanderMode(Enum):
|
|
53
|
+
PROVE = "prove"
|
|
54
|
+
VERIFY = "verify"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class CircuitExecutionConfig:
|
|
59
|
+
"""Configuration for circuit execution operations."""
|
|
60
|
+
|
|
61
|
+
run_type: RunType = RunType.END_TO_END
|
|
62
|
+
witness_file: str | None = None
|
|
63
|
+
input_file: str | None = None
|
|
64
|
+
proof_file: str | None = None
|
|
65
|
+
public_path: str | None = None
|
|
66
|
+
verification_key: str | None = None
|
|
67
|
+
circuit_name: str | None = None
|
|
68
|
+
metadata_path: str | None = None
|
|
69
|
+
architecture_path: str | None = None
|
|
70
|
+
w_and_b_path: str | None = None
|
|
71
|
+
output_file: str | None = None
|
|
72
|
+
proof_system: ZKProofSystems = ZKProofSystems.Expander
|
|
73
|
+
circuit_path: str | None = None
|
|
74
|
+
quantized_path: str | None = None
|
|
75
|
+
ecc: bool = True
|
|
76
|
+
dev_mode: bool = False
|
|
77
|
+
write_json: bool = False
|
|
78
|
+
bench: bool = False
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def filter_expander_output(stderr: str) -> str:
|
|
82
|
+
"""Keep Rust panic + MPI exit summary, drop system stack traces and notes."""
|
|
83
|
+
lines = stderr.splitlines()
|
|
84
|
+
filtered = []
|
|
85
|
+
rust_panic_started = False
|
|
86
|
+
|
|
87
|
+
for line in lines:
|
|
88
|
+
# Start capturing Rust panic/assertion
|
|
89
|
+
if "panicked at" in line or "assertion" in line.lower():
|
|
90
|
+
rust_panic_started = True
|
|
91
|
+
filtered.append(line)
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
# Keep lines following Rust panic that are relevant
|
|
95
|
+
if rust_panic_started:
|
|
96
|
+
# Stop at system stack traces or abort messages
|
|
97
|
+
if (
|
|
98
|
+
re.match(r"\[\s*\d+\]", line)
|
|
99
|
+
or "*** Process received signal ***" in line
|
|
100
|
+
):
|
|
101
|
+
rust_panic_started = False
|
|
102
|
+
continue
|
|
103
|
+
filtered.append(line)
|
|
104
|
+
|
|
105
|
+
# Always keep MPI exit summary
|
|
106
|
+
if line.startswith("prterun noticed that process rank"):
|
|
107
|
+
filtered.append(line)
|
|
108
|
+
|
|
109
|
+
return "\n".join(filtered)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def extract_rust_error(stderr: str) -> str:
|
|
113
|
+
"""
|
|
114
|
+
Extracts the Rust error message from stderr,
|
|
115
|
+
handling both panics and normal error prints.
|
|
116
|
+
"""
|
|
117
|
+
lines = stderr.splitlines()
|
|
118
|
+
error_lines = []
|
|
119
|
+
|
|
120
|
+
# Case 1: Rust panic
|
|
121
|
+
capture = False
|
|
122
|
+
for line in lines:
|
|
123
|
+
if re.match(r"thread '.*' panicked at", line):
|
|
124
|
+
capture = True
|
|
125
|
+
continue
|
|
126
|
+
if capture:
|
|
127
|
+
if "stack backtrace:" in line.lower():
|
|
128
|
+
break
|
|
129
|
+
error_lines.append(line)
|
|
130
|
+
if error_lines:
|
|
131
|
+
return "\n".join(error_lines).strip()
|
|
132
|
+
|
|
133
|
+
# Case 2: Non-panic error (just "Error: ...")
|
|
134
|
+
for line in lines:
|
|
135
|
+
if line.strip().startswith("Error:"):
|
|
136
|
+
return line.strip()
|
|
137
|
+
|
|
138
|
+
return ""
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# Decorator to compute outputs once and store in temp folder
|
|
142
|
+
def compute_and_store_output(func: Callable) -> Callable:
|
|
143
|
+
"""Decorator that computes outputs once
|
|
144
|
+
per circuit instance and stores in temp folder.
|
|
145
|
+
Instead of using in-memory cache, uses files in temp folder.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
func (Callable): Method that computes outputs to be cached.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Callable: Wrapped function that reads/writes a caches.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
@functools.wraps(func)
|
|
155
|
+
def wrapper(self: object, *args: tuple, **kwargs: dict) -> object:
|
|
156
|
+
# Define paths for storing outputs in temp folder
|
|
157
|
+
temp_folder = getattr(self, "temp_folder", "temp")
|
|
158
|
+
try:
|
|
159
|
+
Path(temp_folder).mkdir(parents=True, exist_ok=True)
|
|
160
|
+
except OSError as e:
|
|
161
|
+
msg = f"Could not create temp folder {temp_folder}: {e}"
|
|
162
|
+
logger.exception(msg)
|
|
163
|
+
raise FileCacheError(msg) from e
|
|
164
|
+
|
|
165
|
+
output_cache_path = Path(temp_folder) / f"{self.name}_output_cache.json"
|
|
166
|
+
|
|
167
|
+
# Check if cached output exists
|
|
168
|
+
if output_cache_path.exists():
|
|
169
|
+
msg = f"Loading cached outputs for {self.name} from {output_cache_path}"
|
|
170
|
+
logger.info(msg)
|
|
171
|
+
try:
|
|
172
|
+
with output_cache_path.open() as f:
|
|
173
|
+
return json.load(f)
|
|
174
|
+
except (OSError, json.JSONDecodeError) as e:
|
|
175
|
+
msg = f"Error loading cached output: {e}"
|
|
176
|
+
logger.warning(msg)
|
|
177
|
+
# Continue to compute if loading fails
|
|
178
|
+
|
|
179
|
+
# Compute outputs and cache them
|
|
180
|
+
msg = f"Computing outputs for {self.name}..."
|
|
181
|
+
logger.info(msg)
|
|
182
|
+
output = func(self, *args, **kwargs)
|
|
183
|
+
|
|
184
|
+
# Store output in temp folder
|
|
185
|
+
try:
|
|
186
|
+
with Path(output_cache_path).open("w") as f:
|
|
187
|
+
json.dump(output, f)
|
|
188
|
+
msg = f"Stored outputs in {output_cache_path}"
|
|
189
|
+
logger.info(msg)
|
|
190
|
+
except OSError as e:
|
|
191
|
+
msg = f"Warning: Could not cache output to file: {e}"
|
|
192
|
+
logger.warning(msg)
|
|
193
|
+
|
|
194
|
+
return output
|
|
195
|
+
|
|
196
|
+
return wrapper
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
# Decorator to prepare input/output files
|
|
200
|
+
def prepare_io_files(func: Callable) -> Callable:
|
|
201
|
+
"""Decorator that prepares input and output files.
|
|
202
|
+
This allows the function to be called independently.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
func (Callable): The function requiring prepared file paths.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Callable: Wrapped function with prepared file paths injected into its arguments.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
@functools.wraps(func)
|
|
212
|
+
def wrapper(
|
|
213
|
+
self: object,
|
|
214
|
+
exec_config: CircuitExecutionConfig,
|
|
215
|
+
*args: tuple,
|
|
216
|
+
**kwargs: dict,
|
|
217
|
+
) -> object:
|
|
218
|
+
|
|
219
|
+
def resolve_folder(
|
|
220
|
+
key: str,
|
|
221
|
+
file_attr: str | None = None,
|
|
222
|
+
default: str = "",
|
|
223
|
+
) -> str:
|
|
224
|
+
if getattr(exec_config, key, None):
|
|
225
|
+
return getattr(exec_config, key)
|
|
226
|
+
if file_attr and getattr(exec_config, file_attr, None):
|
|
227
|
+
return str(Path(getattr(exec_config, file_attr)).parent)
|
|
228
|
+
return getattr(self, key, default)
|
|
229
|
+
|
|
230
|
+
input_folder = resolve_folder(
|
|
231
|
+
"input_folder",
|
|
232
|
+
"input_file",
|
|
233
|
+
default="python/models/inputs",
|
|
234
|
+
)
|
|
235
|
+
output_folder = resolve_folder(
|
|
236
|
+
"output_folder",
|
|
237
|
+
"output_file",
|
|
238
|
+
default="python/models/output",
|
|
239
|
+
)
|
|
240
|
+
proof_folder = resolve_folder(
|
|
241
|
+
"proof_folder",
|
|
242
|
+
"proof_file",
|
|
243
|
+
default="python/models/proofs",
|
|
244
|
+
)
|
|
245
|
+
quantized_model_folder = resolve_folder(
|
|
246
|
+
"quantized_folder",
|
|
247
|
+
"quantized_path",
|
|
248
|
+
default="python/models/quantized_model_folder",
|
|
249
|
+
)
|
|
250
|
+
weights_folder = resolve_folder(
|
|
251
|
+
"weights_folder",
|
|
252
|
+
default="python/models/weights",
|
|
253
|
+
)
|
|
254
|
+
circuit_folder = resolve_folder("circuit_folder", default="python/models/")
|
|
255
|
+
|
|
256
|
+
proof_system = exec_config.proof_system or getattr(
|
|
257
|
+
self,
|
|
258
|
+
"proof_system",
|
|
259
|
+
ZKProofSystems.Expander,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
files = get_files(
|
|
263
|
+
self.name,
|
|
264
|
+
proof_system,
|
|
265
|
+
{
|
|
266
|
+
"input": input_folder,
|
|
267
|
+
"proof": proof_folder,
|
|
268
|
+
"circuit": circuit_folder,
|
|
269
|
+
"weights": weights_folder,
|
|
270
|
+
"output": output_folder,
|
|
271
|
+
"quantized_model": quantized_model_folder,
|
|
272
|
+
},
|
|
273
|
+
)
|
|
274
|
+
# Fill in any missing fields in exec_config with defaults from `files`
|
|
275
|
+
exec_config.witness_file = exec_config.witness_file or files["witness_file"]
|
|
276
|
+
exec_config.input_file = exec_config.input_file or files["input_file"]
|
|
277
|
+
exec_config.proof_file = exec_config.proof_file or files["proof_path"]
|
|
278
|
+
exec_config.public_path = exec_config.public_path or files["public_path"]
|
|
279
|
+
exec_config.circuit_name = exec_config.circuit_name or files["circuit_name"]
|
|
280
|
+
exec_config.metadata_path = exec_config.metadata_path or files["metadata_path"]
|
|
281
|
+
exec_config.architecture_path = (
|
|
282
|
+
exec_config.architecture_path or files["architecture_path"]
|
|
283
|
+
)
|
|
284
|
+
exec_config.w_and_b_path = exec_config.w_and_b_path or files["w_and_b_path"]
|
|
285
|
+
exec_config.output_file = exec_config.output_file or files["output_file"]
|
|
286
|
+
|
|
287
|
+
if exec_config.circuit_path:
|
|
288
|
+
circuit_dir = Path(exec_config.circuit_path).parent
|
|
289
|
+
name = Path(exec_config.circuit_path).stem
|
|
290
|
+
exec_config.quantized_path = str(
|
|
291
|
+
circuit_dir / f"{name}_quantized_model.onnx",
|
|
292
|
+
)
|
|
293
|
+
exec_config.metadata_path = str(
|
|
294
|
+
circuit_dir / f"{name}_metadata.json",
|
|
295
|
+
)
|
|
296
|
+
exec_config.architecture_path = str(
|
|
297
|
+
circuit_dir / f"{name}_architecture.json",
|
|
298
|
+
)
|
|
299
|
+
exec_config.w_and_b_path = str(
|
|
300
|
+
circuit_dir / f"{name}_wandb.json",
|
|
301
|
+
)
|
|
302
|
+
else:
|
|
303
|
+
exec_config.quantized_path = None
|
|
304
|
+
|
|
305
|
+
# Store paths and data for use in the decorated function
|
|
306
|
+
self._file_info = {
|
|
307
|
+
"witness_file": exec_config.witness_file,
|
|
308
|
+
"input_file": exec_config.input_file,
|
|
309
|
+
"proof_file": exec_config.proof_file,
|
|
310
|
+
"public_path": exec_config.public_path,
|
|
311
|
+
"circuit_name": exec_config.circuit_name,
|
|
312
|
+
"metadata_path": exec_config.metadata_path,
|
|
313
|
+
"architecture_path": exec_config.architecture_path,
|
|
314
|
+
"w_and_b_path": exec_config.w_and_b_path,
|
|
315
|
+
"output_file": exec_config.output_file,
|
|
316
|
+
"inputs": exec_config.input_file,
|
|
317
|
+
"weights": exec_config.w_and_b_path, # Changed to w_and_b_path
|
|
318
|
+
"outputs": exec_config.output_file,
|
|
319
|
+
"output": exec_config.output_file,
|
|
320
|
+
"proof_system": exec_config.proof_system or proof_system,
|
|
321
|
+
"model_path": getattr(exec_config, "model_path", None),
|
|
322
|
+
"quantized_model_path": exec_config.quantized_path,
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
# Call the original function with the populated exec_config
|
|
326
|
+
return func(self, exec_config, *args, **kwargs)
|
|
327
|
+
|
|
328
|
+
return wrapper
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def ensure_parent_dir(path: str) -> None:
|
|
332
|
+
"""Create parent directories for a given path if they don't exist.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
path (str): Path for which to create parent directories.
|
|
336
|
+
"""
|
|
337
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def run_subprocess(cmd: list[str]) -> None:
|
|
341
|
+
"""Run a subprocess command and raise RuntimeError if it fails.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
cmd (list[str]): The command to execute.
|
|
345
|
+
|
|
346
|
+
Raises:
|
|
347
|
+
RuntimeError: If the command exits with a non-zero return code.
|
|
348
|
+
"""
|
|
349
|
+
env = os.environ.copy()
|
|
350
|
+
env.setdefault("PYTHONUNBUFFERED", "1")
|
|
351
|
+
proc = subprocess.run(cmd, text=True, env=env, check=False) # noqa: S603
|
|
352
|
+
if proc.returncode != 0:
|
|
353
|
+
msg = f"Command failed with exit code {proc.returncode}"
|
|
354
|
+
raise RuntimeError(msg)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def to_json(inputs: dict[str, Any], path: str) -> None:
|
|
358
|
+
"""Write data to a JSON file.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
inputs (dict[str, Any]): Data to be serialized.
|
|
362
|
+
path (str): Path where the JSON file will be written.
|
|
363
|
+
"""
|
|
364
|
+
ensure_parent_dir(path)
|
|
365
|
+
with Path(path).open("w") as outfile:
|
|
366
|
+
json.dump(inputs, outfile)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def read_from_json(public_path: str) -> dict[str, Any]:
|
|
370
|
+
"""Read data from a JSON file.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
public_path (str): Path to the JSON file to read.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
dict[str, Any]: The data read from the JSON file.
|
|
377
|
+
"""
|
|
378
|
+
with Path(public_path).open() as json_data:
|
|
379
|
+
return json.load(json_data)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def run_cargo_command(
|
|
383
|
+
binary_name: str,
|
|
384
|
+
command_type: str,
|
|
385
|
+
args: dict[str, str] | None = None,
|
|
386
|
+
*,
|
|
387
|
+
dev_mode: bool = False,
|
|
388
|
+
bench: bool = False,
|
|
389
|
+
) -> subprocess.CompletedProcess[str]:
|
|
390
|
+
"""Run a cargo command with the correct format based on the command type.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
binary_name (str): Name of the Cargo binary.
|
|
394
|
+
command_type (str): Command type (e.g., 'run_proof', 'run_compile_circuit').
|
|
395
|
+
args (dict[str, str], optional): dictionary of CLI arguments. Defaults to None.
|
|
396
|
+
circuit_path (str, optional):
|
|
397
|
+
Path to the circuit file, used for copying the binary.
|
|
398
|
+
dev_mode (bool, optional):
|
|
399
|
+
If True, run with `cargo run --release` instead of prebuilt binary.
|
|
400
|
+
Defaults to False.
|
|
401
|
+
bench (bool, optional):
|
|
402
|
+
If True, measure execution time and memory usage. Defaults to False.
|
|
403
|
+
|
|
404
|
+
Raises:
|
|
405
|
+
subprocess.CalledProcessError: If the Cargo command fails.
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
subprocess.CompletedProcess[str]: Exit message from the subprocess.
|
|
409
|
+
"""
|
|
410
|
+
try:
|
|
411
|
+
version = get_version(PACKAGE_NAME)
|
|
412
|
+
binary_name = binary_name + f"_{version}".replace(".", "-")
|
|
413
|
+
except Exception:
|
|
414
|
+
try:
|
|
415
|
+
pyproject = tomllib.loads(Path("pyproject.toml").read_text())
|
|
416
|
+
version = pyproject["project"]["version"]
|
|
417
|
+
binary_name = binary_name + f"_{version}".replace(".", "-")
|
|
418
|
+
except (FileNotFoundError, KeyError, tomllib.TOMLDecodeError):
|
|
419
|
+
pass
|
|
420
|
+
|
|
421
|
+
binary_path = None
|
|
422
|
+
possible_paths = [
|
|
423
|
+
f"./target/release/{binary_name}",
|
|
424
|
+
Path(__file__).parent.parent / "binaries" / binary_name,
|
|
425
|
+
Path(sys.prefix) / "bin" / binary_name,
|
|
426
|
+
]
|
|
427
|
+
|
|
428
|
+
for path in possible_paths:
|
|
429
|
+
if Path(path).exists():
|
|
430
|
+
binary_path = str(path)
|
|
431
|
+
break
|
|
432
|
+
|
|
433
|
+
if not binary_path:
|
|
434
|
+
binary_path = f"./target/release/{binary_name}"
|
|
435
|
+
cmd = _build_command(
|
|
436
|
+
binary_path=binary_path,
|
|
437
|
+
command_type=command_type,
|
|
438
|
+
args=args,
|
|
439
|
+
dev_mode=dev_mode,
|
|
440
|
+
binary_name=binary_name,
|
|
441
|
+
)
|
|
442
|
+
env = os.environ.copy()
|
|
443
|
+
env["RUST_BACKTRACE"] = "1"
|
|
444
|
+
|
|
445
|
+
msg = f"Running cargo command: {' '.join(cmd)}"
|
|
446
|
+
print(msg) # noqa: T201
|
|
447
|
+
logger.info(msg)
|
|
448
|
+
|
|
449
|
+
try:
|
|
450
|
+
result = _run_subprocess_with_bench(
|
|
451
|
+
cmd=cmd,
|
|
452
|
+
env=env,
|
|
453
|
+
bench=bench,
|
|
454
|
+
binary_name=binary_name,
|
|
455
|
+
)
|
|
456
|
+
_handle_result(result=result, cmd=cmd)
|
|
457
|
+
except OSError as e:
|
|
458
|
+
msg = f"Failed to execute proof backend command '{cmd}': {e}"
|
|
459
|
+
logger.exception(msg)
|
|
460
|
+
raise ProofBackendError(msg) from e
|
|
461
|
+
except subprocess.CalledProcessError as e:
|
|
462
|
+
msg = f"Cargo command failed (return code {e.returncode}): {e.stderr}"
|
|
463
|
+
logger.exception(msg)
|
|
464
|
+
rust_error = extract_rust_error(e.stderr)
|
|
465
|
+
msg = f"Rust backend error '{rust_error}'"
|
|
466
|
+
raise ProofBackendError(msg, cmd) from e
|
|
467
|
+
else:
|
|
468
|
+
return result
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def _build_command(
|
|
472
|
+
binary_path: str,
|
|
473
|
+
command_type: str,
|
|
474
|
+
args: dict[str, str] | None,
|
|
475
|
+
*,
|
|
476
|
+
dev_mode: bool,
|
|
477
|
+
binary_name: str,
|
|
478
|
+
) -> list[str]:
|
|
479
|
+
"""Build the command list for subprocess."""
|
|
480
|
+
cmd = (
|
|
481
|
+
["cargo", "run", "--bin", binary_name, "--release"]
|
|
482
|
+
if dev_mode or not Path(binary_path).exists()
|
|
483
|
+
# dev_mode indicates that we want a recompile, this happens with compile
|
|
484
|
+
# or if there is no executable already created, then we create a new one
|
|
485
|
+
else [binary_path]
|
|
486
|
+
)
|
|
487
|
+
cmd.append(command_type)
|
|
488
|
+
if args:
|
|
489
|
+
for key, value in args.items():
|
|
490
|
+
cmd.append(f"-{key}")
|
|
491
|
+
if not (isinstance(value, bool) and value):
|
|
492
|
+
cmd.append(str(value))
|
|
493
|
+
return cmd
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def _run_subprocess_with_bench(
|
|
497
|
+
cmd: list[str],
|
|
498
|
+
env: dict[str, str],
|
|
499
|
+
binary_name: str,
|
|
500
|
+
*,
|
|
501
|
+
bench: bool,
|
|
502
|
+
) -> subprocess.CompletedProcess[str]:
|
|
503
|
+
"""Run the subprocess with optional benchmarking."""
|
|
504
|
+
if bench:
|
|
505
|
+
stop_event, monitor_thread, monitor_results = start_memory_collection(
|
|
506
|
+
binary_name,
|
|
507
|
+
)
|
|
508
|
+
start_time = time()
|
|
509
|
+
result = subprocess.run( # noqa: S603
|
|
510
|
+
cmd,
|
|
511
|
+
check=True,
|
|
512
|
+
capture_output=True,
|
|
513
|
+
text=True,
|
|
514
|
+
env=env,
|
|
515
|
+
)
|
|
516
|
+
end_time = time()
|
|
517
|
+
print("\n--- BENCHMARK RESULTS ---") # noqa: T201
|
|
518
|
+
print(f"Rust time taken: {end_time - start_time:.4f} seconds") # noqa: T201
|
|
519
|
+
msg = f"Rust command completed in {end_time - start_time:.4f} seconds"
|
|
520
|
+
logger.info(msg)
|
|
521
|
+
|
|
522
|
+
if bench:
|
|
523
|
+
memory = end_memory_collection(stop_event, monitor_thread, monitor_results)
|
|
524
|
+
msg = f"Rust subprocess memory: {memory['total']:.2f} MB"
|
|
525
|
+
print(msg) # noqa: T201
|
|
526
|
+
logger.info(msg)
|
|
527
|
+
|
|
528
|
+
print(result.stdout) # noqa: T201
|
|
529
|
+
logger.info(result.stdout)
|
|
530
|
+
return result
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def _handle_result(result: subprocess.CompletedProcess[str], cmd: list[str]) -> None:
|
|
534
|
+
"""Handle the subprocess result and raise errors if needed."""
|
|
535
|
+
if result.returncode != 0:
|
|
536
|
+
msg = f"Proving Backend failed (code {result.returncode}):\n{result.stderr}"
|
|
537
|
+
logger.error(msg)
|
|
538
|
+
msg = (
|
|
539
|
+
f"Proving Backend command '{' '.join(cmd)}'"
|
|
540
|
+
f" failed with code {result.returncode}:\n"
|
|
541
|
+
f"STDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
|
|
542
|
+
)
|
|
543
|
+
raise ProofBackendError(msg)
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def _copy_binary_if_needed(
|
|
547
|
+
binary_name: str,
|
|
548
|
+
binary_path: str,
|
|
549
|
+
*,
|
|
550
|
+
dev_mode: bool,
|
|
551
|
+
) -> None:
|
|
552
|
+
"""Copy the binary if conditions are met."""
|
|
553
|
+
src = f"./target/release/{binary_name}"
|
|
554
|
+
if Path(src).exists() and (str(src) != str(binary_path)) and dev_mode:
|
|
555
|
+
shutil.copy(src, binary_path)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def get_expander_file_paths(circuit_name: str) -> dict[str, str]:
|
|
559
|
+
"""Generate standard file paths for an Expander circuit.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
circuit_name (str): The base name of the circuit.
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
dict[str, str]:
|
|
566
|
+
dictionary containing file paths with keys:
|
|
567
|
+
circuit_file, witness_file, proof_file
|
|
568
|
+
"""
|
|
569
|
+
return {
|
|
570
|
+
"circuit_file": f"{circuit_name}_circuit.txt",
|
|
571
|
+
"witness_file": f"{circuit_name}_witness.txt",
|
|
572
|
+
"proof_file": f"{circuit_name}_proof.txt",
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def run_expander_raw( # noqa: PLR0913, PLR0912, C901
|
|
577
|
+
mode: ExpanderMode,
|
|
578
|
+
circuit_file: str,
|
|
579
|
+
witness_file: str,
|
|
580
|
+
proof_file: str,
|
|
581
|
+
pcs_type: str = "Hyrax",
|
|
582
|
+
*,
|
|
583
|
+
bench: bool = False,
|
|
584
|
+
) -> subprocess.CompletedProcess[str]:
|
|
585
|
+
"""Run the Expander executable directly using Cargo.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
mode (ExpanderMode): Operation mode (PROVE or VERIFY).
|
|
589
|
+
circuit_file (str): Path to the circuit definition file.
|
|
590
|
+
witness_file (str): Path to the witness file.
|
|
591
|
+
proof_file (str): Path to the proof file (input for verification,
|
|
592
|
+
output for proving).
|
|
593
|
+
pcs_type (str, optional):
|
|
594
|
+
Polynomial commitment scheme type ("Hyrax" or "Raw").
|
|
595
|
+
Defaults to "Hyrax".
|
|
596
|
+
bench (bool, optional):
|
|
597
|
+
If True, collect runtime and memory benchmark data. Defaults to False.
|
|
598
|
+
|
|
599
|
+
Returns:
|
|
600
|
+
subprocess.CompletedProcess[str]: Exit message from the subprocess.
|
|
601
|
+
"""
|
|
602
|
+
for file_path, label in [
|
|
603
|
+
(circuit_file, "circuit_file"),
|
|
604
|
+
(witness_file, "witness_file"),
|
|
605
|
+
# For VERIFY mode, the proof_file is an input, not just an output
|
|
606
|
+
(proof_file, "proof_file") if mode == ExpanderMode.VERIFY else (None, None),
|
|
607
|
+
]:
|
|
608
|
+
if file_path and not Path(file_path).exists():
|
|
609
|
+
msg = f"Missing file required for {label}"
|
|
610
|
+
raise MissingFileError(msg, file_path)
|
|
611
|
+
|
|
612
|
+
env = os.environ.copy()
|
|
613
|
+
env["RUSTFLAGS"] = "-C target-cpu=native"
|
|
614
|
+
time_measure = "/usr/bin/time"
|
|
615
|
+
|
|
616
|
+
expander_binary_path = None
|
|
617
|
+
possible_paths = [
|
|
618
|
+
"./Expander/target/release/expander-exec",
|
|
619
|
+
Path(__file__).parent.parent / "binaries" / "expander-exec",
|
|
620
|
+
Path(sys.prefix) / "bin" / "expander-exec",
|
|
621
|
+
]
|
|
622
|
+
|
|
623
|
+
for path in possible_paths:
|
|
624
|
+
if Path(path).exists():
|
|
625
|
+
expander_binary_path = str(path)
|
|
626
|
+
break
|
|
627
|
+
|
|
628
|
+
if expander_binary_path:
|
|
629
|
+
args = [
|
|
630
|
+
time_measure,
|
|
631
|
+
expander_binary_path,
|
|
632
|
+
"-p",
|
|
633
|
+
pcs_type,
|
|
634
|
+
]
|
|
635
|
+
else:
|
|
636
|
+
args = [
|
|
637
|
+
time_measure,
|
|
638
|
+
"mpiexec",
|
|
639
|
+
"-n",
|
|
640
|
+
"1",
|
|
641
|
+
"cargo",
|
|
642
|
+
"run",
|
|
643
|
+
"--manifest-path",
|
|
644
|
+
"Expander/Cargo.toml",
|
|
645
|
+
"--bin",
|
|
646
|
+
"expander-exec",
|
|
647
|
+
"--release",
|
|
648
|
+
"--",
|
|
649
|
+
"-p",
|
|
650
|
+
pcs_type,
|
|
651
|
+
]
|
|
652
|
+
if mode == ExpanderMode.PROVE:
|
|
653
|
+
args.append(mode.value)
|
|
654
|
+
proof_command = "-o"
|
|
655
|
+
else:
|
|
656
|
+
args.append(mode.value)
|
|
657
|
+
proof_command = "-i"
|
|
658
|
+
|
|
659
|
+
args.extend(["-c", circuit_file])
|
|
660
|
+
args.extend(["-w", witness_file])
|
|
661
|
+
args.extend([proof_command, proof_file])
|
|
662
|
+
|
|
663
|
+
try:
|
|
664
|
+
if bench:
|
|
665
|
+
stop_event, monitor_thread, monitor_results = start_memory_collection(
|
|
666
|
+
"expander-exec",
|
|
667
|
+
)
|
|
668
|
+
start_time = time()
|
|
669
|
+
result = subprocess.run( # noqa: S603
|
|
670
|
+
args,
|
|
671
|
+
env=env,
|
|
672
|
+
capture_output=True,
|
|
673
|
+
text=True,
|
|
674
|
+
check=False,
|
|
675
|
+
)
|
|
676
|
+
end_time = time()
|
|
677
|
+
|
|
678
|
+
print("\n--- BENCHMARK RESULTS ---") # noqa: T201
|
|
679
|
+
print(f"Rust time taken: {end_time - start_time:.4f} seconds") # noqa: T201
|
|
680
|
+
|
|
681
|
+
if bench:
|
|
682
|
+
memory = end_memory_collection(stop_event, monitor_thread, monitor_results)
|
|
683
|
+
print(f"Rust subprocess memory: {memory['total']:.2f} MB") # noqa: T201
|
|
684
|
+
|
|
685
|
+
if result.returncode != 0:
|
|
686
|
+
clean_stderr = filter_expander_output(result.stderr)
|
|
687
|
+
msg = f"Expander {mode.value} failed:\n{clean_stderr}"
|
|
688
|
+
logger.warning(msg)
|
|
689
|
+
msg = f"Expander {mode.value} failed"
|
|
690
|
+
raise ProofBackendError(
|
|
691
|
+
msg,
|
|
692
|
+
command=args,
|
|
693
|
+
returncode=result.returncode,
|
|
694
|
+
stdout=result.stdout,
|
|
695
|
+
stderr=clean_stderr,
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
print( # noqa: T201
|
|
699
|
+
f"✅ expander-exec {mode.value} succeeded:\n{result.stdout}",
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
print(f"Time taken: {end_time - start_time:.4f} seconds") # noqa: T201
|
|
703
|
+
except OSError as e:
|
|
704
|
+
msg = f"Failed to execute Expander {mode.value}: {e}"
|
|
705
|
+
logger.exception(msg)
|
|
706
|
+
raise ProofBackendError(
|
|
707
|
+
msg,
|
|
708
|
+
command=args,
|
|
709
|
+
) from e
|
|
710
|
+
else:
|
|
711
|
+
return result
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def compile_circuit( # noqa: PLR0913
|
|
715
|
+
circuit_name: str,
|
|
716
|
+
circuit_path: str,
|
|
717
|
+
metadata_path: str,
|
|
718
|
+
architecture_path: str,
|
|
719
|
+
w_and_b_path: str,
|
|
720
|
+
proof_system: ZKProofSystems = ZKProofSystems.Expander,
|
|
721
|
+
*,
|
|
722
|
+
dev_mode: bool = True,
|
|
723
|
+
bench: bool = False,
|
|
724
|
+
) -> subprocess.CompletedProcess[str]:
|
|
725
|
+
"""Compile a model into zk circuit
|
|
726
|
+
|
|
727
|
+
Args:
|
|
728
|
+
circuit_name (str): Name of the circuit.
|
|
729
|
+
circuit_path (str): Path to the circuit source file.
|
|
730
|
+
proof_system (ZKProofSystems, optional):
|
|
731
|
+
Proof system to use. Defaults to ZKProofSystems.Expander.
|
|
732
|
+
dev_mode (bool, optional):
|
|
733
|
+
If True, recompiles the rust binary (run in development mode).
|
|
734
|
+
Defaults to True.
|
|
735
|
+
bench (bool, optional):
|
|
736
|
+
Whether or not to run benchmarking metrics. Defaults to False.
|
|
737
|
+
|
|
738
|
+
Raises:
|
|
739
|
+
NotImplementedError: If proof system is not supported.
|
|
740
|
+
Returns:
|
|
741
|
+
subprocess.CompletedProcess[str]: Exit message from the subprocess.
|
|
742
|
+
"""
|
|
743
|
+
if proof_system == ZKProofSystems.Expander:
|
|
744
|
+
# Extract the binary name from the circuit path
|
|
745
|
+
binary_name = Path(circuit_name).name
|
|
746
|
+
|
|
747
|
+
# Prepare arguments
|
|
748
|
+
args = {
|
|
749
|
+
"n": circuit_name,
|
|
750
|
+
"c": circuit_path,
|
|
751
|
+
"m": metadata_path,
|
|
752
|
+
"a": architecture_path,
|
|
753
|
+
"b": w_and_b_path,
|
|
754
|
+
}
|
|
755
|
+
# Run the command
|
|
756
|
+
try:
|
|
757
|
+
return run_cargo_command(
|
|
758
|
+
binary_name=binary_name,
|
|
759
|
+
command_type=RunType.COMPILE_CIRCUIT.value,
|
|
760
|
+
args=args,
|
|
761
|
+
dev_mode=dev_mode,
|
|
762
|
+
bench=bench,
|
|
763
|
+
)
|
|
764
|
+
except ProofBackendError as e:
|
|
765
|
+
warning = f"Warning: Compile operation failed: {e}."
|
|
766
|
+
warning2 = f" Using binary: {binary_name}"
|
|
767
|
+
logger.warning(warning)
|
|
768
|
+
logger.warning(warning2)
|
|
769
|
+
raise
|
|
770
|
+
|
|
771
|
+
else:
|
|
772
|
+
msg = f"Proof system {proof_system} not implemented"
|
|
773
|
+
raise ProofSystemNotImplementedError(msg)
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def generate_witness( # noqa: PLR0913
|
|
777
|
+
circuit_name: str,
|
|
778
|
+
circuit_path: str,
|
|
779
|
+
witness_file: str,
|
|
780
|
+
input_file: str,
|
|
781
|
+
output_file: str,
|
|
782
|
+
metadata_path: str,
|
|
783
|
+
proof_system: ZKProofSystems = ZKProofSystems.Expander,
|
|
784
|
+
*,
|
|
785
|
+
dev_mode: bool = False,
|
|
786
|
+
bench: bool = False,
|
|
787
|
+
) -> subprocess.CompletedProcess[str]:
|
|
788
|
+
"""Generate a witness file for a circuit.
|
|
789
|
+
|
|
790
|
+
Args:
|
|
791
|
+
circuit_name (str): Name of the circuit.
|
|
792
|
+
circuit_path (str): Path to the circuit definition.
|
|
793
|
+
witness_file (str): Path to the output witness file.
|
|
794
|
+
input_file (str): Path to the input JSON file with private inputs.
|
|
795
|
+
output_file (str): Path to the output JSON file with computed outputs.
|
|
796
|
+
proof_system (ZKProofSystems, optional): Proof system to use.
|
|
797
|
+
Defaults to ZKProofSystems.Expander.
|
|
798
|
+
dev_mode (bool, optional):
|
|
799
|
+
If True, recompiles the rust binary (run in development mode).
|
|
800
|
+
Defaults to False.
|
|
801
|
+
bench (bool, optional):
|
|
802
|
+
If True, enable benchmarking. Defaults to False.
|
|
803
|
+
|
|
804
|
+
Raises:
|
|
805
|
+
NotImplementedError: If proof system is not supported.
|
|
806
|
+
|
|
807
|
+
Returns:
|
|
808
|
+
subprocess.CompletedProcess[str]: Exit message from the subprocess.
|
|
809
|
+
"""
|
|
810
|
+
if proof_system == ZKProofSystems.Expander:
|
|
811
|
+
# Extract the binary name from the circuit path
|
|
812
|
+
binary_name = Path(circuit_name).name
|
|
813
|
+
|
|
814
|
+
# Prepare arguments
|
|
815
|
+
args = {
|
|
816
|
+
"n": circuit_name,
|
|
817
|
+
"c": circuit_path,
|
|
818
|
+
"i": input_file,
|
|
819
|
+
"o": output_file,
|
|
820
|
+
"w": witness_file,
|
|
821
|
+
"m": metadata_path,
|
|
822
|
+
}
|
|
823
|
+
# Run the command
|
|
824
|
+
try:
|
|
825
|
+
return run_cargo_command(
|
|
826
|
+
binary_name=binary_name,
|
|
827
|
+
command_type=RunType.GEN_WITNESS.value,
|
|
828
|
+
args=args,
|
|
829
|
+
dev_mode=dev_mode,
|
|
830
|
+
bench=bench,
|
|
831
|
+
)
|
|
832
|
+
except ProofBackendError as e:
|
|
833
|
+
warning = f"Warning: Witness generation failed: {e}"
|
|
834
|
+
logger.warning(warning)
|
|
835
|
+
raise
|
|
836
|
+
else:
|
|
837
|
+
msg = f"Proof system {proof_system} not implemented"
|
|
838
|
+
raise ProofSystemNotImplementedError(msg)
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
def generate_proof( # noqa: PLR0913
|
|
842
|
+
circuit_name: str,
|
|
843
|
+
circuit_path: str,
|
|
844
|
+
witness_file: str,
|
|
845
|
+
proof_file: str,
|
|
846
|
+
metadata_path: str,
|
|
847
|
+
proof_system: ZKProofSystems = ZKProofSystems.Expander,
|
|
848
|
+
*,
|
|
849
|
+
dev_mode: bool = False,
|
|
850
|
+
ecc: bool = True,
|
|
851
|
+
bench: bool = False,
|
|
852
|
+
) -> subprocess.CompletedProcess[str]:
|
|
853
|
+
"""Generate proof for the witness.
|
|
854
|
+
|
|
855
|
+
Args:
|
|
856
|
+
circuit_name (str): Name of the circuit.
|
|
857
|
+
circuit_path (str): Path to the circuit definition.
|
|
858
|
+
witness_file (str): Path to the witness file.
|
|
859
|
+
proof_file (str): Path to the output proof file.
|
|
860
|
+
proof_system (ZKProofSystems, optional): Proof system to use.
|
|
861
|
+
Defaults to ZKProofSystems.Expander.
|
|
862
|
+
dev_mode (bool, optional):
|
|
863
|
+
If True, recompiles the rust binary (run in development mode).
|
|
864
|
+
Defaults to False.
|
|
865
|
+
ecc (bool, optional):
|
|
866
|
+
If true, run proof using ECC api, otherwise run directly through Expander.
|
|
867
|
+
Defaults to True.
|
|
868
|
+
bench (bool, optional):
|
|
869
|
+
If True, enable benchmarking. Defaults to False.
|
|
870
|
+
|
|
871
|
+
Raises:
|
|
872
|
+
NotImplementedError: If proof system is not supported.
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
subprocess.CompletedProcess[str]: Exit message from the subprocess.
|
|
876
|
+
"""
|
|
877
|
+
if proof_system == ZKProofSystems.Expander:
|
|
878
|
+
if ecc:
|
|
879
|
+
# Extract the binary name from the circuit path
|
|
880
|
+
binary_name = Path(circuit_name).name
|
|
881
|
+
|
|
882
|
+
# Prepare arguments
|
|
883
|
+
args = {
|
|
884
|
+
"n": circuit_name,
|
|
885
|
+
"c": circuit_path,
|
|
886
|
+
"w": witness_file,
|
|
887
|
+
"p": proof_file,
|
|
888
|
+
"m": metadata_path,
|
|
889
|
+
}
|
|
890
|
+
|
|
891
|
+
# Run the command
|
|
892
|
+
try:
|
|
893
|
+
return run_cargo_command(
|
|
894
|
+
binary_name=binary_name,
|
|
895
|
+
command_type=RunType.PROVE_WITNESS.value,
|
|
896
|
+
args=args,
|
|
897
|
+
dev_mode=dev_mode,
|
|
898
|
+
bench=bench,
|
|
899
|
+
)
|
|
900
|
+
except ProofBackendError as e:
|
|
901
|
+
warning = f"Warning: Proof generation failed: {e}"
|
|
902
|
+
logger.warning(warning)
|
|
903
|
+
raise
|
|
904
|
+
else:
|
|
905
|
+
return run_expander_raw(
|
|
906
|
+
mode=ExpanderMode.PROVE,
|
|
907
|
+
circuit_file=circuit_path,
|
|
908
|
+
witness_file=witness_file,
|
|
909
|
+
proof_file=proof_file,
|
|
910
|
+
bench=bench,
|
|
911
|
+
)
|
|
912
|
+
else:
|
|
913
|
+
msg = f"Proof system {proof_system} not implemented"
|
|
914
|
+
raise ProofSystemNotImplementedError(msg)
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
def generate_verification( # noqa: PLR0913
|
|
918
|
+
circuit_name: str,
|
|
919
|
+
circuit_path: str,
|
|
920
|
+
input_file: str,
|
|
921
|
+
output_file: str,
|
|
922
|
+
witness_file: str,
|
|
923
|
+
proof_file: str,
|
|
924
|
+
metadata_path: str,
|
|
925
|
+
proof_system: ZKProofSystems = ZKProofSystems.Expander,
|
|
926
|
+
*,
|
|
927
|
+
dev_mode: bool = False,
|
|
928
|
+
ecc: bool = True,
|
|
929
|
+
bench: bool = False,
|
|
930
|
+
) -> subprocess.CompletedProcess[str]:
|
|
931
|
+
"""Verify a given proof.
|
|
932
|
+
|
|
933
|
+
Args:
|
|
934
|
+
circuit_name (str): Name of the circuit.
|
|
935
|
+
circuit_path (str): Path to the circuit definition.
|
|
936
|
+
input_file (str): Path to the input JSON file with public inputs.
|
|
937
|
+
output_file (str): Path to the output JSON file with expected outputs.
|
|
938
|
+
witness_file (str): Path to the witness file.
|
|
939
|
+
proof_file (str): Path to the output proof file.
|
|
940
|
+
proof_system (ZKProofSystems, optional): Proof system to use.
|
|
941
|
+
Defaults to ZKProofSystems.Expander.
|
|
942
|
+
dev_mode (bool, optional):
|
|
943
|
+
If True, recompiles the rust binary (run in development mode).
|
|
944
|
+
Defaults to False.
|
|
945
|
+
ecc (bool, optional):
|
|
946
|
+
If true, run proof using ECC api, otherwise run directly through Expander.
|
|
947
|
+
Defaults to True.
|
|
948
|
+
bench (bool, optional):
|
|
949
|
+
If True, enable benchmarking. Defaults to False.
|
|
950
|
+
|
|
951
|
+
Raises:
|
|
952
|
+
NotImplementedError: If proof system is not supported.
|
|
953
|
+
|
|
954
|
+
Returns:
|
|
955
|
+
subprocess.CompletedProcess[str]: Exit message from the subprocess.
|
|
956
|
+
"""
|
|
957
|
+
if proof_system == ZKProofSystems.Expander:
|
|
958
|
+
if ecc:
|
|
959
|
+
# Extract the binary name from the circuit path
|
|
960
|
+
binary_name = Path(circuit_name).name
|
|
961
|
+
|
|
962
|
+
# Prepare arguments
|
|
963
|
+
args = {
|
|
964
|
+
"n": circuit_name,
|
|
965
|
+
"c": circuit_path,
|
|
966
|
+
"i": input_file,
|
|
967
|
+
"o": output_file,
|
|
968
|
+
"w": witness_file,
|
|
969
|
+
"p": proof_file,
|
|
970
|
+
"m": metadata_path,
|
|
971
|
+
}
|
|
972
|
+
# Run the command
|
|
973
|
+
try:
|
|
974
|
+
return run_cargo_command(
|
|
975
|
+
binary_name=binary_name,
|
|
976
|
+
command_type=RunType.GEN_VERIFY.value,
|
|
977
|
+
args=args,
|
|
978
|
+
dev_mode=dev_mode,
|
|
979
|
+
bench=bench,
|
|
980
|
+
)
|
|
981
|
+
except ProofBackendError as e:
|
|
982
|
+
warning = f"Warning: Verification generation failed: {e}"
|
|
983
|
+
logger.warning(warning)
|
|
984
|
+
raise
|
|
985
|
+
else:
|
|
986
|
+
return run_expander_raw(
|
|
987
|
+
mode=ExpanderMode.VERIFY,
|
|
988
|
+
circuit_file=circuit_path,
|
|
989
|
+
witness_file=witness_file,
|
|
990
|
+
proof_file=proof_file,
|
|
991
|
+
bench=bench,
|
|
992
|
+
)
|
|
993
|
+
else:
|
|
994
|
+
msg = f"Proof system {proof_system} not implemented"
|
|
995
|
+
raise ProofSystemNotImplementedError(msg)
|
|
996
|
+
|
|
997
|
+
|
|
998
|
+
def run_end_to_end( # noqa: PLR0913
|
|
999
|
+
circuit_name: str,
|
|
1000
|
+
circuit_path: str,
|
|
1001
|
+
input_file: str,
|
|
1002
|
+
output_file: str,
|
|
1003
|
+
proof_system: ZKProofSystems = ZKProofSystems.Expander,
|
|
1004
|
+
*,
|
|
1005
|
+
demo: bool = False,
|
|
1006
|
+
dev_mode: bool = False,
|
|
1007
|
+
ecc: bool = True,
|
|
1008
|
+
) -> int:
|
|
1009
|
+
"""Run the full pipeline for proving and verifying a circuit.
|
|
1010
|
+
|
|
1011
|
+
Steps:
|
|
1012
|
+
1. Compile the circuit.
|
|
1013
|
+
2. Generate a witness from inputs.
|
|
1014
|
+
3. Produce a proof from the witness.
|
|
1015
|
+
4. Verify the proof against inputs and outputs.
|
|
1016
|
+
|
|
1017
|
+
Args:
|
|
1018
|
+
circuit_name (str): Name of the circuit.
|
|
1019
|
+
circuit_path (str): Path to the circuit definition.
|
|
1020
|
+
input_file (str): Path to the input JSON file with public inputs.
|
|
1021
|
+
output_file (str): Path to the output JSON file with expected outputs.
|
|
1022
|
+
proof_system (ZKProofSystems, optional):
|
|
1023
|
+
Proof system to use. Defaults to ZKProofSystems.Expander.
|
|
1024
|
+
demo (bool, optional):
|
|
1025
|
+
Run Demo mode, which limits prints, to clean only. Defaults to False.
|
|
1026
|
+
dev_mode (bool, optional):
|
|
1027
|
+
If True, recompiles the rust binary (run in development mode).
|
|
1028
|
+
Defaults to False.
|
|
1029
|
+
ecc (bool, optional):
|
|
1030
|
+
If true, run proof using ECC api, otherwise run directly through Expander.
|
|
1031
|
+
Defaults to True.
|
|
1032
|
+
|
|
1033
|
+
Raises:
|
|
1034
|
+
NotImplementedError: If proof system is not supported.
|
|
1035
|
+
|
|
1036
|
+
Returns:
|
|
1037
|
+
int: Exit code from the verification step (0 = success, non-zero = failure).
|
|
1038
|
+
"""
|
|
1039
|
+
_ = demo
|
|
1040
|
+
if proof_system == ZKProofSystems.Expander:
|
|
1041
|
+
path = Path(circuit_path)
|
|
1042
|
+
base = str(path.with_suffix("")) # filename without extension
|
|
1043
|
+
ext = path.suffix
|
|
1044
|
+
|
|
1045
|
+
witness_file = f"{base}_witness{ext}"
|
|
1046
|
+
proof_file = f"{base}_proof.bin"
|
|
1047
|
+
compile_circuit(
|
|
1048
|
+
circuit_name,
|
|
1049
|
+
circuit_path,
|
|
1050
|
+
f"{base}_metadata.json",
|
|
1051
|
+
f"{base}_architecture.json",
|
|
1052
|
+
f"{base}_wandb.json",
|
|
1053
|
+
proof_system,
|
|
1054
|
+
dev_mode,
|
|
1055
|
+
)
|
|
1056
|
+
generate_witness(
|
|
1057
|
+
circuit_name,
|
|
1058
|
+
circuit_path,
|
|
1059
|
+
witness_file,
|
|
1060
|
+
input_file,
|
|
1061
|
+
output_file,
|
|
1062
|
+
f"{base}_metadata.json",
|
|
1063
|
+
proof_system,
|
|
1064
|
+
dev_mode,
|
|
1065
|
+
)
|
|
1066
|
+
generate_proof(
|
|
1067
|
+
circuit_name,
|
|
1068
|
+
circuit_path,
|
|
1069
|
+
witness_file,
|
|
1070
|
+
proof_file,
|
|
1071
|
+
f"{base}_metadata.json",
|
|
1072
|
+
proof_system,
|
|
1073
|
+
dev_mode,
|
|
1074
|
+
ecc,
|
|
1075
|
+
)
|
|
1076
|
+
return generate_verification(
|
|
1077
|
+
circuit_name,
|
|
1078
|
+
circuit_path,
|
|
1079
|
+
input_file,
|
|
1080
|
+
output_file,
|
|
1081
|
+
witness_file,
|
|
1082
|
+
proof_file,
|
|
1083
|
+
f"{base}_metadata.json",
|
|
1084
|
+
proof_system,
|
|
1085
|
+
dev_mode,
|
|
1086
|
+
ecc,
|
|
1087
|
+
)
|
|
1088
|
+
msg = f"Proof system {proof_system} not implemented"
|
|
1089
|
+
raise ProofSystemNotImplementedError(msg)
|
|
1090
|
+
|
|
1091
|
+
|
|
1092
|
+
def get_files(
|
|
1093
|
+
name: str,
|
|
1094
|
+
proof_system: ZKProofSystems,
|
|
1095
|
+
folders: dict[str, str],
|
|
1096
|
+
) -> dict[str, str]:
|
|
1097
|
+
"""
|
|
1098
|
+
Generate file paths ensuring folders exist.
|
|
1099
|
+
|
|
1100
|
+
Args:
|
|
1101
|
+
name (str): The base name for all generated files.
|
|
1102
|
+
proof_system (ZKProofSystems): The ZK proof system being used.
|
|
1103
|
+
folders (dict[str, str]):
|
|
1104
|
+
dictionary containing required folder paths with keys like:
|
|
1105
|
+
'input', 'proof', 'temp', 'circuit', 'weights', 'output', 'quantized_model'.
|
|
1106
|
+
|
|
1107
|
+
Raises:
|
|
1108
|
+
NotImplementedError: If not implemented proof system is tried
|
|
1109
|
+
|
|
1110
|
+
Returns:
|
|
1111
|
+
dict[str, str]: A dictionary mapping descriptive keys to file paths.
|
|
1112
|
+
"""
|
|
1113
|
+
# Common file paths
|
|
1114
|
+
paths = {
|
|
1115
|
+
"input_file": str(Path(folders["input"]) / f"{name}_input.json"),
|
|
1116
|
+
"public_path": str(Path(folders["proof"]) / f"{name}_public.json"),
|
|
1117
|
+
"metadata_path": str(Path(folders["weights"]) / f"{name}_metadata.json"),
|
|
1118
|
+
"architecture_path": str(
|
|
1119
|
+
Path(folders["weights"]) / f"{name}_architecture.json",
|
|
1120
|
+
),
|
|
1121
|
+
"w_and_b_path": str(Path(folders["weights"]) / f"{name}_w_and_b.json"),
|
|
1122
|
+
"output_file": str(Path(folders["output"]) / f"{name}_output.json"),
|
|
1123
|
+
}
|
|
1124
|
+
|
|
1125
|
+
# Proof-system-specific files
|
|
1126
|
+
if proof_system == ZKProofSystems.Expander:
|
|
1127
|
+
paths.update(
|
|
1128
|
+
{
|
|
1129
|
+
"circuit_name": name,
|
|
1130
|
+
"witness_file": str(Path(folders["input"]) / f"{name}_witness.txt"),
|
|
1131
|
+
"proof_path": str(Path(folders["proof"]) / f"{name}_proof.bin"),
|
|
1132
|
+
},
|
|
1133
|
+
)
|
|
1134
|
+
else:
|
|
1135
|
+
msg = f"Proof system {proof_system} not implemented"
|
|
1136
|
+
raise ProofSystemNotImplementedError(msg)
|
|
1137
|
+
|
|
1138
|
+
return paths
|