JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +6 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,833 @@
1
+ # python/scripts/benchmark_runner.py
2
+ # ruff: noqa: S603, RUF003, RUF002, T201
3
+
4
+ """
5
+ Benchmark JSTProve by invoking the CLI phases (compile → witness → prove → verify),
6
+ streaming live output with a spinner/HUD, and logging per-phase timing/memory.
7
+
8
+ Writes one JSON object per phase per iteration to a JSONL file so you can analyze later.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ # --- Standard library --------------------------------------------------------
14
+ import argparse
15
+ import json
16
+ import logging
17
+ import os
18
+ import re
19
+ import shutil
20
+ import subprocess
21
+ import sys
22
+ import tempfile
23
+ import time
24
+ from contextlib import suppress
25
+ from dataclasses import dataclass
26
+ from datetime import datetime, timezone
27
+ from pathlib import Path
28
+ from statistics import mean, stdev
29
+
30
+ # --- Third-party -------------------------------------------------------------
31
+ import psutil
32
+
33
+ # --- Local -------------------------------------------------------------------
34
+ from python.core.utils.benchmarking_helpers import (
35
+ end_memory_collection,
36
+ start_memory_collection,
37
+ )
38
+
39
+ # -----------------------------------------------------------------------------
40
+ # Logging
41
+ # -----------------------------------------------------------------------------
42
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
43
+ log = logging.getLogger("benchmark_runner")
44
+
45
+ # -----------------------------------------------------------------------------
46
+ # Parsing helpers / patterns
47
+ # -----------------------------------------------------------------------------
48
+ TIME_PATTERNS = [
49
+ re.compile(r"Rust time taken:\s*([0-9.]+)"),
50
+ re.compile(r"Time elapsed:\s*([0-9.]+)\s*seconds"),
51
+ ]
52
+ MEM_PATTERNS = [
53
+ re.compile(r"Peak Memory used Overall\s*:\s*([0-9.]+)"),
54
+ re.compile(r"Rust subprocess memory\s*:\s*([0-9.]+)"),
55
+ ]
56
+
57
+ ANSI_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
58
+ ECC_HINT_RE = re.compile(r"built\s+hint\s+normalized\s+ir\b.*", re.IGNORECASE)
59
+ ECC_LAYERED_RE = re.compile(r"built\s+layered\s+circuit\b.*", re.IGNORECASE)
60
+ ECC_LINE_PATTERNS = [
61
+ re.compile(r"built layered circuit\b.*", re.IGNORECASE),
62
+ re.compile(r"built hint normalized ir\b.*", re.IGNORECASE),
63
+ ]
64
+
65
+ ECC_KEYS = {
66
+ "numInputs",
67
+ "numConstraints",
68
+ "numInsns",
69
+ "numVars",
70
+ "numTerms",
71
+ "numSegment",
72
+ "numLayer",
73
+ "numUsedInputs",
74
+ "numUsedVariables",
75
+ "numVariables",
76
+ "numAdd",
77
+ "numCst",
78
+ "numMul",
79
+ "totalCost",
80
+ }
81
+
82
+ # Accept optional spaces and thousands separators in values
83
+ KV_PAIR = re.compile(r"\b([A-Za-z_][A-Za-z0-9_]*)\s*=\s*([0-9,]+)\b")
84
+
85
+ # Spinner glyphs (ASCII by default; set JSTPROVE_UNICODE=1 to switch)
86
+ _SPINNER_ASCII = "-\\|/"
87
+ _SPINNER_UNICODE = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
88
+
89
+
90
+ def _term_width(default: int = 100) -> int:
91
+ """Best-effort terminal width (columns)."""
92
+ try:
93
+ return shutil.get_terminal_size((default, 20)).columns
94
+ except Exception:
95
+ return default
96
+
97
+
98
+ # NOTE: currently unused; retained only because the (deprecated) compile card used it.
99
+ def _human_bytes(n: int | None) -> str:
100
+ """Pretty-print bytes as B/KB/MB/GB/TB (unused in the current flow)."""
101
+ conversion_value = 1024.0
102
+ if n is None:
103
+ return "NA"
104
+ units = ["B", "KB", "MB", "GB", "TB"]
105
+ x = float(n)
106
+ for u in units:
107
+ if x < conversion_value or u == units[-1]:
108
+ return f"{x:.1f} {u}" if u != "B" else f"{int(x)} B"
109
+ x /= conversion_value
110
+ msg = "Unreachable code: failed to format byte size."
111
+ raise RuntimeError(msg)
112
+
113
+
114
+ # NOTE: currently unused; retained only because the (deprecated) compile card used it.
115
+ def _fmt_int(n: int | None) -> str:
116
+ """Format an int with thousands separators (unused in the current flow)."""
117
+ return f"{n:,}" if isinstance(n, int) else "NA"
118
+
119
+
120
+ def _bar(value: int, vmax: int, width: int = 24, char: str = "█") -> str:
121
+ """Fixed-width bar proportional to value/vmax, using a solid block character."""
122
+ if vmax <= 0 or value <= 0:
123
+ return " " * width
124
+ fill = max(1, int(width * min(value, vmax) / vmax))
125
+ return char * fill + " " * (width - fill)
126
+
127
+
128
+ def _marquee(t: float, width: int = 24, char: str = "█") -> str:
129
+ """Bouncing 8-char block to suggest activity when total work is unknown."""
130
+ w = max(8, min(width, 24))
131
+ pos = int((abs(((t * 0.8) % 2) - 1)) * (w - 8))
132
+ return " " * pos + char * 8 + " " * (w - 8 - pos)
133
+
134
+
135
+ def _sum_child_rss_mb(parent_pid: int) -> float:
136
+ """Approximate current total RSS of all child processes, in MB."""
137
+ try:
138
+ parent = psutil.Process(parent_pid)
139
+ except psutil.Error:
140
+ return 0.0
141
+ total = 0
142
+ for c in parent.children(recursive=True):
143
+ with suppress(psutil.Error):
144
+ total += c.memory_info().rss
145
+ return total / (1024.0 * 1024.0)
146
+
147
+
148
+ def parse_ecc_stats(text: str) -> dict[str, int]:
149
+ """Scan the whole blob for k=v pairs and keep only ECC keys."""
150
+ clean = ANSI_RE.sub("", text).replace("\r", "\n")
151
+ stats: dict[str, int] = {}
152
+ for k, v in KV_PAIR.findall(clean):
153
+ if k in ECC_KEYS:
154
+ with suppress(ValueError):
155
+ stats[k] = int(v.replace(",", ""))
156
+ return stats
157
+
158
+
159
+ def strip_ansi(s: str) -> str:
160
+ """Remove ANSI color/escape sequences from a string."""
161
+ return ANSI_RE.sub("", s)
162
+
163
+
164
+ def count_onnx_parameters(model_path: Path) -> int:
165
+ """
166
+ Sum element counts of ONNX initializers (trainable weights).
167
+ Returns -1 if the `onnx` dependency is unavailable.
168
+ """
169
+ try:
170
+ import onnx # noqa: PLC0415 # type: ignore[import]
171
+ except Exception:
172
+ return -1
173
+
174
+ model = onnx.load(str(model_path))
175
+ total = 0
176
+ for init in model.graph.initializer:
177
+ n = 1
178
+ for d in init.dims:
179
+ n *= int(d)
180
+ total += n
181
+ return int(total)
182
+
183
+
184
+ def file_size_bytes(path: str | Path) -> int | None:
185
+ """Return file size in bytes (or None if the path does not exist)."""
186
+ try:
187
+ p = Path(path)
188
+ return p.stat().st_size if p.exists() else None
189
+ except OSError:
190
+ return None
191
+
192
+
193
+ def parse_metrics(text: str) -> tuple[float | None, float | None]:
194
+ """Best-effort parse for time (seconds) and memory (MB) from CLI output."""
195
+ time_s: float | None = None
196
+ mem_mb: float | None = None
197
+ for pat in TIME_PATTERNS:
198
+ m = pat.search(text)
199
+ if m:
200
+ try:
201
+ time_s = float(m.group(1))
202
+ break
203
+ except ValueError:
204
+ pass
205
+ for pat in MEM_PATTERNS:
206
+ m = pat.search(text)
207
+ if m:
208
+ try:
209
+ mem_mb = float(m.group(1))
210
+ break
211
+ except ValueError:
212
+ pass
213
+ return time_s, mem_mb
214
+
215
+
216
+ def now_utc() -> str:
217
+ """
218
+ UTC timestamp in RFC3339 format without subseconds
219
+ (e.g., '2025-01-01T00:00:00Z').
220
+ """
221
+ return (
222
+ datetime.now(timezone.utc)
223
+ .replace(microsecond=0)
224
+ .isoformat()
225
+ .replace("+00:00", "Z")
226
+ )
227
+
228
+
229
+ @dataclass(frozen=True)
230
+ class PhaseIO:
231
+ """Per-phase file locations used to invoke the CLI."""
232
+
233
+ model_path: Path
234
+ circuit_path: Path
235
+ input_path: Path | None
236
+ output_path: Path
237
+ witness_path: Path
238
+ proof_path: Path
239
+
240
+
241
+ def _build_phase_cmd(phase: str, io: PhaseIO) -> list[str]:
242
+ """Construct the exact `jst` CLI command for a phase."""
243
+ base = ["jst", "--no-banner"]
244
+ if phase == "compile":
245
+ return [*base, "compile", "-m", str(io.model_path), "-c", str(io.circuit_path)]
246
+ if phase == "witness":
247
+ cmd = [
248
+ *base,
249
+ "witness",
250
+ "-c",
251
+ str(io.circuit_path),
252
+ "-o",
253
+ str(io.output_path),
254
+ "-w",
255
+ str(io.witness_path),
256
+ ]
257
+ if io.input_path:
258
+ cmd += ["-i", str(io.input_path)]
259
+ return cmd
260
+ if phase == "prove":
261
+ return [
262
+ *base,
263
+ "prove",
264
+ "-c",
265
+ str(io.circuit_path),
266
+ "-w",
267
+ str(io.witness_path),
268
+ "-p",
269
+ str(io.proof_path),
270
+ ]
271
+ if phase == "verify":
272
+ cmd = [
273
+ *base,
274
+ "verify",
275
+ "-c",
276
+ str(io.circuit_path),
277
+ "-o",
278
+ str(io.output_path),
279
+ "-w",
280
+ str(io.witness_path),
281
+ "-p",
282
+ str(io.proof_path),
283
+ ]
284
+ if io.input_path:
285
+ cmd += ["-i", str(io.input_path)]
286
+ return cmd
287
+ msg = f"unknown phase: {phase}"
288
+ raise ValueError(msg)
289
+
290
+
291
+ def run_cli( # noqa: PLR0915, PLR0912, C901
292
+ phase: str,
293
+ io: PhaseIO,
294
+ ) -> tuple[
295
+ int,
296
+ str,
297
+ float | None,
298
+ float | None,
299
+ list[str],
300
+ float | None,
301
+ float | None,
302
+ dict[str, int],
303
+ ]:
304
+ """
305
+ Execute one CLI phase, streaming stdout live with a spinner/HUD and
306
+ tracking peak RSS via psutil.
307
+
308
+ Returns:
309
+ (returncode, combined_output, time_s, mem_mb_primary, cmd_list,
310
+ mem_mb_rust, mem_mb_psutil, ecc_dict)
311
+ """
312
+ cmd = _build_phase_cmd(phase, io)
313
+
314
+ env = os.environ.copy()
315
+ env.setdefault("RUST_LOG", "info")
316
+ env.setdefault("RUST_BACKTRACE", "1")
317
+ env.setdefault("PYTHONUNBUFFERED", "1") # help with child buffering
318
+
319
+ stop_ev, mon_thread, mon_results = start_memory_collection("")
320
+ start = time.time()
321
+ combined_lines: list[str] = []
322
+ ecc_live: dict[str, int] = {}
323
+
324
+ proc = subprocess.Popen(
325
+ cmd,
326
+ stdout=subprocess.PIPE,
327
+ stderr=subprocess.STDOUT,
328
+ text=True,
329
+ env=env,
330
+ bufsize=1,
331
+ )
332
+
333
+ spinner = (
334
+ _SPINNER_UNICODE
335
+ if os.environ.get("JSTPROVE_UNICODE") == "1"
336
+ else _SPINNER_ASCII
337
+ )
338
+ sp_i = 0
339
+ peak_live_mb = 0.0
340
+ tw = _term_width()
341
+ bar_w = max(18, min(28, tw - 50))
342
+
343
+ try:
344
+ while True:
345
+ line = proc.stdout.readline() if proc.stdout else ""
346
+ elapsed = time.time() - start
347
+
348
+ # live peak memory
349
+ live_mb = _sum_child_rss_mb(proc.pid)
350
+ peak_live_mb = max(peak_live_mb, live_mb)
351
+
352
+ if line:
353
+ combined_lines.append(line.rstrip("\n"))
354
+ low = line.lower()
355
+
356
+ # Echo milestone lines for user feedback (unchanged behavior)
357
+ if ("built layered circuit" in low) or (
358
+ "built hint normalized ir" in low
359
+ ):
360
+ print(line, end="")
361
+
362
+ # Harvest ECC counters live from any k=v pairs we see
363
+ for k, v in KV_PAIR.findall(ANSI_RE.sub("", line)):
364
+ if k in ECC_KEYS:
365
+ with suppress(ValueError):
366
+ ecc_live[k] = int(v.replace(",", ""))
367
+
368
+ # refresh HUD ~10Hz
369
+ if int(elapsed * 10) != int((elapsed - 0.09) * 10) or not line:
370
+ spin = spinner[sp_i % len(spinner)]
371
+ sp_i += 1
372
+ hud_bar = _marquee(elapsed, width=bar_w)
373
+ hud = (
374
+ f"\r[{spin}] {phase:<7} | {elapsed:6.1f}s | "
375
+ f"mem↑ {peak_live_mb:7.1f} MB | {hud_bar}"
376
+ )
377
+ print(hud[: tw - 1], end="", flush=True)
378
+
379
+ if proc.poll() is not None:
380
+ # Drain any remaining buffered output after the process exits.
381
+ if proc.stdout:
382
+ tail = proc.stdout.read()
383
+ if tail:
384
+ combined_lines.extend(tail.splitlines())
385
+
386
+ # final HUD line
387
+ elapsed = time.time() - start
388
+ hud = (
389
+ f"\r[✔] {phase:<7} | {elapsed:6.1f}s "
390
+ f"| mem↑ {peak_live_mb:7.1f} MB | " + " " * bar_w
391
+ )
392
+ print(hud[: tw - 1])
393
+ break
394
+
395
+ time.sleep(0.09)
396
+
397
+ finally:
398
+ # collect psutil-based peak (may differ from live sampler)
399
+ collected_mb: float | None = None
400
+ try:
401
+ mem = end_memory_collection(stop_ev, mon_thread, mon_results) # type: ignore[arg-type]
402
+ if isinstance(mem, dict):
403
+ collected_mb = float(mem.get("total", 0.0))
404
+ except Exception:
405
+ collected_mb = None
406
+
407
+ combined = "\n".join(combined_lines)
408
+ time_s, mem_mb_rust = parse_metrics(combined)
409
+ if time_s is None:
410
+ time_s = elapsed
411
+ mem_mb_psutil = collected_mb if collected_mb is not None else peak_live_mb
412
+ mem_mb_primary = mem_mb_psutil if mem_mb_psutil is not None else mem_mb_rust
413
+
414
+ # If we missed something live, parse once more from the combined blob
415
+ if not ecc_live:
416
+ ecc_live = parse_ecc_stats(combined)
417
+
418
+ # Also append a compact ECC block into the combined text for later eyeballing
419
+ if ecc_live:
420
+ kv = " ".join(f"{k}={v}" for k, v in sorted(ecc_live.items()))
421
+ combined += f"\n[ECC]\n{kv}\n"
422
+
423
+ return (
424
+ proc.returncode or 0,
425
+ combined,
426
+ time_s,
427
+ mem_mb_primary,
428
+ cmd,
429
+ mem_mb_rust,
430
+ mem_mb_psutil,
431
+ ecc_live,
432
+ )
433
+
434
+
435
+ # NOTE: this card printer is currently unused (kept for reference).
436
+ def _print_compile_card(
437
+ ecc: dict,
438
+ circuit_bytes: int | None,
439
+ quant_bytes: int | None,
440
+ ) -> None:
441
+ """Pretty compile stats block (unused in the current flow)."""
442
+ _ = circuit_bytes, quant_bytes
443
+ if not ecc:
444
+ return
445
+ keys = [
446
+ "numAdd",
447
+ "numMul",
448
+ "numCst",
449
+ "numVars",
450
+ "numInsns",
451
+ "numConstraints",
452
+ "totalCost",
453
+ ]
454
+ data = {k: int(ecc[k]) for k in keys if k in ecc}
455
+ if not data:
456
+ return
457
+ vmax = max(data.values())
458
+ w = max(24, min(40, _term_width() - 50))
459
+
460
+ print()
461
+ print(
462
+ "┌────────────────────────── Compile Stats ──────────────────────────┐",
463
+ )
464
+ for k in keys:
465
+ if k in data:
466
+ bar = _bar(data[k], vmax, width=w)
467
+ # _fmt_int/_human_bytes are intentionally still present
468
+ # but currently unused elsewhere.
469
+ print(f"│ {k:<14} {data[k]:>12} {bar} │")
470
+ print(
471
+ "└────────────────────────────────────────────────────────────────────┘",
472
+ )
473
+
474
+
475
+ def _quantized_path_from_circuit(circuit_path: Path) -> Path:
476
+ """Derive quantized ONNX path from circuit: <dir>/<stem>_quantized_model.onnx"""
477
+ return circuit_path.with_name(f"{circuit_path.stem}_quantized_model.onnx")
478
+
479
+
480
+ def _fmt_mean_sd(vals: list[float]) -> tuple[str, float | None, float | None]:
481
+ """Format a list as 'μ ± σ' (or single value), returning the label and μ,σ."""
482
+ if not vals:
483
+ return "NA", None, None
484
+ if len(vals) == 1:
485
+ v = vals[0]
486
+ return f"{v:.3f}", v, None
487
+ mu, sd = mean(vals), stdev(vals)
488
+ return f"{mu:.3f} ± {sd:.3f}", mu, sd
489
+
490
+
491
+ def _summary_card( # noqa: PLR0915, C901
492
+ model_name: str,
493
+ tmap: dict[str, list[float]],
494
+ mmap: dict[str, list[float]],
495
+ ) -> None:
496
+ """
497
+ Render a compact summary card with data-driven column widths so that
498
+ multi-digit means (e.g., 46.741) align just as neatly as single-digit ones.
499
+ Uses ASCII borders when JSTPROVE_ASCII=1.
500
+ """
501
+ _ = model_name
502
+ phases = ("compile", "witness", "prove", "verify")
503
+
504
+ # 1) Build labels and stats first (so we know true content widths)
505
+ rows: list[tuple[str, str, str, str, str]] = []
506
+ t_means: list[float] = []
507
+ m_means: list[float] = []
508
+
509
+ def fmt_mean_sd(vals: list[float]) -> tuple[str, float | None]:
510
+ if not vals:
511
+ return "NA", None
512
+ if len(vals) == 1:
513
+ v = float(vals[0])
514
+ return f"{v:.3f}", v
515
+ mu = float(mean(vals))
516
+ sd = float(stdev(vals))
517
+ return f"{mu:.3f} ± {sd:.3f}", mu
518
+
519
+ for ph in phases:
520
+ tlabel, tmean = fmt_mean_sd(tmap.get(ph, []))
521
+ mlabel, mmean = fmt_mean_sd(mmap.get(ph, []))
522
+ tbest = f"{min(tmap[ph]):.3f}" if tmap.get(ph) else "NA"
523
+ mpeak = f"{max(mmap[ph]):.3f}" if mmap.get(ph) else "NA"
524
+ rows.append((ph, tlabel, tbest, mlabel, mpeak))
525
+ if tmean is not None:
526
+ t_means.append(tmean)
527
+ if mmean is not None:
528
+ m_means.append(mmean)
529
+
530
+ # When everything is NA, avoid div-by-zero in bar scaling
531
+ tmax = max(t_means) if t_means else 1.0
532
+ mmax = max(m_means) if m_means else 1.0
533
+
534
+ # 2) Compute column widths from the actual content + headers
535
+ hdr_phase = "phase"
536
+ hdr_time = "time (s)"
537
+ hdr_best = "best"
538
+ hdr_mem = "mem (MB)"
539
+ hdr_peak = "peak"
540
+
541
+ phase_w = max(len(hdr_phase), *(len(ph) for ph, *_ in rows))
542
+ time_w = max(len(hdr_time), *(len(t) for _, t, *_ in rows))
543
+ best_w = max(len(hdr_best), *(len(b) for *_, b, _, _ in rows))
544
+ mem_w = max(len(hdr_mem), *(len(m) for *_, m, _ in rows))
545
+ peak_w = max(len(hdr_peak), *(len(p) for *_, p in rows))
546
+
547
+ # pick a reasonable bar width; shrink only if terminal is very narrow
548
+ # we’ll try to fit inside the terminal if possible, but we *don’t* rely on it.
549
+ min_bar = 10
550
+ max_bar = 24
551
+ # Estimate available width from the terminal; keep a comfortable default.
552
+ tw = _term_width(100)
553
+ # Fixed chars per row besides the two bars (separators, spaces, borders)
554
+ # Layout: │ {phase:<pw} │ {time:<tw} │ {best:>bw} │ {tbar} │ {mem:<mw} │ {peak:>pk} │ {mbar} │ # noqa: E501
555
+ fixed = (
556
+ 1 # left border
557
+ + 1
558
+ + phase_w
559
+ + 1
560
+ + 1 # "│ " + phase + " │"
561
+ + 1
562
+ + time_w
563
+ + 1
564
+ + 1 # " " + time + " │"
565
+ + 1
566
+ + best_w
567
+ + 1
568
+ + 1 # " " + best + " │"
569
+ + 1
570
+ + 1 # " " + "│" before mem
571
+ + 1
572
+ + mem_w
573
+ + 1
574
+ + 1 # " " + mem + " │"
575
+ + 1
576
+ + peak_w
577
+ + 1
578
+ + 1 # " " + peak + " │"
579
+ + 1 # space before mbar
580
+ + 1 # right border (we’ll account for it at the end)
581
+ )
582
+ # two bars + the final right border
583
+ # available width = tw - (fixed + 2 bars + right border). solve for bar size.
584
+ # we’ll clamp into [min_bar, max_bar].
585
+ avail_for_bars = max(0, tw - (fixed + 1)) # leave room for right border
586
+ per_bar = max(min_bar, min(max_bar, avail_for_bars // 2)) or min_bar
587
+ bar_w = per_bar
588
+
589
+ # Optionally switch to pure-ASCII table and bar characters
590
+ ascii_mode = os.environ.get("JSTPROVE_ASCII") == "1"
591
+ V = "|" if ascii_mode else "│" # noqa: N806
592
+ H = "-" if ascii_mode else "─" # noqa: N806
593
+ TL = "+" if ascii_mode else "┌" # noqa: N806
594
+ TR = "+" if ascii_mode else "┐" # noqa: N806
595
+ BL = "+" if ascii_mode else "└" # noqa: N806
596
+ BR = "+" if ascii_mode else "┘" # noqa: N806
597
+ TJ = "+" if ascii_mode else "├" # noqa: N806
598
+ BJ = "+" if ascii_mode else "┴" # noqa: N806,F841
599
+ MJ = "+" if ascii_mode else "┼" # noqa: N806
600
+ BAR_CHAR = "#" if ascii_mode else "█" # noqa: N806
601
+
602
+ def bar(val: float, vmax: float) -> str:
603
+ # scale relative to max mean; clamp; ensure non-empty when val>0
604
+ if vmax <= 0 or val <= 0:
605
+ return " " * bar_w
606
+ filled = max(1, int(bar_w * min(val, vmax) / vmax))
607
+ return BAR_CHAR * filled + " " * (bar_w - filled)
608
+
609
+ # Make a header content row to measure the total width
610
+ header_line = (
611
+ f"{V} {hdr_phase:<{phase_w}} {V} {hdr_time:<{time_w}} {V} {hdr_best:>{best_w}} "
612
+ f"{V} {'t-bar':<{bar_w}} {V} {hdr_mem:<{mem_w}} "
613
+ f"{V} {hdr_peak:>{peak_w}} {V} {'m-bar':<{bar_w}} {V}"
614
+ )
615
+ # Draw a top border that exactly matches header width
616
+ top = (
617
+ TL + H * (len(header_line) - 2) + TR
618
+ ) # -2 accounts for replacing first/last char with corners
619
+ sep = (
620
+ TJ
621
+ + H * (2 + phase_w)
622
+ + MJ
623
+ + H * (2 + time_w)
624
+ + MJ
625
+ + H * (2 + best_w)
626
+ + MJ
627
+ + H * (2 + bar_w)
628
+ + MJ
629
+ + H * (2 + mem_w)
630
+ + MJ
631
+ + H * (2 + peak_w)
632
+ + MJ
633
+ + H * (2 + bar_w)
634
+ + TR.replace(TR, MJ) # match corner with a cross-joint
635
+ )
636
+ bottom = BL + H * (len(header_line) - 2) + BR
637
+
638
+ print()
639
+ print(top)
640
+ print(header_line)
641
+ print(sep)
642
+
643
+ # body
644
+ for ph, tlabel, tbest, mlabel, mpeak in rows:
645
+ # pull means again for bar scaling; parse left side of "μ ± σ" if present
646
+ def _to_mean(s: str) -> float:
647
+ if s == "NA":
648
+ return 0.0
649
+ part = s.split("±")[0].strip()
650
+ try:
651
+ return float(part)
652
+ except Exception:
653
+ return 0.0
654
+
655
+ tmean = _to_mean(tlabel)
656
+ mmean = _to_mean(mlabel)
657
+
658
+ tbar = bar(tmean, tmax)
659
+ mbar = bar(mmean, mmax)
660
+
661
+ line = (
662
+ f"{V} {ph:<{phase_w}} {V} {tlabel:<{time_w}} {V} {tbest:>{best_w}} "
663
+ f"{V} {tbar} {V} {mlabel:<{mem_w}} {V} {mpeak:>{peak_w}} {V} {mbar} {V}"
664
+ )
665
+ print(line)
666
+
667
+ print(bottom)
668
+
669
+
670
+ def _fmt_stats(vals: list[float]) -> str:
671
+ """Legacy one-liner stats; kept for compatibility with older callers."""
672
+ if not vals:
673
+ return "NA"
674
+ if len(vals) == 1:
675
+ return f"{vals[0]:.3f}"
676
+ return f"mean={mean(vals):.3f} stdev={stdev(vals):.3f} n={len(vals)}"
677
+
678
+
679
+ def summarize(rows: list[dict], model_name: str) -> None:
680
+ """
681
+ Build per-phase arrays from JSONL rows and print the summary card.
682
+ Only successful (return_code==0) rows for the given model are included.
683
+ """
684
+ phases = ("compile", "witness", "prove", "verify")
685
+ tmap: dict[str, list[float]] = {p: [] for p in phases}
686
+ mmap: dict[str, list[float]] = {p: [] for p in phases}
687
+ for r in rows:
688
+ if r.get("model") == model_name and r.get("return_code") == 0:
689
+ if r.get("time_s") is not None:
690
+ tmap[r["phase"]].append(float(r["time_s"]))
691
+ if r.get("mem_mb") is not None:
692
+ mmap[r["phase"]].append(float(r["mem_mb"]))
693
+ _summary_card(model_name, tmap, mmap)
694
+
695
+
696
+ def main() -> int: # noqa: PLR0915, C901, PLR0912
697
+ """CLI entrypoint for the benchmark runner."""
698
+ ap = argparse.ArgumentParser(
699
+ description="Benchmark JSTProve by calling the CLI directly.",
700
+ )
701
+ ap.add_argument(
702
+ "--model",
703
+ required=True,
704
+ help="ONNX model name (e.g., 'lenet', where path to model is "
705
+ "python/models/models_onnx/lenet.onnx)",
706
+ )
707
+ ap.add_argument("--input", required=False, help="Path to input JSON (optional).")
708
+ ap.add_argument(
709
+ "--iterations",
710
+ type=int,
711
+ default=5,
712
+ help="Number of E2E loops (default: 5).",
713
+ )
714
+ ap.add_argument(
715
+ "--output",
716
+ default="results.jsonl",
717
+ help="JSONL to append per-run rows (default: results.jsonl)",
718
+ )
719
+ ap.add_argument(
720
+ "--summarize",
721
+ action="store_true",
722
+ help="Print per-phase summary card at the end.",
723
+ )
724
+ args = ap.parse_args()
725
+
726
+ model_path = Path(args.model).resolve()
727
+ param_count = count_onnx_parameters(model_path)
728
+ fixed_input = Path(args.input).resolve() if args.input else None
729
+ out_path = Path(args.output).resolve()
730
+
731
+ out_path.parent.mkdir(parents=True, exist_ok=True)
732
+ rows: list[dict] = []
733
+
734
+ try:
735
+ for it in range(1, args.iterations + 1):
736
+ with tempfile.TemporaryDirectory() as tmp_s:
737
+ tmp = Path(tmp_s)
738
+ io = PhaseIO(
739
+ model_path=model_path,
740
+ circuit_path=tmp / "circuit.txt",
741
+ input_path=fixed_input or (tmp / "input.json"),
742
+ output_path=tmp / "output.json",
743
+ witness_path=tmp / "witness.bin",
744
+ proof_path=tmp / "proof.bin",
745
+ )
746
+
747
+ for phase in ("compile", "witness", "prove", "verify"):
748
+ ts = now_utc()
749
+ rc, out, t, m, cmd, m_rust, m_psutil, ecc_live = run_cli(phase, io)
750
+
751
+ # ECC and artifact sizes are collected into the JSONL
752
+ # (not printed live)
753
+ artifact_sizes: dict[str, int | None] = {}
754
+ if phase == "compile":
755
+ circuit_size = file_size_bytes(io.circuit_path)
756
+ quantized_path = io.circuit_path.with_name(
757
+ f"{io.circuit_path.stem}_quantized_model.onnx",
758
+ )
759
+ quant_size = file_size_bytes(quantized_path)
760
+ artifact_sizes["circuit_size_bytes"] = circuit_size
761
+ artifact_sizes["quantized_size_bytes"] = quant_size
762
+ elif phase == "witness":
763
+ artifact_sizes["witness_size_bytes"] = file_size_bytes(
764
+ io.witness_path,
765
+ )
766
+ artifact_sizes["output_size_bytes"] = file_size_bytes(
767
+ io.output_path,
768
+ )
769
+ elif phase in ("prove", "verify"):
770
+ artifact_sizes["proof_size_bytes"] = file_size_bytes(
771
+ io.proof_path,
772
+ )
773
+
774
+ row = {
775
+ "timestamp": ts,
776
+ "model": str(model_path),
777
+ "iteration": it,
778
+ "phase": phase,
779
+ "return_code": rc,
780
+ "time_s": t,
781
+ "mem_mb": m,
782
+ "mem_mb_rust": m_rust,
783
+ "mem_mb_psutil": m_psutil,
784
+ "ecc": (ecc_live if ecc_live else {}),
785
+ "cmd": cmd,
786
+ "tmpdir": str(tmp),
787
+ "param_count": param_count,
788
+ **artifact_sizes,
789
+ }
790
+ rows.append(row)
791
+ with out_path.open("a", encoding="utf-8") as f:
792
+ f.write(json.dumps(row) + "\n")
793
+
794
+ # Guard: compile claimed success but circuit missing
795
+ if phase == "compile" and rc == 0:
796
+ if not io.circuit_path.exists():
797
+ log.error(
798
+ "[compile] rc=0 but circuit file missing: "
799
+ "%s\n----- compile output -----\n%s",
800
+ io.circuit_path,
801
+ out,
802
+ )
803
+ return 1
804
+ # Quantized is expected; warn (do not fail) if missing.
805
+ qpath = _quantized_path_from_circuit(io.circuit_path)
806
+ if not qpath.exists():
807
+ log.warning(
808
+ "[compile] expected quantized ONNX missing: %s",
809
+ qpath,
810
+ )
811
+
812
+ if rc != 0:
813
+ log.error("[%s] rc=%s — see logs below\n%s\n", phase, rc, out)
814
+
815
+ if t is not None:
816
+ mem_str = f"{m:.2f}" if m is not None else "NA"
817
+ log.info("[%s] t=%.3fs, mem=%s MB", phase, t, mem_str)
818
+ else:
819
+ log.info("[%s] metrics not parsed; rc=%s", phase, rc)
820
+
821
+ except KeyboardInterrupt:
822
+ log.info("\nCancelled by user (Ctrl+C).")
823
+ return 130
824
+ else:
825
+ log.info("")
826
+ log.info("✔ Wrote %d rows to %s", len(rows), out_path)
827
+ if args.summarize:
828
+ summarize(rows, str(model_path))
829
+ return 0
830
+
831
+
832
+ if __name__ == "__main__":
833
+ sys.exit(main())