JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +5 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,166 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import pkgutil
5
+ from collections import namedtuple
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ import types
11
+ from collections.abc import Callable
12
+
13
+
14
+ from python.core import circuit_models, circuits
15
+ from python.core.circuit_models.generic_onnx import GenericModelONNX
16
+ from python.core.circuits.base import Circuit
17
+
18
+ ModelEntry = namedtuple( # noqa: PYI024
19
+ "ModelEntry",
20
+ ["name", "source", "loader", "args", "kwargs"],
21
+ )
22
+
23
+
24
+ def scan_model_files(
25
+ directory: str,
26
+ extension: str,
27
+ loader_fn: Callable,
28
+ prefix: str,
29
+ ) -> list[ModelEntry]:
30
+ """Scan a directory for model files and return each discovered model in a callable.
31
+
32
+ Args:
33
+ directory (str): Path to the directory to scan.
34
+ extension (str): File extension to filter by (e.g., ".onnx").
35
+ loader_fn (Callable):
36
+ Loader function that can instantiate a model from a file path.
37
+ prefix (str): Source prefix used to categorize models (e.g., "onnx").
38
+
39
+ Returns:
40
+ list[ModelEntry]: A list of ModelEntry objects representing discovered models.
41
+ """
42
+ _ = extension
43
+ entries = []
44
+ for item in Path(directory).iterdir():
45
+ file_or_foldername = item.name
46
+ if prefix == "onnx" and (
47
+ item.is_file() and file_or_foldername.endswith(".onnx")
48
+ ):
49
+ name = file_or_foldername[:-5]
50
+ path = str(item)
51
+ entries.append(
52
+ ModelEntry(
53
+ name=f"{name}",
54
+ source=prefix,
55
+ loader=lambda p=path: loader_fn(p),
56
+ args=(),
57
+ kwargs={},
58
+ ),
59
+ )
60
+ return entries
61
+
62
+
63
+ def import_all_submodules(package: types.ModuleType) -> None:
64
+ """Import all submodules of a given package.
65
+
66
+ Args:
67
+ package (types.ModuleType): The Python package to import submodules from.
68
+ """
69
+ for _, name, _ in pkgutil.walk_packages(
70
+ package.__path__,
71
+ package.__name__ + ".",
72
+ ):
73
+ importlib.import_module(name)
74
+
75
+
76
+ # Import all submodules so their classes are registered
77
+ import_all_submodules(circuit_models)
78
+ import_all_submodules(circuits)
79
+
80
+
81
+ def onnx_test_loader(path: str) -> GenericModelONNX:
82
+ return GenericModelONNX(path, use_find_model=True)
83
+
84
+
85
+ def all_subclasses(cls) -> set: # noqa: ANN001
86
+ """Recursively find all subclasses of a given class.
87
+
88
+ Args:
89
+ cls (type): The base class to search subclasses for.
90
+
91
+ Returns:
92
+ _type_: A set of all subclasses of the given class.
93
+ """
94
+ subclasses = set(cls.__subclasses__())
95
+ return subclasses.union(s for c in subclasses for s in all_subclasses(c))
96
+
97
+
98
+ def build_models_to_test() -> list[ModelEntry]:
99
+ """Build a list of model entries to be tested.
100
+
101
+ - Collects all subclasses of the Circuit base class.
102
+ - Filters out unwanted base or placeholder models.
103
+ - Adds discovered ONNX models from the models directory.
104
+
105
+ Returns:
106
+ list[ModelEntry]: A list of model entries to be used in tests.
107
+ """
108
+ models = []
109
+ for cls in all_subclasses(Circuit):
110
+ name = cls.__name__.lower()
111
+ models.append(
112
+ ModelEntry(name=name, source="class", loader=cls, args=(), kwargs={}),
113
+ )
114
+ # Filter unwanted class models
115
+ models = [
116
+ m
117
+ for m in models
118
+ if m.name not in {"zkmodel", "genericmodelonnx", "zktorchmodel", "zkmodelbase"}
119
+ ]
120
+ # Add ONNX models
121
+ models += scan_model_files(
122
+ "python/models/models_onnx",
123
+ ".onnx",
124
+ onnx_test_loader,
125
+ "onnx",
126
+ )
127
+ return models
128
+
129
+
130
+ MODELS_TO_TEST = build_models_to_test()
131
+
132
+
133
+ def list_available_models() -> list[str]:
134
+ """list all available models in a human-readable format.
135
+
136
+ Returns:
137
+ list[str]: A sorted list of strings in the form "source: model_name".
138
+ """
139
+ return sorted(f"{model.source}: {model.name}" for model in MODELS_TO_TEST)
140
+
141
+
142
+ def get_models_to_test(
143
+ selected_models: list[str] | None = None,
144
+ source_filter: str | None = None,
145
+ ) -> list[ModelEntry]:
146
+ """Retrieve models to be tested with optional filtering.
147
+
148
+ Args:
149
+ selected_models (list[str], optional):
150
+ A list of model names to include. Defaults to None.
151
+ source_filter (str, optional):
152
+ Restrict models to a specific source (e.g., "onnx", "class").
153
+ Defaults to None.
154
+
155
+ Returns:
156
+ list[ModelEntry]: A filtered list of model entries.
157
+ """
158
+ models = MODELS_TO_TEST
159
+
160
+ if selected_models is not None:
161
+ models = [m for m in models if m.name in selected_models]
162
+
163
+ if source_filter is not None:
164
+ models = [m for m in models if m.source == source_filter]
165
+
166
+ return models
@@ -0,0 +1,66 @@
1
+ import onnx
2
+ from onnx import TensorProto, helper, shape_inference
3
+ from onnx import numpy_helper
4
+ from onnx import load, save
5
+ from onnx.utils import extract_model
6
+
7
+ def prune_model(model_path, output_names, save_path):
8
+ model = load(model_path)
9
+
10
+ # Provide model input names and the new desired output names
11
+ input_names = [i.name for i in model.graph.input]
12
+
13
+ extract_model(
14
+ input_path=model_path,
15
+ output_path=save_path,
16
+ input_names=input_names,
17
+ output_names=output_names
18
+ )
19
+
20
+ print(f"Pruned model saved to {save_path}")
21
+
22
+
23
+ def cut_model(model_path, output_names, save_path):
24
+ model = onnx.load(model_path)
25
+ model = shape_inference.infer_shapes(model)
26
+
27
+ graph = model.graph
28
+
29
+ # Remove all current outputs one by one (cannot use .clear() or assignment)
30
+ while len(graph.output) > 0:
31
+ graph.output.pop()
32
+
33
+ # Add new outputs
34
+ for name in output_names:
35
+ # Look in value_info, input, or output
36
+ candidates = list(graph.value_info) + list(graph.input) + list(graph.output)
37
+ value_info = next((vi for vi in candidates if vi.name == name), None)
38
+ if value_info is None:
39
+ raise ValueError(f"Tensor {name} not found in model graph.")
40
+
41
+ elem_type = value_info.type.tensor_type.elem_type
42
+ shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
43
+ new_output = helper.make_tensor_value_info(name, elem_type, shape)
44
+ graph.output.append(new_output)
45
+ for output in graph.output:
46
+ print(output)
47
+ if output.name == "/conv1/Conv_output_0":
48
+ output.type.tensor_type.elem_type = TensorProto.INT64
49
+
50
+ onnx.save(model, save_path)
51
+ print(f"Saved cut model with outputs {output_names} to {save_path}")
52
+
53
+
54
+ if __name__ == "__main__":
55
+ # /conv1/Conv_output_0
56
+ # prune_model(
57
+ # model_path="models_onnx/doom.onnx",
58
+ # output_names=["/Relu_2_output_0"], # replace with your intermediate tensor
59
+ # save_path= "models_onnx/test_doom_cut.onnx"
60
+ # )
61
+ # cut_model("models_onnx/doom.onnx",["/Relu_2_output_0"], "test_doom_after_conv.onnx")
62
+ prune_model(
63
+ model_path="models_onnx/doom.onnx",
64
+ output_names=["/Relu_3_output_0"], # replace with your intermediate tensor
65
+ save_path= "models_onnx/test_doom_cut.onnx"
66
+ )
@@ -0,0 +1,291 @@
1
+ from __future__ import annotations
2
+
3
+ import struct
4
+ from abc import ABC, abstractmethod
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, BinaryIO
7
+
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Callable
10
+
11
+ from python.core.utils.errors import ProofSystemNotImplementedError
12
+ from python.core.utils.helper_functions import ZKProofSystems
13
+
14
+
15
+ # -------------------------
16
+ # Base Witness Loader
17
+ # -------------------------
18
+ class WitnessLoader(ABC):
19
+ def __init__(self: WitnessLoader, path: str) -> None:
20
+ self.path = path
21
+
22
+ @abstractmethod
23
+ def load_witness(self: WitnessLoader) -> dict:
24
+ """Load witness data from file."""
25
+
26
+ @abstractmethod
27
+ def compare_witness_to_io(
28
+ self: WitnessLoader,
29
+ witnesses: dict,
30
+ expected_inputs: dict,
31
+ expected_outputs: dict,
32
+ modulus: int,
33
+ ) -> bool:
34
+ """Compare witness to expected I/O."""
35
+
36
+
37
+ def read_usize(f: BinaryIO, usize_len: int | None = 8) -> int:
38
+ """
39
+ Read an unsigned integer of size `usize_len` bytes from a binary file object.
40
+
41
+ Args:
42
+ f (BinaryIO): Opened file in binary mode.
43
+ usize_len (int, optional): Number of bytes to read
44
+ (default is 8, for 64-bit systems).
45
+
46
+ Returns:
47
+ int: The unpacked unsigned integer.
48
+ """
49
+ return struct.unpack("<Q", f.read(usize_len))[0]
50
+
51
+
52
+ def read_u256(f: BinaryIO) -> int:
53
+ """
54
+ Read a 256-bit unsigned integer (U256) from a binary file in little-endian format.
55
+
56
+ Args:
57
+ f (BinaryIO): Opened file in binary mode.
58
+
59
+ Returns:
60
+ int: The 256-bit integer.
61
+ """
62
+ return int.from_bytes(f.read(32), "little")
63
+
64
+
65
+ def read_field_elements(f: BinaryIO, count: int) -> list[int]:
66
+ """
67
+ Read a sequence of 32-byte field elements from a binary file.
68
+
69
+ Args:
70
+ f (BinaryIO): Opened file in binary mode.
71
+ count (int): Number of 32-byte elements to read.
72
+
73
+ Returns:
74
+ list[int]: List of integers representing the field elements.
75
+ """
76
+ return [read_u256(f) for _ in range(count)]
77
+
78
+
79
+ def to_field_repr(value: int, modulus: int) -> int:
80
+ """
81
+ Convert a signed integer to its field representation modulo `modulus`.
82
+
83
+ Args:
84
+ value (int): Integer to convert.
85
+ modulus (int): Field modulus.
86
+
87
+ Returns:
88
+ int: Least field representation of the integer, ensuring a non-negative result.
89
+ """
90
+ return value % modulus
91
+
92
+
93
+ class ExpanderWitnessLoader(WitnessLoader):
94
+ def load_witness(self: ExpanderWitnessLoader) -> dict:
95
+ """
96
+ Load witness data from a binary file and return it in structured form.
97
+
98
+ Returns:
99
+ dict: Dictionary containing:
100
+ - num_witnesses (int)
101
+ - num_inputs_per_witness (int)
102
+ - num_public_inputs_per_witness (int)
103
+ - modulus (int)
104
+ - witnesses (list of dicts with 'inputs' and 'public_inputs')
105
+ """
106
+ path = self.path
107
+ with Path(path).open("rb") as f:
108
+ num_witnesses = read_usize(f)
109
+ num_inputs = read_usize(f)
110
+ num_public_inputs = read_usize(f)
111
+ modulus = read_u256(f)
112
+
113
+ total = num_witnesses * (num_inputs + num_public_inputs)
114
+ values = read_field_elements(f, total)
115
+
116
+ # Reshape into witnesses
117
+ witnesses = []
118
+ offset = 0
119
+ for _ in range(num_witnesses):
120
+ inputs = values[offset : offset + num_inputs]
121
+ public_inputs = values[
122
+ offset + num_inputs : offset + num_inputs + num_public_inputs
123
+ ]
124
+ witnesses.append({"inputs": inputs, "public_inputs": public_inputs})
125
+ offset += num_inputs + num_public_inputs
126
+
127
+ return {
128
+ "num_witnesses": num_witnesses,
129
+ "num_inputs_per_witness": num_inputs,
130
+ "num_public_inputs_per_witness": num_public_inputs,
131
+ "modulus": modulus,
132
+ "witnesses": witnesses,
133
+ }
134
+
135
+ def compare_witness_to_io(
136
+ self: ExpanderWitnessLoader,
137
+ witnesses: dict,
138
+ expected_inputs: dict,
139
+ expected_outputs: dict,
140
+ modulus: int,
141
+ scaling_function: Callable[[list[int], int, int], list[int]] | None = None,
142
+ ) -> bool:
143
+ """
144
+ Compare the public inputs of the first witness
145
+ against expected inputs and outputs.
146
+
147
+ Accounts for negative numbers by representing them in the field as
148
+ `modulus - abs(value)`.
149
+
150
+ Args:
151
+ witnesses (dict): Witness data as returned by `load_witness`.
152
+ expected_inputs (dict):
153
+ Dictionary containing key "input"
154
+ mapping to a list of expected input integers.
155
+ expected_outputs (dict):
156
+ Dictionary containing key "output"
157
+ mapping to a list of expected output integers.
158
+ modulus (int): Field modulus.
159
+ scaling_function
160
+ (Callable[[list[int], int, int], list[int]] | None, optional):
161
+ Optional scaling function to apply to inputs.
162
+ Takes (inputs, scale_base, scale_exponent)
163
+ and returns scaled inputs. Defaults to None.
164
+
165
+ Returns:
166
+ bool:
167
+ True if the witness public inputs match the expected inputs and outputs,
168
+ False otherwise.
169
+ """
170
+
171
+ import torch # noqa: PLC0415
172
+
173
+ scale_base = witnesses["witnesses"][0]["public_inputs"][-2]
174
+ scale_exponent = witnesses["witnesses"][0]["public_inputs"][-1]
175
+
176
+ # Convert expectations into field form
177
+ inputs_list = expected_inputs.get("input", [])
178
+
179
+ if callable(scaling_function):
180
+ inputs_list = (
181
+ torch.tensor(scaling_function(inputs_list, scale_base, scale_exponent))
182
+ .flatten()
183
+ .tolist()
184
+ )
185
+ else:
186
+ inputs_list = (
187
+ torch.round(
188
+ torch.tensor(inputs_list) * (scale_base**scale_exponent),
189
+ )
190
+ .long()
191
+ .tolist()
192
+ )
193
+ outputs_list = expected_outputs.get("output", [])
194
+
195
+ expected_inputs_mod = [
196
+ to_field_repr(v, modulus)
197
+ for v in torch.tensor(inputs_list).flatten().tolist()
198
+ ]
199
+ expected_outputs_mod = [to_field_repr(v, modulus) for v in outputs_list]
200
+
201
+ n_inputs = len(expected_inputs_mod) + len(expected_outputs_mod) + 2
202
+
203
+ witness = witnesses["witnesses"][0]["public_inputs"]
204
+
205
+ if n_inputs != len(witness):
206
+ return False
207
+
208
+ actual_inputs = witness[: len(expected_inputs_mod)]
209
+ actual_outputs = witness[len(expected_inputs_mod) : -2]
210
+
211
+ # Compare
212
+ return (
213
+ actual_inputs == expected_inputs_mod
214
+ and actual_outputs == expected_outputs_mod
215
+ )
216
+
217
+
218
+ # -------------------------
219
+ # Factory
220
+ # -------------------------
221
+ def get_loader(system: ZKProofSystems, path: str) -> WitnessLoader:
222
+ if system == ZKProofSystems.Expander:
223
+ return ExpanderWitnessLoader(path)
224
+ msg = f"No loader implemented for {system}"
225
+ raise ProofSystemNotImplementedError(msg)
226
+
227
+
228
+ # -------------------------
229
+ # Public API
230
+ # -------------------------
231
+ def load_witness(path: str, system: ZKProofSystems = ZKProofSystems.Expander) -> dict:
232
+ loader = get_loader(system, path)
233
+ return loader.load_witness()
234
+
235
+
236
+ def compare_witness_to_io( # noqa: PLR0913
237
+ witnesses: dict,
238
+ expected_inputs: dict,
239
+ expected_outputs: dict,
240
+ modulus: int,
241
+ system: ZKProofSystems = ZKProofSystems.Expander,
242
+ scaling_function: Callable[[list[int], int, int], list[int]] | None = None,
243
+ ) -> bool:
244
+ """
245
+ Compare witness data to expected inputs and outputs for a given ZK proof system.
246
+
247
+ Args:
248
+ witnesses (dict): Witness data as returned by `load_witness`.
249
+ expected_inputs (dict):
250
+ Dictionary containing key "input" mapping to list of expected integers.
251
+ expected_outputs (dict):
252
+ Dictionary containing key "output" mapping to list of expected integers.
253
+ modulus (int): Field modulus.
254
+ system (ZKProofSystems, optional):
255
+ The ZK proof system. Defaults to ZKProofSystems.Expander.
256
+ scaling_function (Callable[[list[int], int, int], list[int]] | None, optional):
257
+ Optional scaling function to apply to inputs.
258
+ Takes (inputs, scale_base, scale_exponent) and returns scaled inputs.
259
+ Defaults to None.
260
+
261
+ Returns:
262
+ bool: True if the witness matches the expected I/O, False otherwise.
263
+ """
264
+ loader = get_loader(system, "") # path not needed for comparison
265
+ return loader.compare_witness_to_io(
266
+ witnesses,
267
+ expected_inputs,
268
+ expected_outputs,
269
+ modulus,
270
+ scaling_function,
271
+ )
272
+
273
+
274
+ if __name__ == "__main__":
275
+ import time
276
+
277
+ start_time = time.time()
278
+ w = load_witness("./artifacts/lenet/witness.bin", ZKProofSystems.Expander)
279
+ end_time = time.time()
280
+ print("Modulus:", w["modulus"]) # noqa: T201
281
+ print("First witness inputs:", w["witnesses"][0]["inputs"][0]) # noqa: T201
282
+ print( # noqa: T201
283
+ "First witness public inputs:",
284
+ w["witnesses"][0]["public_inputs"][0],
285
+ )
286
+
287
+ print(len(w["witnesses"][0]["public_inputs"])) # noqa: T201
288
+ print((w["witnesses"][0]["public_inputs"][0] - w["modulus"]) / 2**18) # noqa: T201
289
+ elapsed = end_time - start_time
290
+
291
+ print("time taken: ", elapsed) # noqa: T201
File without changes
python/frontend/cli.py ADDED
@@ -0,0 +1,115 @@
1
+ # python/frontend/cli.py
2
+ """JSTprove CLI."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import os
8
+ import sys
9
+ from typing import TYPE_CHECKING
10
+
11
+ if TYPE_CHECKING:
12
+ from python.frontend.commands import BaseCommand
13
+
14
+ from python.frontend.commands import (
15
+ BenchCommand,
16
+ CompileCommand,
17
+ ModelCheckCommand,
18
+ ProveCommand,
19
+ VerifyCommand,
20
+ WitnessCommand,
21
+ )
22
+ from python.frontend.commands.base import HiddenPositionalHelpFormatter
23
+
24
+ BANNER_TITLE = r"""
25
+ 888888 .d8888b. 88888888888
26
+ "88b d88P Y88b 888
27
+ 888 Y88b. 888
28
+ 888 "Y888b. 888 88888b. 888d888 .d88b. 888 888 .d88b.
29
+ 888 "Y88b. 888 888 "88b 888P" d88""88b 888 888 d8P Y8b
30
+ 888 "888 888 888 888 888 888 888 Y88 88P 88888888
31
+ 88P Y88b d88P 888 888 d88P 888 Y88..88P Y8bd8P Y8b.
32
+ 888 "Y8888P" 888 88888P" 888 "Y88P" Y88P "Y8888
33
+ .d88P 888
34
+ .d88P" 888
35
+ 888P" 888
36
+ """
37
+
38
+ COMMANDS: list[type[BaseCommand]] = [
39
+ ModelCheckCommand,
40
+ CompileCommand,
41
+ WitnessCommand,
42
+ ProveCommand,
43
+ VerifyCommand,
44
+ BenchCommand,
45
+ ]
46
+
47
+
48
+ def print_header() -> None:
49
+ """Print the CLI banner (no side-effects at import time)."""
50
+ print( # noqa: T201
51
+ BANNER_TITLE
52
+ + "\n"
53
+ + "JSTprove — Verifiable ML by Inference Labs\n"
54
+ + "Based on Polyhedra Network's Expander (GKR-based proving system)\n",
55
+ )
56
+
57
+
58
+ def main(argv: list[str] | None = None) -> int:
59
+ """
60
+ Entry point for the JSTprove CLI.
61
+
62
+ Returns:
63
+ 0 on success, 1 on error.
64
+ """
65
+ argv = sys.argv[1:] if argv is None else argv
66
+
67
+ parser = argparse.ArgumentParser(
68
+ prog="jst",
69
+ description="ZKML CLI (compile, witness, prove, verify).",
70
+ allow_abbrev=False,
71
+ )
72
+ parser.add_argument(
73
+ "--no-banner",
74
+ action="store_true",
75
+ help="Suppress the startup banner.",
76
+ )
77
+
78
+ subparsers = parser.add_subparsers(dest="cmd", required=True)
79
+
80
+ command_map = {}
81
+ for command_cls in COMMANDS:
82
+ cmd_parser = subparsers.add_parser(
83
+ command_cls.name,
84
+ aliases=command_cls.aliases,
85
+ help=command_cls.help,
86
+ allow_abbrev=False,
87
+ formatter_class=HiddenPositionalHelpFormatter,
88
+ )
89
+ command_cls.configure_parser(cmd_parser)
90
+ command_map[command_cls.name] = command_cls
91
+ for alias in command_cls.aliases:
92
+ command_map[alias] = command_cls
93
+
94
+ args = parser.parse_args(argv)
95
+
96
+ if not args.no_banner and not os.environ.get("JSTPROVE_NO_BANNER"):
97
+ print_header()
98
+
99
+ try:
100
+ command_cls = command_map[args.cmd]
101
+ command_cls.run(args)
102
+ except (ValueError, FileNotFoundError, PermissionError, RuntimeError) as e:
103
+ print(f"Error: {e}", file=sys.stderr) # noqa: T201
104
+ return 1
105
+ except SystemExit:
106
+ raise
107
+ except Exception as e:
108
+ print(f"Error: {e}", file=sys.stderr) # noqa: T201
109
+ return 1
110
+
111
+ return 0
112
+
113
+
114
+ if __name__ == "__main__":
115
+ raise SystemExit(main())
@@ -0,0 +1,17 @@
1
+ from python.frontend.commands.base import BaseCommand
2
+ from python.frontend.commands.bench import BenchCommand
3
+ from python.frontend.commands.compile import CompileCommand
4
+ from python.frontend.commands.model_check import ModelCheckCommand
5
+ from python.frontend.commands.prove import ProveCommand
6
+ from python.frontend.commands.verify import VerifyCommand
7
+ from python.frontend.commands.witness import WitnessCommand
8
+
9
+ __all__ = [
10
+ "BaseCommand",
11
+ "BenchCommand",
12
+ "CompileCommand",
13
+ "ModelCheckCommand",
14
+ "ProveCommand",
15
+ "VerifyCommand",
16
+ "WitnessCommand",
17
+ ]