JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +5 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,1138 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import json
5
+ import logging
6
+ import os
7
+ import re
8
+ import shutil
9
+ import subprocess
10
+ import sys
11
+ from collections.abc import Callable
12
+ from dataclasses import dataclass
13
+ from enum import Enum
14
+ from importlib.metadata import version as get_version
15
+ from pathlib import Path
16
+ from time import time
17
+ from typing import Any, TypeVar
18
+
19
+ try:
20
+ import tomllib # Python 3.11+
21
+ except ModuleNotFoundError:
22
+ import tomli as tomllib
23
+
24
+ from python.core import PACKAGE_NAME
25
+ from python.core.utils.benchmarking_helpers import (
26
+ end_memory_collection,
27
+ start_memory_collection,
28
+ )
29
+ from python.core.utils.errors import (
30
+ FileCacheError,
31
+ MissingFileError,
32
+ ProofBackendError,
33
+ ProofSystemNotImplementedError,
34
+ )
35
+
36
+ F = TypeVar("F", bound=Callable[..., Any])
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class RunType(Enum):
41
+ END_TO_END = "end_to_end"
42
+ COMPILE_CIRCUIT = "run_compile_circuit"
43
+ GEN_WITNESS = "run_gen_witness"
44
+ PROVE_WITNESS = "run_prove_witness"
45
+ GEN_VERIFY = "run_gen_verify"
46
+
47
+
48
+ class ZKProofSystems(Enum):
49
+ Expander = "Expander"
50
+
51
+
52
+ class ExpanderMode(Enum):
53
+ PROVE = "prove"
54
+ VERIFY = "verify"
55
+
56
+
57
+ @dataclass
58
+ class CircuitExecutionConfig:
59
+ """Configuration for circuit execution operations."""
60
+
61
+ run_type: RunType = RunType.END_TO_END
62
+ witness_file: str | None = None
63
+ input_file: str | None = None
64
+ proof_file: str | None = None
65
+ public_path: str | None = None
66
+ verification_key: str | None = None
67
+ circuit_name: str | None = None
68
+ metadata_path: str | None = None
69
+ architecture_path: str | None = None
70
+ w_and_b_path: str | None = None
71
+ output_file: str | None = None
72
+ proof_system: ZKProofSystems = ZKProofSystems.Expander
73
+ circuit_path: str | None = None
74
+ quantized_path: str | None = None
75
+ ecc: bool = True
76
+ dev_mode: bool = False
77
+ write_json: bool = False
78
+ bench: bool = False
79
+
80
+
81
+ def filter_expander_output(stderr: str) -> str:
82
+ """Keep Rust panic + MPI exit summary, drop system stack traces and notes."""
83
+ lines = stderr.splitlines()
84
+ filtered = []
85
+ rust_panic_started = False
86
+
87
+ for line in lines:
88
+ # Start capturing Rust panic/assertion
89
+ if "panicked at" in line or "assertion" in line.lower():
90
+ rust_panic_started = True
91
+ filtered.append(line)
92
+ continue
93
+
94
+ # Keep lines following Rust panic that are relevant
95
+ if rust_panic_started:
96
+ # Stop at system stack traces or abort messages
97
+ if (
98
+ re.match(r"\[\s*\d+\]", line)
99
+ or "*** Process received signal ***" in line
100
+ ):
101
+ rust_panic_started = False
102
+ continue
103
+ filtered.append(line)
104
+
105
+ # Always keep MPI exit summary
106
+ if line.startswith("prterun noticed that process rank"):
107
+ filtered.append(line)
108
+
109
+ return "\n".join(filtered)
110
+
111
+
112
+ def extract_rust_error(stderr: str) -> str:
113
+ """
114
+ Extracts the Rust error message from stderr,
115
+ handling both panics and normal error prints.
116
+ """
117
+ lines = stderr.splitlines()
118
+ error_lines = []
119
+
120
+ # Case 1: Rust panic
121
+ capture = False
122
+ for line in lines:
123
+ if re.match(r"thread '.*' panicked at", line):
124
+ capture = True
125
+ continue
126
+ if capture:
127
+ if "stack backtrace:" in line.lower():
128
+ break
129
+ error_lines.append(line)
130
+ if error_lines:
131
+ return "\n".join(error_lines).strip()
132
+
133
+ # Case 2: Non-panic error (just "Error: ...")
134
+ for line in lines:
135
+ if line.strip().startswith("Error:"):
136
+ return line.strip()
137
+
138
+ return ""
139
+
140
+
141
+ # Decorator to compute outputs once and store in temp folder
142
+ def compute_and_store_output(func: Callable) -> Callable:
143
+ """Decorator that computes outputs once
144
+ per circuit instance and stores in temp folder.
145
+ Instead of using in-memory cache, uses files in temp folder.
146
+
147
+ Args:
148
+ func (Callable): Method that computes outputs to be cached.
149
+
150
+ Returns:
151
+ Callable: Wrapped function that reads/writes a caches.
152
+ """
153
+
154
+ @functools.wraps(func)
155
+ def wrapper(self: object, *args: tuple, **kwargs: dict) -> object:
156
+ # Define paths for storing outputs in temp folder
157
+ temp_folder = getattr(self, "temp_folder", "temp")
158
+ try:
159
+ Path(temp_folder).mkdir(parents=True, exist_ok=True)
160
+ except OSError as e:
161
+ msg = f"Could not create temp folder {temp_folder}: {e}"
162
+ logger.exception(msg)
163
+ raise FileCacheError(msg) from e
164
+
165
+ output_cache_path = Path(temp_folder) / f"{self.name}_output_cache.json"
166
+
167
+ # Check if cached output exists
168
+ if output_cache_path.exists():
169
+ msg = f"Loading cached outputs for {self.name} from {output_cache_path}"
170
+ logger.info(msg)
171
+ try:
172
+ with output_cache_path.open() as f:
173
+ return json.load(f)
174
+ except (OSError, json.JSONDecodeError) as e:
175
+ msg = f"Error loading cached output: {e}"
176
+ logger.warning(msg)
177
+ # Continue to compute if loading fails
178
+
179
+ # Compute outputs and cache them
180
+ msg = f"Computing outputs for {self.name}..."
181
+ logger.info(msg)
182
+ output = func(self, *args, **kwargs)
183
+
184
+ # Store output in temp folder
185
+ try:
186
+ with Path(output_cache_path).open("w") as f:
187
+ json.dump(output, f)
188
+ msg = f"Stored outputs in {output_cache_path}"
189
+ logger.info(msg)
190
+ except OSError as e:
191
+ msg = f"Warning: Could not cache output to file: {e}"
192
+ logger.warning(msg)
193
+
194
+ return output
195
+
196
+ return wrapper
197
+
198
+
199
+ # Decorator to prepare input/output files
200
+ def prepare_io_files(func: Callable) -> Callable:
201
+ """Decorator that prepares input and output files.
202
+ This allows the function to be called independently.
203
+
204
+ Args:
205
+ func (Callable): The function requiring prepared file paths.
206
+
207
+ Returns:
208
+ Callable: Wrapped function with prepared file paths injected into its arguments.
209
+ """
210
+
211
+ @functools.wraps(func)
212
+ def wrapper(
213
+ self: object,
214
+ exec_config: CircuitExecutionConfig,
215
+ *args: tuple,
216
+ **kwargs: dict,
217
+ ) -> object:
218
+
219
+ def resolve_folder(
220
+ key: str,
221
+ file_attr: str | None = None,
222
+ default: str = "",
223
+ ) -> str:
224
+ if getattr(exec_config, key, None):
225
+ return getattr(exec_config, key)
226
+ if file_attr and getattr(exec_config, file_attr, None):
227
+ return str(Path(getattr(exec_config, file_attr)).parent)
228
+ return getattr(self, key, default)
229
+
230
+ input_folder = resolve_folder(
231
+ "input_folder",
232
+ "input_file",
233
+ default="python/models/inputs",
234
+ )
235
+ output_folder = resolve_folder(
236
+ "output_folder",
237
+ "output_file",
238
+ default="python/models/output",
239
+ )
240
+ proof_folder = resolve_folder(
241
+ "proof_folder",
242
+ "proof_file",
243
+ default="python/models/proofs",
244
+ )
245
+ quantized_model_folder = resolve_folder(
246
+ "quantized_folder",
247
+ "quantized_path",
248
+ default="python/models/quantized_model_folder",
249
+ )
250
+ weights_folder = resolve_folder(
251
+ "weights_folder",
252
+ default="python/models/weights",
253
+ )
254
+ circuit_folder = resolve_folder("circuit_folder", default="python/models/")
255
+
256
+ proof_system = exec_config.proof_system or getattr(
257
+ self,
258
+ "proof_system",
259
+ ZKProofSystems.Expander,
260
+ )
261
+
262
+ files = get_files(
263
+ self.name,
264
+ proof_system,
265
+ {
266
+ "input": input_folder,
267
+ "proof": proof_folder,
268
+ "circuit": circuit_folder,
269
+ "weights": weights_folder,
270
+ "output": output_folder,
271
+ "quantized_model": quantized_model_folder,
272
+ },
273
+ )
274
+ # Fill in any missing fields in exec_config with defaults from `files`
275
+ exec_config.witness_file = exec_config.witness_file or files["witness_file"]
276
+ exec_config.input_file = exec_config.input_file or files["input_file"]
277
+ exec_config.proof_file = exec_config.proof_file or files["proof_path"]
278
+ exec_config.public_path = exec_config.public_path or files["public_path"]
279
+ exec_config.circuit_name = exec_config.circuit_name or files["circuit_name"]
280
+ exec_config.metadata_path = exec_config.metadata_path or files["metadata_path"]
281
+ exec_config.architecture_path = (
282
+ exec_config.architecture_path or files["architecture_path"]
283
+ )
284
+ exec_config.w_and_b_path = exec_config.w_and_b_path or files["w_and_b_path"]
285
+ exec_config.output_file = exec_config.output_file or files["output_file"]
286
+
287
+ if exec_config.circuit_path:
288
+ circuit_dir = Path(exec_config.circuit_path).parent
289
+ name = Path(exec_config.circuit_path).stem
290
+ exec_config.quantized_path = str(
291
+ circuit_dir / f"{name}_quantized_model.onnx",
292
+ )
293
+ exec_config.metadata_path = str(
294
+ circuit_dir / f"{name}_metadata.json",
295
+ )
296
+ exec_config.architecture_path = str(
297
+ circuit_dir / f"{name}_architecture.json",
298
+ )
299
+ exec_config.w_and_b_path = str(
300
+ circuit_dir / f"{name}_wandb.json",
301
+ )
302
+ else:
303
+ exec_config.quantized_path = None
304
+
305
+ # Store paths and data for use in the decorated function
306
+ self._file_info = {
307
+ "witness_file": exec_config.witness_file,
308
+ "input_file": exec_config.input_file,
309
+ "proof_file": exec_config.proof_file,
310
+ "public_path": exec_config.public_path,
311
+ "circuit_name": exec_config.circuit_name,
312
+ "metadata_path": exec_config.metadata_path,
313
+ "architecture_path": exec_config.architecture_path,
314
+ "w_and_b_path": exec_config.w_and_b_path,
315
+ "output_file": exec_config.output_file,
316
+ "inputs": exec_config.input_file,
317
+ "weights": exec_config.w_and_b_path, # Changed to w_and_b_path
318
+ "outputs": exec_config.output_file,
319
+ "output": exec_config.output_file,
320
+ "proof_system": exec_config.proof_system or proof_system,
321
+ "model_path": getattr(exec_config, "model_path", None),
322
+ "quantized_model_path": exec_config.quantized_path,
323
+ }
324
+
325
+ # Call the original function with the populated exec_config
326
+ return func(self, exec_config, *args, **kwargs)
327
+
328
+ return wrapper
329
+
330
+
331
+ def ensure_parent_dir(path: str) -> None:
332
+ """Create parent directories for a given path if they don't exist.
333
+
334
+ Args:
335
+ path (str): Path for which to create parent directories.
336
+ """
337
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
338
+
339
+
340
+ def run_subprocess(cmd: list[str]) -> None:
341
+ """Run a subprocess command and raise RuntimeError if it fails.
342
+
343
+ Args:
344
+ cmd (list[str]): The command to execute.
345
+
346
+ Raises:
347
+ RuntimeError: If the command exits with a non-zero return code.
348
+ """
349
+ env = os.environ.copy()
350
+ env.setdefault("PYTHONUNBUFFERED", "1")
351
+ proc = subprocess.run(cmd, text=True, env=env, check=False) # noqa: S603
352
+ if proc.returncode != 0:
353
+ msg = f"Command failed with exit code {proc.returncode}"
354
+ raise RuntimeError(msg)
355
+
356
+
357
+ def to_json(inputs: dict[str, Any], path: str) -> None:
358
+ """Write data to a JSON file.
359
+
360
+ Args:
361
+ inputs (dict[str, Any]): Data to be serialized.
362
+ path (str): Path where the JSON file will be written.
363
+ """
364
+ ensure_parent_dir(path)
365
+ with Path(path).open("w") as outfile:
366
+ json.dump(inputs, outfile)
367
+
368
+
369
+ def read_from_json(public_path: str) -> dict[str, Any]:
370
+ """Read data from a JSON file.
371
+
372
+ Args:
373
+ public_path (str): Path to the JSON file to read.
374
+
375
+ Returns:
376
+ dict[str, Any]: The data read from the JSON file.
377
+ """
378
+ with Path(public_path).open() as json_data:
379
+ return json.load(json_data)
380
+
381
+
382
+ def run_cargo_command(
383
+ binary_name: str,
384
+ command_type: str,
385
+ args: dict[str, str] | None = None,
386
+ *,
387
+ dev_mode: bool = False,
388
+ bench: bool = False,
389
+ ) -> subprocess.CompletedProcess[str]:
390
+ """Run a cargo command with the correct format based on the command type.
391
+
392
+ Args:
393
+ binary_name (str): Name of the Cargo binary.
394
+ command_type (str): Command type (e.g., 'run_proof', 'run_compile_circuit').
395
+ args (dict[str, str], optional): dictionary of CLI arguments. Defaults to None.
396
+ circuit_path (str, optional):
397
+ Path to the circuit file, used for copying the binary.
398
+ dev_mode (bool, optional):
399
+ If True, run with `cargo run --release` instead of prebuilt binary.
400
+ Defaults to False.
401
+ bench (bool, optional):
402
+ If True, measure execution time and memory usage. Defaults to False.
403
+
404
+ Raises:
405
+ subprocess.CalledProcessError: If the Cargo command fails.
406
+
407
+ Returns:
408
+ subprocess.CompletedProcess[str]: Exit message from the subprocess.
409
+ """
410
+ try:
411
+ version = get_version(PACKAGE_NAME)
412
+ binary_name = binary_name + f"_{version}".replace(".", "-")
413
+ except Exception:
414
+ try:
415
+ pyproject = tomllib.loads(Path("pyproject.toml").read_text())
416
+ version = pyproject["project"]["version"]
417
+ binary_name = binary_name + f"_{version}".replace(".", "-")
418
+ except (FileNotFoundError, KeyError, tomllib.TOMLDecodeError):
419
+ pass
420
+
421
+ binary_path = None
422
+ possible_paths = [
423
+ f"./target/release/{binary_name}",
424
+ Path(__file__).parent.parent / "binaries" / binary_name,
425
+ Path(sys.prefix) / "bin" / binary_name,
426
+ ]
427
+
428
+ for path in possible_paths:
429
+ if Path(path).exists():
430
+ binary_path = str(path)
431
+ break
432
+
433
+ if not binary_path:
434
+ binary_path = f"./target/release/{binary_name}"
435
+ cmd = _build_command(
436
+ binary_path=binary_path,
437
+ command_type=command_type,
438
+ args=args,
439
+ dev_mode=dev_mode,
440
+ binary_name=binary_name,
441
+ )
442
+ env = os.environ.copy()
443
+ env["RUST_BACKTRACE"] = "1"
444
+
445
+ msg = f"Running cargo command: {' '.join(cmd)}"
446
+ print(msg) # noqa: T201
447
+ logger.info(msg)
448
+
449
+ try:
450
+ result = _run_subprocess_with_bench(
451
+ cmd=cmd,
452
+ env=env,
453
+ bench=bench,
454
+ binary_name=binary_name,
455
+ )
456
+ _handle_result(result=result, cmd=cmd)
457
+ except OSError as e:
458
+ msg = f"Failed to execute proof backend command '{cmd}': {e}"
459
+ logger.exception(msg)
460
+ raise ProofBackendError(msg) from e
461
+ except subprocess.CalledProcessError as e:
462
+ msg = f"Cargo command failed (return code {e.returncode}): {e.stderr}"
463
+ logger.exception(msg)
464
+ rust_error = extract_rust_error(e.stderr)
465
+ msg = f"Rust backend error '{rust_error}'"
466
+ raise ProofBackendError(msg, cmd) from e
467
+ else:
468
+ return result
469
+
470
+
471
+ def _build_command(
472
+ binary_path: str,
473
+ command_type: str,
474
+ args: dict[str, str] | None,
475
+ *,
476
+ dev_mode: bool,
477
+ binary_name: str,
478
+ ) -> list[str]:
479
+ """Build the command list for subprocess."""
480
+ cmd = (
481
+ ["cargo", "run", "--bin", binary_name, "--release"]
482
+ if dev_mode or not Path(binary_path).exists()
483
+ # dev_mode indicates that we want a recompile, this happens with compile
484
+ # or if there is no executable already created, then we create a new one
485
+ else [binary_path]
486
+ )
487
+ cmd.append(command_type)
488
+ if args:
489
+ for key, value in args.items():
490
+ cmd.append(f"-{key}")
491
+ if not (isinstance(value, bool) and value):
492
+ cmd.append(str(value))
493
+ return cmd
494
+
495
+
496
+ def _run_subprocess_with_bench(
497
+ cmd: list[str],
498
+ env: dict[str, str],
499
+ binary_name: str,
500
+ *,
501
+ bench: bool,
502
+ ) -> subprocess.CompletedProcess[str]:
503
+ """Run the subprocess with optional benchmarking."""
504
+ if bench:
505
+ stop_event, monitor_thread, monitor_results = start_memory_collection(
506
+ binary_name,
507
+ )
508
+ start_time = time()
509
+ result = subprocess.run( # noqa: S603
510
+ cmd,
511
+ check=True,
512
+ capture_output=True,
513
+ text=True,
514
+ env=env,
515
+ )
516
+ end_time = time()
517
+ print("\n--- BENCHMARK RESULTS ---") # noqa: T201
518
+ print(f"Rust time taken: {end_time - start_time:.4f} seconds") # noqa: T201
519
+ msg = f"Rust command completed in {end_time - start_time:.4f} seconds"
520
+ logger.info(msg)
521
+
522
+ if bench:
523
+ memory = end_memory_collection(stop_event, monitor_thread, monitor_results)
524
+ msg = f"Rust subprocess memory: {memory['total']:.2f} MB"
525
+ print(msg) # noqa: T201
526
+ logger.info(msg)
527
+
528
+ print(result.stdout) # noqa: T201
529
+ logger.info(result.stdout)
530
+ return result
531
+
532
+
533
+ def _handle_result(result: subprocess.CompletedProcess[str], cmd: list[str]) -> None:
534
+ """Handle the subprocess result and raise errors if needed."""
535
+ if result.returncode != 0:
536
+ msg = f"Proving Backend failed (code {result.returncode}):\n{result.stderr}"
537
+ logger.error(msg)
538
+ msg = (
539
+ f"Proving Backend command '{' '.join(cmd)}'"
540
+ f" failed with code {result.returncode}:\n"
541
+ f"STDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
542
+ )
543
+ raise ProofBackendError(msg)
544
+
545
+
546
+ def _copy_binary_if_needed(
547
+ binary_name: str,
548
+ binary_path: str,
549
+ *,
550
+ dev_mode: bool,
551
+ ) -> None:
552
+ """Copy the binary if conditions are met."""
553
+ src = f"./target/release/{binary_name}"
554
+ if Path(src).exists() and (str(src) != str(binary_path)) and dev_mode:
555
+ shutil.copy(src, binary_path)
556
+
557
+
558
+ def get_expander_file_paths(circuit_name: str) -> dict[str, str]:
559
+ """Generate standard file paths for an Expander circuit.
560
+
561
+ Args:
562
+ circuit_name (str): The base name of the circuit.
563
+
564
+ Returns:
565
+ dict[str, str]:
566
+ dictionary containing file paths with keys:
567
+ circuit_file, witness_file, proof_file
568
+ """
569
+ return {
570
+ "circuit_file": f"{circuit_name}_circuit.txt",
571
+ "witness_file": f"{circuit_name}_witness.txt",
572
+ "proof_file": f"{circuit_name}_proof.txt",
573
+ }
574
+
575
+
576
+ def run_expander_raw( # noqa: PLR0913, PLR0912, C901
577
+ mode: ExpanderMode,
578
+ circuit_file: str,
579
+ witness_file: str,
580
+ proof_file: str,
581
+ pcs_type: str = "Hyrax",
582
+ *,
583
+ bench: bool = False,
584
+ ) -> subprocess.CompletedProcess[str]:
585
+ """Run the Expander executable directly using Cargo.
586
+
587
+ Args:
588
+ mode (ExpanderMode): Operation mode (PROVE or VERIFY).
589
+ circuit_file (str): Path to the circuit definition file.
590
+ witness_file (str): Path to the witness file.
591
+ proof_file (str): Path to the proof file (input for verification,
592
+ output for proving).
593
+ pcs_type (str, optional):
594
+ Polynomial commitment scheme type ("Hyrax" or "Raw").
595
+ Defaults to "Hyrax".
596
+ bench (bool, optional):
597
+ If True, collect runtime and memory benchmark data. Defaults to False.
598
+
599
+ Returns:
600
+ subprocess.CompletedProcess[str]: Exit message from the subprocess.
601
+ """
602
+ for file_path, label in [
603
+ (circuit_file, "circuit_file"),
604
+ (witness_file, "witness_file"),
605
+ # For VERIFY mode, the proof_file is an input, not just an output
606
+ (proof_file, "proof_file") if mode == ExpanderMode.VERIFY else (None, None),
607
+ ]:
608
+ if file_path and not Path(file_path).exists():
609
+ msg = f"Missing file required for {label}"
610
+ raise MissingFileError(msg, file_path)
611
+
612
+ env = os.environ.copy()
613
+ env["RUSTFLAGS"] = "-C target-cpu=native"
614
+ time_measure = "/usr/bin/time"
615
+
616
+ expander_binary_path = None
617
+ possible_paths = [
618
+ "./Expander/target/release/expander-exec",
619
+ Path(__file__).parent.parent / "binaries" / "expander-exec",
620
+ Path(sys.prefix) / "bin" / "expander-exec",
621
+ ]
622
+
623
+ for path in possible_paths:
624
+ if Path(path).exists():
625
+ expander_binary_path = str(path)
626
+ break
627
+
628
+ if expander_binary_path:
629
+ args = [
630
+ time_measure,
631
+ expander_binary_path,
632
+ "-p",
633
+ pcs_type,
634
+ ]
635
+ else:
636
+ args = [
637
+ time_measure,
638
+ "mpiexec",
639
+ "-n",
640
+ "1",
641
+ "cargo",
642
+ "run",
643
+ "--manifest-path",
644
+ "Expander/Cargo.toml",
645
+ "--bin",
646
+ "expander-exec",
647
+ "--release",
648
+ "--",
649
+ "-p",
650
+ pcs_type,
651
+ ]
652
+ if mode == ExpanderMode.PROVE:
653
+ args.append(mode.value)
654
+ proof_command = "-o"
655
+ else:
656
+ args.append(mode.value)
657
+ proof_command = "-i"
658
+
659
+ args.extend(["-c", circuit_file])
660
+ args.extend(["-w", witness_file])
661
+ args.extend([proof_command, proof_file])
662
+
663
+ try:
664
+ if bench:
665
+ stop_event, monitor_thread, monitor_results = start_memory_collection(
666
+ "expander-exec",
667
+ )
668
+ start_time = time()
669
+ result = subprocess.run( # noqa: S603
670
+ args,
671
+ env=env,
672
+ capture_output=True,
673
+ text=True,
674
+ check=False,
675
+ )
676
+ end_time = time()
677
+
678
+ print("\n--- BENCHMARK RESULTS ---") # noqa: T201
679
+ print(f"Rust time taken: {end_time - start_time:.4f} seconds") # noqa: T201
680
+
681
+ if bench:
682
+ memory = end_memory_collection(stop_event, monitor_thread, monitor_results)
683
+ print(f"Rust subprocess memory: {memory['total']:.2f} MB") # noqa: T201
684
+
685
+ if result.returncode != 0:
686
+ clean_stderr = filter_expander_output(result.stderr)
687
+ msg = f"Expander {mode.value} failed:\n{clean_stderr}"
688
+ logger.warning(msg)
689
+ msg = f"Expander {mode.value} failed"
690
+ raise ProofBackendError(
691
+ msg,
692
+ command=args,
693
+ returncode=result.returncode,
694
+ stdout=result.stdout,
695
+ stderr=clean_stderr,
696
+ )
697
+
698
+ print( # noqa: T201
699
+ f"✅ expander-exec {mode.value} succeeded:\n{result.stdout}",
700
+ )
701
+
702
+ print(f"Time taken: {end_time - start_time:.4f} seconds") # noqa: T201
703
+ except OSError as e:
704
+ msg = f"Failed to execute Expander {mode.value}: {e}"
705
+ logger.exception(msg)
706
+ raise ProofBackendError(
707
+ msg,
708
+ command=args,
709
+ ) from e
710
+ else:
711
+ return result
712
+
713
+
714
+ def compile_circuit( # noqa: PLR0913
715
+ circuit_name: str,
716
+ circuit_path: str,
717
+ metadata_path: str,
718
+ architecture_path: str,
719
+ w_and_b_path: str,
720
+ proof_system: ZKProofSystems = ZKProofSystems.Expander,
721
+ *,
722
+ dev_mode: bool = True,
723
+ bench: bool = False,
724
+ ) -> subprocess.CompletedProcess[str]:
725
+ """Compile a model into zk circuit
726
+
727
+ Args:
728
+ circuit_name (str): Name of the circuit.
729
+ circuit_path (str): Path to the circuit source file.
730
+ proof_system (ZKProofSystems, optional):
731
+ Proof system to use. Defaults to ZKProofSystems.Expander.
732
+ dev_mode (bool, optional):
733
+ If True, recompiles the rust binary (run in development mode).
734
+ Defaults to True.
735
+ bench (bool, optional):
736
+ Whether or not to run benchmarking metrics. Defaults to False.
737
+
738
+ Raises:
739
+ NotImplementedError: If proof system is not supported.
740
+ Returns:
741
+ subprocess.CompletedProcess[str]: Exit message from the subprocess.
742
+ """
743
+ if proof_system == ZKProofSystems.Expander:
744
+ # Extract the binary name from the circuit path
745
+ binary_name = Path(circuit_name).name
746
+
747
+ # Prepare arguments
748
+ args = {
749
+ "n": circuit_name,
750
+ "c": circuit_path,
751
+ "m": metadata_path,
752
+ "a": architecture_path,
753
+ "b": w_and_b_path,
754
+ }
755
+ # Run the command
756
+ try:
757
+ return run_cargo_command(
758
+ binary_name=binary_name,
759
+ command_type=RunType.COMPILE_CIRCUIT.value,
760
+ args=args,
761
+ dev_mode=dev_mode,
762
+ bench=bench,
763
+ )
764
+ except ProofBackendError as e:
765
+ warning = f"Warning: Compile operation failed: {e}."
766
+ warning2 = f" Using binary: {binary_name}"
767
+ logger.warning(warning)
768
+ logger.warning(warning2)
769
+ raise
770
+
771
+ else:
772
+ msg = f"Proof system {proof_system} not implemented"
773
+ raise ProofSystemNotImplementedError(msg)
774
+
775
+
776
+ def generate_witness( # noqa: PLR0913
777
+ circuit_name: str,
778
+ circuit_path: str,
779
+ witness_file: str,
780
+ input_file: str,
781
+ output_file: str,
782
+ metadata_path: str,
783
+ proof_system: ZKProofSystems = ZKProofSystems.Expander,
784
+ *,
785
+ dev_mode: bool = False,
786
+ bench: bool = False,
787
+ ) -> subprocess.CompletedProcess[str]:
788
+ """Generate a witness file for a circuit.
789
+
790
+ Args:
791
+ circuit_name (str): Name of the circuit.
792
+ circuit_path (str): Path to the circuit definition.
793
+ witness_file (str): Path to the output witness file.
794
+ input_file (str): Path to the input JSON file with private inputs.
795
+ output_file (str): Path to the output JSON file with computed outputs.
796
+ proof_system (ZKProofSystems, optional): Proof system to use.
797
+ Defaults to ZKProofSystems.Expander.
798
+ dev_mode (bool, optional):
799
+ If True, recompiles the rust binary (run in development mode).
800
+ Defaults to False.
801
+ bench (bool, optional):
802
+ If True, enable benchmarking. Defaults to False.
803
+
804
+ Raises:
805
+ NotImplementedError: If proof system is not supported.
806
+
807
+ Returns:
808
+ subprocess.CompletedProcess[str]: Exit message from the subprocess.
809
+ """
810
+ if proof_system == ZKProofSystems.Expander:
811
+ # Extract the binary name from the circuit path
812
+ binary_name = Path(circuit_name).name
813
+
814
+ # Prepare arguments
815
+ args = {
816
+ "n": circuit_name,
817
+ "c": circuit_path,
818
+ "i": input_file,
819
+ "o": output_file,
820
+ "w": witness_file,
821
+ "m": metadata_path,
822
+ }
823
+ # Run the command
824
+ try:
825
+ return run_cargo_command(
826
+ binary_name=binary_name,
827
+ command_type=RunType.GEN_WITNESS.value,
828
+ args=args,
829
+ dev_mode=dev_mode,
830
+ bench=bench,
831
+ )
832
+ except ProofBackendError as e:
833
+ warning = f"Warning: Witness generation failed: {e}"
834
+ logger.warning(warning)
835
+ raise
836
+ else:
837
+ msg = f"Proof system {proof_system} not implemented"
838
+ raise ProofSystemNotImplementedError(msg)
839
+
840
+
841
+ def generate_proof( # noqa: PLR0913
842
+ circuit_name: str,
843
+ circuit_path: str,
844
+ witness_file: str,
845
+ proof_file: str,
846
+ metadata_path: str,
847
+ proof_system: ZKProofSystems = ZKProofSystems.Expander,
848
+ *,
849
+ dev_mode: bool = False,
850
+ ecc: bool = True,
851
+ bench: bool = False,
852
+ ) -> subprocess.CompletedProcess[str]:
853
+ """Generate proof for the witness.
854
+
855
+ Args:
856
+ circuit_name (str): Name of the circuit.
857
+ circuit_path (str): Path to the circuit definition.
858
+ witness_file (str): Path to the witness file.
859
+ proof_file (str): Path to the output proof file.
860
+ proof_system (ZKProofSystems, optional): Proof system to use.
861
+ Defaults to ZKProofSystems.Expander.
862
+ dev_mode (bool, optional):
863
+ If True, recompiles the rust binary (run in development mode).
864
+ Defaults to False.
865
+ ecc (bool, optional):
866
+ If true, run proof using ECC api, otherwise run directly through Expander.
867
+ Defaults to True.
868
+ bench (bool, optional):
869
+ If True, enable benchmarking. Defaults to False.
870
+
871
+ Raises:
872
+ NotImplementedError: If proof system is not supported.
873
+
874
+ Returns:
875
+ subprocess.CompletedProcess[str]: Exit message from the subprocess.
876
+ """
877
+ if proof_system == ZKProofSystems.Expander:
878
+ if ecc:
879
+ # Extract the binary name from the circuit path
880
+ binary_name = Path(circuit_name).name
881
+
882
+ # Prepare arguments
883
+ args = {
884
+ "n": circuit_name,
885
+ "c": circuit_path,
886
+ "w": witness_file,
887
+ "p": proof_file,
888
+ "m": metadata_path,
889
+ }
890
+
891
+ # Run the command
892
+ try:
893
+ return run_cargo_command(
894
+ binary_name=binary_name,
895
+ command_type=RunType.PROVE_WITNESS.value,
896
+ args=args,
897
+ dev_mode=dev_mode,
898
+ bench=bench,
899
+ )
900
+ except ProofBackendError as e:
901
+ warning = f"Warning: Proof generation failed: {e}"
902
+ logger.warning(warning)
903
+ raise
904
+ else:
905
+ return run_expander_raw(
906
+ mode=ExpanderMode.PROVE,
907
+ circuit_file=circuit_path,
908
+ witness_file=witness_file,
909
+ proof_file=proof_file,
910
+ bench=bench,
911
+ )
912
+ else:
913
+ msg = f"Proof system {proof_system} not implemented"
914
+ raise ProofSystemNotImplementedError(msg)
915
+
916
+
917
+ def generate_verification( # noqa: PLR0913
918
+ circuit_name: str,
919
+ circuit_path: str,
920
+ input_file: str,
921
+ output_file: str,
922
+ witness_file: str,
923
+ proof_file: str,
924
+ metadata_path: str,
925
+ proof_system: ZKProofSystems = ZKProofSystems.Expander,
926
+ *,
927
+ dev_mode: bool = False,
928
+ ecc: bool = True,
929
+ bench: bool = False,
930
+ ) -> subprocess.CompletedProcess[str]:
931
+ """Verify a given proof.
932
+
933
+ Args:
934
+ circuit_name (str): Name of the circuit.
935
+ circuit_path (str): Path to the circuit definition.
936
+ input_file (str): Path to the input JSON file with public inputs.
937
+ output_file (str): Path to the output JSON file with expected outputs.
938
+ witness_file (str): Path to the witness file.
939
+ proof_file (str): Path to the output proof file.
940
+ proof_system (ZKProofSystems, optional): Proof system to use.
941
+ Defaults to ZKProofSystems.Expander.
942
+ dev_mode (bool, optional):
943
+ If True, recompiles the rust binary (run in development mode).
944
+ Defaults to False.
945
+ ecc (bool, optional):
946
+ If true, run proof using ECC api, otherwise run directly through Expander.
947
+ Defaults to True.
948
+ bench (bool, optional):
949
+ If True, enable benchmarking. Defaults to False.
950
+
951
+ Raises:
952
+ NotImplementedError: If proof system is not supported.
953
+
954
+ Returns:
955
+ subprocess.CompletedProcess[str]: Exit message from the subprocess.
956
+ """
957
+ if proof_system == ZKProofSystems.Expander:
958
+ if ecc:
959
+ # Extract the binary name from the circuit path
960
+ binary_name = Path(circuit_name).name
961
+
962
+ # Prepare arguments
963
+ args = {
964
+ "n": circuit_name,
965
+ "c": circuit_path,
966
+ "i": input_file,
967
+ "o": output_file,
968
+ "w": witness_file,
969
+ "p": proof_file,
970
+ "m": metadata_path,
971
+ }
972
+ # Run the command
973
+ try:
974
+ return run_cargo_command(
975
+ binary_name=binary_name,
976
+ command_type=RunType.GEN_VERIFY.value,
977
+ args=args,
978
+ dev_mode=dev_mode,
979
+ bench=bench,
980
+ )
981
+ except ProofBackendError as e:
982
+ warning = f"Warning: Verification generation failed: {e}"
983
+ logger.warning(warning)
984
+ raise
985
+ else:
986
+ return run_expander_raw(
987
+ mode=ExpanderMode.VERIFY,
988
+ circuit_file=circuit_path,
989
+ witness_file=witness_file,
990
+ proof_file=proof_file,
991
+ bench=bench,
992
+ )
993
+ else:
994
+ msg = f"Proof system {proof_system} not implemented"
995
+ raise ProofSystemNotImplementedError(msg)
996
+
997
+
998
+ def run_end_to_end( # noqa: PLR0913
999
+ circuit_name: str,
1000
+ circuit_path: str,
1001
+ input_file: str,
1002
+ output_file: str,
1003
+ proof_system: ZKProofSystems = ZKProofSystems.Expander,
1004
+ *,
1005
+ demo: bool = False,
1006
+ dev_mode: bool = False,
1007
+ ecc: bool = True,
1008
+ ) -> int:
1009
+ """Run the full pipeline for proving and verifying a circuit.
1010
+
1011
+ Steps:
1012
+ 1. Compile the circuit.
1013
+ 2. Generate a witness from inputs.
1014
+ 3. Produce a proof from the witness.
1015
+ 4. Verify the proof against inputs and outputs.
1016
+
1017
+ Args:
1018
+ circuit_name (str): Name of the circuit.
1019
+ circuit_path (str): Path to the circuit definition.
1020
+ input_file (str): Path to the input JSON file with public inputs.
1021
+ output_file (str): Path to the output JSON file with expected outputs.
1022
+ proof_system (ZKProofSystems, optional):
1023
+ Proof system to use. Defaults to ZKProofSystems.Expander.
1024
+ demo (bool, optional):
1025
+ Run Demo mode, which limits prints, to clean only. Defaults to False.
1026
+ dev_mode (bool, optional):
1027
+ If True, recompiles the rust binary (run in development mode).
1028
+ Defaults to False.
1029
+ ecc (bool, optional):
1030
+ If true, run proof using ECC api, otherwise run directly through Expander.
1031
+ Defaults to True.
1032
+
1033
+ Raises:
1034
+ NotImplementedError: If proof system is not supported.
1035
+
1036
+ Returns:
1037
+ int: Exit code from the verification step (0 = success, non-zero = failure).
1038
+ """
1039
+ _ = demo
1040
+ if proof_system == ZKProofSystems.Expander:
1041
+ path = Path(circuit_path)
1042
+ base = str(path.with_suffix("")) # filename without extension
1043
+ ext = path.suffix
1044
+
1045
+ witness_file = f"{base}_witness{ext}"
1046
+ proof_file = f"{base}_proof.bin"
1047
+ compile_circuit(
1048
+ circuit_name,
1049
+ circuit_path,
1050
+ f"{base}_metadata.json",
1051
+ f"{base}_architecture.json",
1052
+ f"{base}_wandb.json",
1053
+ proof_system,
1054
+ dev_mode,
1055
+ )
1056
+ generate_witness(
1057
+ circuit_name,
1058
+ circuit_path,
1059
+ witness_file,
1060
+ input_file,
1061
+ output_file,
1062
+ f"{base}_metadata.json",
1063
+ proof_system,
1064
+ dev_mode,
1065
+ )
1066
+ generate_proof(
1067
+ circuit_name,
1068
+ circuit_path,
1069
+ witness_file,
1070
+ proof_file,
1071
+ f"{base}_metadata.json",
1072
+ proof_system,
1073
+ dev_mode,
1074
+ ecc,
1075
+ )
1076
+ return generate_verification(
1077
+ circuit_name,
1078
+ circuit_path,
1079
+ input_file,
1080
+ output_file,
1081
+ witness_file,
1082
+ proof_file,
1083
+ f"{base}_metadata.json",
1084
+ proof_system,
1085
+ dev_mode,
1086
+ ecc,
1087
+ )
1088
+ msg = f"Proof system {proof_system} not implemented"
1089
+ raise ProofSystemNotImplementedError(msg)
1090
+
1091
+
1092
+ def get_files(
1093
+ name: str,
1094
+ proof_system: ZKProofSystems,
1095
+ folders: dict[str, str],
1096
+ ) -> dict[str, str]:
1097
+ """
1098
+ Generate file paths ensuring folders exist.
1099
+
1100
+ Args:
1101
+ name (str): The base name for all generated files.
1102
+ proof_system (ZKProofSystems): The ZK proof system being used.
1103
+ folders (dict[str, str]):
1104
+ dictionary containing required folder paths with keys like:
1105
+ 'input', 'proof', 'temp', 'circuit', 'weights', 'output', 'quantized_model'.
1106
+
1107
+ Raises:
1108
+ NotImplementedError: If not implemented proof system is tried
1109
+
1110
+ Returns:
1111
+ dict[str, str]: A dictionary mapping descriptive keys to file paths.
1112
+ """
1113
+ # Common file paths
1114
+ paths = {
1115
+ "input_file": str(Path(folders["input"]) / f"{name}_input.json"),
1116
+ "public_path": str(Path(folders["proof"]) / f"{name}_public.json"),
1117
+ "metadata_path": str(Path(folders["weights"]) / f"{name}_metadata.json"),
1118
+ "architecture_path": str(
1119
+ Path(folders["weights"]) / f"{name}_architecture.json",
1120
+ ),
1121
+ "w_and_b_path": str(Path(folders["weights"]) / f"{name}_w_and_b.json"),
1122
+ "output_file": str(Path(folders["output"]) / f"{name}_output.json"),
1123
+ }
1124
+
1125
+ # Proof-system-specific files
1126
+ if proof_system == ZKProofSystems.Expander:
1127
+ paths.update(
1128
+ {
1129
+ "circuit_name": name,
1130
+ "witness_file": str(Path(folders["input"]) / f"{name}_witness.txt"),
1131
+ "proof_path": str(Path(folders["proof"]) / f"{name}_proof.bin"),
1132
+ },
1133
+ )
1134
+ else:
1135
+ msg = f"Proof system {proof_system} not implemented"
1136
+ raise ProofSystemNotImplementedError(msg)
1137
+
1138
+ return paths