JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +5 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import pkgutil
|
|
5
|
+
from collections import namedtuple
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import types
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from python.core import circuit_models, circuits
|
|
15
|
+
from python.core.circuit_models.generic_onnx import GenericModelONNX
|
|
16
|
+
from python.core.circuits.base import Circuit
|
|
17
|
+
|
|
18
|
+
ModelEntry = namedtuple( # noqa: PYI024
|
|
19
|
+
"ModelEntry",
|
|
20
|
+
["name", "source", "loader", "args", "kwargs"],
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def scan_model_files(
|
|
25
|
+
directory: str,
|
|
26
|
+
extension: str,
|
|
27
|
+
loader_fn: Callable,
|
|
28
|
+
prefix: str,
|
|
29
|
+
) -> list[ModelEntry]:
|
|
30
|
+
"""Scan a directory for model files and return each discovered model in a callable.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
directory (str): Path to the directory to scan.
|
|
34
|
+
extension (str): File extension to filter by (e.g., ".onnx").
|
|
35
|
+
loader_fn (Callable):
|
|
36
|
+
Loader function that can instantiate a model from a file path.
|
|
37
|
+
prefix (str): Source prefix used to categorize models (e.g., "onnx").
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
list[ModelEntry]: A list of ModelEntry objects representing discovered models.
|
|
41
|
+
"""
|
|
42
|
+
_ = extension
|
|
43
|
+
entries = []
|
|
44
|
+
for item in Path(directory).iterdir():
|
|
45
|
+
file_or_foldername = item.name
|
|
46
|
+
if prefix == "onnx" and (
|
|
47
|
+
item.is_file() and file_or_foldername.endswith(".onnx")
|
|
48
|
+
):
|
|
49
|
+
name = file_or_foldername[:-5]
|
|
50
|
+
path = str(item)
|
|
51
|
+
entries.append(
|
|
52
|
+
ModelEntry(
|
|
53
|
+
name=f"{name}",
|
|
54
|
+
source=prefix,
|
|
55
|
+
loader=lambda p=path: loader_fn(p),
|
|
56
|
+
args=(),
|
|
57
|
+
kwargs={},
|
|
58
|
+
),
|
|
59
|
+
)
|
|
60
|
+
return entries
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def import_all_submodules(package: types.ModuleType) -> None:
|
|
64
|
+
"""Import all submodules of a given package.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
package (types.ModuleType): The Python package to import submodules from.
|
|
68
|
+
"""
|
|
69
|
+
for _, name, _ in pkgutil.walk_packages(
|
|
70
|
+
package.__path__,
|
|
71
|
+
package.__name__ + ".",
|
|
72
|
+
):
|
|
73
|
+
importlib.import_module(name)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# Import all submodules so their classes are registered
|
|
77
|
+
import_all_submodules(circuit_models)
|
|
78
|
+
import_all_submodules(circuits)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def onnx_test_loader(path: str) -> GenericModelONNX:
|
|
82
|
+
return GenericModelONNX(path, use_find_model=True)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def all_subclasses(cls) -> set: # noqa: ANN001
|
|
86
|
+
"""Recursively find all subclasses of a given class.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
cls (type): The base class to search subclasses for.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
_type_: A set of all subclasses of the given class.
|
|
93
|
+
"""
|
|
94
|
+
subclasses = set(cls.__subclasses__())
|
|
95
|
+
return subclasses.union(s for c in subclasses for s in all_subclasses(c))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def build_models_to_test() -> list[ModelEntry]:
|
|
99
|
+
"""Build a list of model entries to be tested.
|
|
100
|
+
|
|
101
|
+
- Collects all subclasses of the Circuit base class.
|
|
102
|
+
- Filters out unwanted base or placeholder models.
|
|
103
|
+
- Adds discovered ONNX models from the models directory.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
list[ModelEntry]: A list of model entries to be used in tests.
|
|
107
|
+
"""
|
|
108
|
+
models = []
|
|
109
|
+
for cls in all_subclasses(Circuit):
|
|
110
|
+
name = cls.__name__.lower()
|
|
111
|
+
models.append(
|
|
112
|
+
ModelEntry(name=name, source="class", loader=cls, args=(), kwargs={}),
|
|
113
|
+
)
|
|
114
|
+
# Filter unwanted class models
|
|
115
|
+
models = [
|
|
116
|
+
m
|
|
117
|
+
for m in models
|
|
118
|
+
if m.name not in {"zkmodel", "genericmodelonnx", "zktorchmodel", "zkmodelbase"}
|
|
119
|
+
]
|
|
120
|
+
# Add ONNX models
|
|
121
|
+
models += scan_model_files(
|
|
122
|
+
"python/models/models_onnx",
|
|
123
|
+
".onnx",
|
|
124
|
+
onnx_test_loader,
|
|
125
|
+
"onnx",
|
|
126
|
+
)
|
|
127
|
+
return models
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
MODELS_TO_TEST = build_models_to_test()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def list_available_models() -> list[str]:
|
|
134
|
+
"""list all available models in a human-readable format.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
list[str]: A sorted list of strings in the form "source: model_name".
|
|
138
|
+
"""
|
|
139
|
+
return sorted(f"{model.source}: {model.name}" for model in MODELS_TO_TEST)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_models_to_test(
|
|
143
|
+
selected_models: list[str] | None = None,
|
|
144
|
+
source_filter: str | None = None,
|
|
145
|
+
) -> list[ModelEntry]:
|
|
146
|
+
"""Retrieve models to be tested with optional filtering.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
selected_models (list[str], optional):
|
|
150
|
+
A list of model names to include. Defaults to None.
|
|
151
|
+
source_filter (str, optional):
|
|
152
|
+
Restrict models to a specific source (e.g., "onnx", "class").
|
|
153
|
+
Defaults to None.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
list[ModelEntry]: A filtered list of model entries.
|
|
157
|
+
"""
|
|
158
|
+
models = MODELS_TO_TEST
|
|
159
|
+
|
|
160
|
+
if selected_models is not None:
|
|
161
|
+
models = [m for m in models if m.name in selected_models]
|
|
162
|
+
|
|
163
|
+
if source_filter is not None:
|
|
164
|
+
models = [m for m in models if m.source == source_filter]
|
|
165
|
+
|
|
166
|
+
return models
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import onnx
|
|
2
|
+
from onnx import TensorProto, helper, shape_inference
|
|
3
|
+
from onnx import numpy_helper
|
|
4
|
+
from onnx import load, save
|
|
5
|
+
from onnx.utils import extract_model
|
|
6
|
+
|
|
7
|
+
def prune_model(model_path, output_names, save_path):
|
|
8
|
+
model = load(model_path)
|
|
9
|
+
|
|
10
|
+
# Provide model input names and the new desired output names
|
|
11
|
+
input_names = [i.name for i in model.graph.input]
|
|
12
|
+
|
|
13
|
+
extract_model(
|
|
14
|
+
input_path=model_path,
|
|
15
|
+
output_path=save_path,
|
|
16
|
+
input_names=input_names,
|
|
17
|
+
output_names=output_names
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
print(f"Pruned model saved to {save_path}")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def cut_model(model_path, output_names, save_path):
|
|
24
|
+
model = onnx.load(model_path)
|
|
25
|
+
model = shape_inference.infer_shapes(model)
|
|
26
|
+
|
|
27
|
+
graph = model.graph
|
|
28
|
+
|
|
29
|
+
# Remove all current outputs one by one (cannot use .clear() or assignment)
|
|
30
|
+
while len(graph.output) > 0:
|
|
31
|
+
graph.output.pop()
|
|
32
|
+
|
|
33
|
+
# Add new outputs
|
|
34
|
+
for name in output_names:
|
|
35
|
+
# Look in value_info, input, or output
|
|
36
|
+
candidates = list(graph.value_info) + list(graph.input) + list(graph.output)
|
|
37
|
+
value_info = next((vi for vi in candidates if vi.name == name), None)
|
|
38
|
+
if value_info is None:
|
|
39
|
+
raise ValueError(f"Tensor {name} not found in model graph.")
|
|
40
|
+
|
|
41
|
+
elem_type = value_info.type.tensor_type.elem_type
|
|
42
|
+
shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
|
|
43
|
+
new_output = helper.make_tensor_value_info(name, elem_type, shape)
|
|
44
|
+
graph.output.append(new_output)
|
|
45
|
+
for output in graph.output:
|
|
46
|
+
print(output)
|
|
47
|
+
if output.name == "/conv1/Conv_output_0":
|
|
48
|
+
output.type.tensor_type.elem_type = TensorProto.INT64
|
|
49
|
+
|
|
50
|
+
onnx.save(model, save_path)
|
|
51
|
+
print(f"Saved cut model with outputs {output_names} to {save_path}")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
if __name__ == "__main__":
|
|
55
|
+
# /conv1/Conv_output_0
|
|
56
|
+
# prune_model(
|
|
57
|
+
# model_path="models_onnx/doom.onnx",
|
|
58
|
+
# output_names=["/Relu_2_output_0"], # replace with your intermediate tensor
|
|
59
|
+
# save_path= "models_onnx/test_doom_cut.onnx"
|
|
60
|
+
# )
|
|
61
|
+
# cut_model("models_onnx/doom.onnx",["/Relu_2_output_0"], "test_doom_after_conv.onnx")
|
|
62
|
+
prune_model(
|
|
63
|
+
model_path="models_onnx/doom.onnx",
|
|
64
|
+
output_names=["/Relu_3_output_0"], # replace with your intermediate tensor
|
|
65
|
+
save_path= "models_onnx/test_doom_cut.onnx"
|
|
66
|
+
)
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import struct
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, BinaryIO
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
|
|
11
|
+
from python.core.utils.errors import ProofSystemNotImplementedError
|
|
12
|
+
from python.core.utils.helper_functions import ZKProofSystems
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# -------------------------
|
|
16
|
+
# Base Witness Loader
|
|
17
|
+
# -------------------------
|
|
18
|
+
class WitnessLoader(ABC):
|
|
19
|
+
def __init__(self: WitnessLoader, path: str) -> None:
|
|
20
|
+
self.path = path
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def load_witness(self: WitnessLoader) -> dict:
|
|
24
|
+
"""Load witness data from file."""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def compare_witness_to_io(
|
|
28
|
+
self: WitnessLoader,
|
|
29
|
+
witnesses: dict,
|
|
30
|
+
expected_inputs: dict,
|
|
31
|
+
expected_outputs: dict,
|
|
32
|
+
modulus: int,
|
|
33
|
+
) -> bool:
|
|
34
|
+
"""Compare witness to expected I/O."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def read_usize(f: BinaryIO, usize_len: int | None = 8) -> int:
|
|
38
|
+
"""
|
|
39
|
+
Read an unsigned integer of size `usize_len` bytes from a binary file object.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
f (BinaryIO): Opened file in binary mode.
|
|
43
|
+
usize_len (int, optional): Number of bytes to read
|
|
44
|
+
(default is 8, for 64-bit systems).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
int: The unpacked unsigned integer.
|
|
48
|
+
"""
|
|
49
|
+
return struct.unpack("<Q", f.read(usize_len))[0]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def read_u256(f: BinaryIO) -> int:
|
|
53
|
+
"""
|
|
54
|
+
Read a 256-bit unsigned integer (U256) from a binary file in little-endian format.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
f (BinaryIO): Opened file in binary mode.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
int: The 256-bit integer.
|
|
61
|
+
"""
|
|
62
|
+
return int.from_bytes(f.read(32), "little")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def read_field_elements(f: BinaryIO, count: int) -> list[int]:
|
|
66
|
+
"""
|
|
67
|
+
Read a sequence of 32-byte field elements from a binary file.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
f (BinaryIO): Opened file in binary mode.
|
|
71
|
+
count (int): Number of 32-byte elements to read.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
list[int]: List of integers representing the field elements.
|
|
75
|
+
"""
|
|
76
|
+
return [read_u256(f) for _ in range(count)]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def to_field_repr(value: int, modulus: int) -> int:
|
|
80
|
+
"""
|
|
81
|
+
Convert a signed integer to its field representation modulo `modulus`.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
value (int): Integer to convert.
|
|
85
|
+
modulus (int): Field modulus.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
int: Least field representation of the integer, ensuring a non-negative result.
|
|
89
|
+
"""
|
|
90
|
+
return value % modulus
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class ExpanderWitnessLoader(WitnessLoader):
|
|
94
|
+
def load_witness(self: ExpanderWitnessLoader) -> dict:
|
|
95
|
+
"""
|
|
96
|
+
Load witness data from a binary file and return it in structured form.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
dict: Dictionary containing:
|
|
100
|
+
- num_witnesses (int)
|
|
101
|
+
- num_inputs_per_witness (int)
|
|
102
|
+
- num_public_inputs_per_witness (int)
|
|
103
|
+
- modulus (int)
|
|
104
|
+
- witnesses (list of dicts with 'inputs' and 'public_inputs')
|
|
105
|
+
"""
|
|
106
|
+
path = self.path
|
|
107
|
+
with Path(path).open("rb") as f:
|
|
108
|
+
num_witnesses = read_usize(f)
|
|
109
|
+
num_inputs = read_usize(f)
|
|
110
|
+
num_public_inputs = read_usize(f)
|
|
111
|
+
modulus = read_u256(f)
|
|
112
|
+
|
|
113
|
+
total = num_witnesses * (num_inputs + num_public_inputs)
|
|
114
|
+
values = read_field_elements(f, total)
|
|
115
|
+
|
|
116
|
+
# Reshape into witnesses
|
|
117
|
+
witnesses = []
|
|
118
|
+
offset = 0
|
|
119
|
+
for _ in range(num_witnesses):
|
|
120
|
+
inputs = values[offset : offset + num_inputs]
|
|
121
|
+
public_inputs = values[
|
|
122
|
+
offset + num_inputs : offset + num_inputs + num_public_inputs
|
|
123
|
+
]
|
|
124
|
+
witnesses.append({"inputs": inputs, "public_inputs": public_inputs})
|
|
125
|
+
offset += num_inputs + num_public_inputs
|
|
126
|
+
|
|
127
|
+
return {
|
|
128
|
+
"num_witnesses": num_witnesses,
|
|
129
|
+
"num_inputs_per_witness": num_inputs,
|
|
130
|
+
"num_public_inputs_per_witness": num_public_inputs,
|
|
131
|
+
"modulus": modulus,
|
|
132
|
+
"witnesses": witnesses,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
def compare_witness_to_io(
|
|
136
|
+
self: ExpanderWitnessLoader,
|
|
137
|
+
witnesses: dict,
|
|
138
|
+
expected_inputs: dict,
|
|
139
|
+
expected_outputs: dict,
|
|
140
|
+
modulus: int,
|
|
141
|
+
scaling_function: Callable[[list[int], int, int], list[int]] | None = None,
|
|
142
|
+
) -> bool:
|
|
143
|
+
"""
|
|
144
|
+
Compare the public inputs of the first witness
|
|
145
|
+
against expected inputs and outputs.
|
|
146
|
+
|
|
147
|
+
Accounts for negative numbers by representing them in the field as
|
|
148
|
+
`modulus - abs(value)`.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
witnesses (dict): Witness data as returned by `load_witness`.
|
|
152
|
+
expected_inputs (dict):
|
|
153
|
+
Dictionary containing key "input"
|
|
154
|
+
mapping to a list of expected input integers.
|
|
155
|
+
expected_outputs (dict):
|
|
156
|
+
Dictionary containing key "output"
|
|
157
|
+
mapping to a list of expected output integers.
|
|
158
|
+
modulus (int): Field modulus.
|
|
159
|
+
scaling_function
|
|
160
|
+
(Callable[[list[int], int, int], list[int]] | None, optional):
|
|
161
|
+
Optional scaling function to apply to inputs.
|
|
162
|
+
Takes (inputs, scale_base, scale_exponent)
|
|
163
|
+
and returns scaled inputs. Defaults to None.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
bool:
|
|
167
|
+
True if the witness public inputs match the expected inputs and outputs,
|
|
168
|
+
False otherwise.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
import torch # noqa: PLC0415
|
|
172
|
+
|
|
173
|
+
scale_base = witnesses["witnesses"][0]["public_inputs"][-2]
|
|
174
|
+
scale_exponent = witnesses["witnesses"][0]["public_inputs"][-1]
|
|
175
|
+
|
|
176
|
+
# Convert expectations into field form
|
|
177
|
+
inputs_list = expected_inputs.get("input", [])
|
|
178
|
+
|
|
179
|
+
if callable(scaling_function):
|
|
180
|
+
inputs_list = (
|
|
181
|
+
torch.tensor(scaling_function(inputs_list, scale_base, scale_exponent))
|
|
182
|
+
.flatten()
|
|
183
|
+
.tolist()
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
inputs_list = (
|
|
187
|
+
torch.round(
|
|
188
|
+
torch.tensor(inputs_list) * (scale_base**scale_exponent),
|
|
189
|
+
)
|
|
190
|
+
.long()
|
|
191
|
+
.tolist()
|
|
192
|
+
)
|
|
193
|
+
outputs_list = expected_outputs.get("output", [])
|
|
194
|
+
|
|
195
|
+
expected_inputs_mod = [
|
|
196
|
+
to_field_repr(v, modulus)
|
|
197
|
+
for v in torch.tensor(inputs_list).flatten().tolist()
|
|
198
|
+
]
|
|
199
|
+
expected_outputs_mod = [to_field_repr(v, modulus) for v in outputs_list]
|
|
200
|
+
|
|
201
|
+
n_inputs = len(expected_inputs_mod) + len(expected_outputs_mod) + 2
|
|
202
|
+
|
|
203
|
+
witness = witnesses["witnesses"][0]["public_inputs"]
|
|
204
|
+
|
|
205
|
+
if n_inputs != len(witness):
|
|
206
|
+
return False
|
|
207
|
+
|
|
208
|
+
actual_inputs = witness[: len(expected_inputs_mod)]
|
|
209
|
+
actual_outputs = witness[len(expected_inputs_mod) : -2]
|
|
210
|
+
|
|
211
|
+
# Compare
|
|
212
|
+
return (
|
|
213
|
+
actual_inputs == expected_inputs_mod
|
|
214
|
+
and actual_outputs == expected_outputs_mod
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# -------------------------
|
|
219
|
+
# Factory
|
|
220
|
+
# -------------------------
|
|
221
|
+
def get_loader(system: ZKProofSystems, path: str) -> WitnessLoader:
|
|
222
|
+
if system == ZKProofSystems.Expander:
|
|
223
|
+
return ExpanderWitnessLoader(path)
|
|
224
|
+
msg = f"No loader implemented for {system}"
|
|
225
|
+
raise ProofSystemNotImplementedError(msg)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# -------------------------
|
|
229
|
+
# Public API
|
|
230
|
+
# -------------------------
|
|
231
|
+
def load_witness(path: str, system: ZKProofSystems = ZKProofSystems.Expander) -> dict:
|
|
232
|
+
loader = get_loader(system, path)
|
|
233
|
+
return loader.load_witness()
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def compare_witness_to_io( # noqa: PLR0913
|
|
237
|
+
witnesses: dict,
|
|
238
|
+
expected_inputs: dict,
|
|
239
|
+
expected_outputs: dict,
|
|
240
|
+
modulus: int,
|
|
241
|
+
system: ZKProofSystems = ZKProofSystems.Expander,
|
|
242
|
+
scaling_function: Callable[[list[int], int, int], list[int]] | None = None,
|
|
243
|
+
) -> bool:
|
|
244
|
+
"""
|
|
245
|
+
Compare witness data to expected inputs and outputs for a given ZK proof system.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
witnesses (dict): Witness data as returned by `load_witness`.
|
|
249
|
+
expected_inputs (dict):
|
|
250
|
+
Dictionary containing key "input" mapping to list of expected integers.
|
|
251
|
+
expected_outputs (dict):
|
|
252
|
+
Dictionary containing key "output" mapping to list of expected integers.
|
|
253
|
+
modulus (int): Field modulus.
|
|
254
|
+
system (ZKProofSystems, optional):
|
|
255
|
+
The ZK proof system. Defaults to ZKProofSystems.Expander.
|
|
256
|
+
scaling_function (Callable[[list[int], int, int], list[int]] | None, optional):
|
|
257
|
+
Optional scaling function to apply to inputs.
|
|
258
|
+
Takes (inputs, scale_base, scale_exponent) and returns scaled inputs.
|
|
259
|
+
Defaults to None.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
bool: True if the witness matches the expected I/O, False otherwise.
|
|
263
|
+
"""
|
|
264
|
+
loader = get_loader(system, "") # path not needed for comparison
|
|
265
|
+
return loader.compare_witness_to_io(
|
|
266
|
+
witnesses,
|
|
267
|
+
expected_inputs,
|
|
268
|
+
expected_outputs,
|
|
269
|
+
modulus,
|
|
270
|
+
scaling_function,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
if __name__ == "__main__":
|
|
275
|
+
import time
|
|
276
|
+
|
|
277
|
+
start_time = time.time()
|
|
278
|
+
w = load_witness("./artifacts/lenet/witness.bin", ZKProofSystems.Expander)
|
|
279
|
+
end_time = time.time()
|
|
280
|
+
print("Modulus:", w["modulus"]) # noqa: T201
|
|
281
|
+
print("First witness inputs:", w["witnesses"][0]["inputs"][0]) # noqa: T201
|
|
282
|
+
print( # noqa: T201
|
|
283
|
+
"First witness public inputs:",
|
|
284
|
+
w["witnesses"][0]["public_inputs"][0],
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
print(len(w["witnesses"][0]["public_inputs"])) # noqa: T201
|
|
288
|
+
print((w["witnesses"][0]["public_inputs"][0] - w["modulus"]) / 2**18) # noqa: T201
|
|
289
|
+
elapsed = end_time - start_time
|
|
290
|
+
|
|
291
|
+
print("time taken: ", elapsed) # noqa: T201
|
|
File without changes
|
python/frontend/cli.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# python/frontend/cli.py
|
|
2
|
+
"""JSTprove CLI."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from python.frontend.commands import BaseCommand
|
|
13
|
+
|
|
14
|
+
from python.frontend.commands import (
|
|
15
|
+
BenchCommand,
|
|
16
|
+
CompileCommand,
|
|
17
|
+
ModelCheckCommand,
|
|
18
|
+
ProveCommand,
|
|
19
|
+
VerifyCommand,
|
|
20
|
+
WitnessCommand,
|
|
21
|
+
)
|
|
22
|
+
from python.frontend.commands.base import HiddenPositionalHelpFormatter
|
|
23
|
+
|
|
24
|
+
BANNER_TITLE = r"""
|
|
25
|
+
888888 .d8888b. 88888888888
|
|
26
|
+
"88b d88P Y88b 888
|
|
27
|
+
888 Y88b. 888
|
|
28
|
+
888 "Y888b. 888 88888b. 888d888 .d88b. 888 888 .d88b.
|
|
29
|
+
888 "Y88b. 888 888 "88b 888P" d88""88b 888 888 d8P Y8b
|
|
30
|
+
888 "888 888 888 888 888 888 888 Y88 88P 88888888
|
|
31
|
+
88P Y88b d88P 888 888 d88P 888 Y88..88P Y8bd8P Y8b.
|
|
32
|
+
888 "Y8888P" 888 88888P" 888 "Y88P" Y88P "Y8888
|
|
33
|
+
.d88P 888
|
|
34
|
+
.d88P" 888
|
|
35
|
+
888P" 888
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
COMMANDS: list[type[BaseCommand]] = [
|
|
39
|
+
ModelCheckCommand,
|
|
40
|
+
CompileCommand,
|
|
41
|
+
WitnessCommand,
|
|
42
|
+
ProveCommand,
|
|
43
|
+
VerifyCommand,
|
|
44
|
+
BenchCommand,
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def print_header() -> None:
|
|
49
|
+
"""Print the CLI banner (no side-effects at import time)."""
|
|
50
|
+
print( # noqa: T201
|
|
51
|
+
BANNER_TITLE
|
|
52
|
+
+ "\n"
|
|
53
|
+
+ "JSTprove — Verifiable ML by Inference Labs\n"
|
|
54
|
+
+ "Based on Polyhedra Network's Expander (GKR-based proving system)\n",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def main(argv: list[str] | None = None) -> int:
|
|
59
|
+
"""
|
|
60
|
+
Entry point for the JSTprove CLI.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
0 on success, 1 on error.
|
|
64
|
+
"""
|
|
65
|
+
argv = sys.argv[1:] if argv is None else argv
|
|
66
|
+
|
|
67
|
+
parser = argparse.ArgumentParser(
|
|
68
|
+
prog="jst",
|
|
69
|
+
description="ZKML CLI (compile, witness, prove, verify).",
|
|
70
|
+
allow_abbrev=False,
|
|
71
|
+
)
|
|
72
|
+
parser.add_argument(
|
|
73
|
+
"--no-banner",
|
|
74
|
+
action="store_true",
|
|
75
|
+
help="Suppress the startup banner.",
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
|
79
|
+
|
|
80
|
+
command_map = {}
|
|
81
|
+
for command_cls in COMMANDS:
|
|
82
|
+
cmd_parser = subparsers.add_parser(
|
|
83
|
+
command_cls.name,
|
|
84
|
+
aliases=command_cls.aliases,
|
|
85
|
+
help=command_cls.help,
|
|
86
|
+
allow_abbrev=False,
|
|
87
|
+
formatter_class=HiddenPositionalHelpFormatter,
|
|
88
|
+
)
|
|
89
|
+
command_cls.configure_parser(cmd_parser)
|
|
90
|
+
command_map[command_cls.name] = command_cls
|
|
91
|
+
for alias in command_cls.aliases:
|
|
92
|
+
command_map[alias] = command_cls
|
|
93
|
+
|
|
94
|
+
args = parser.parse_args(argv)
|
|
95
|
+
|
|
96
|
+
if not args.no_banner and not os.environ.get("JSTPROVE_NO_BANNER"):
|
|
97
|
+
print_header()
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
command_cls = command_map[args.cmd]
|
|
101
|
+
command_cls.run(args)
|
|
102
|
+
except (ValueError, FileNotFoundError, PermissionError, RuntimeError) as e:
|
|
103
|
+
print(f"Error: {e}", file=sys.stderr) # noqa: T201
|
|
104
|
+
return 1
|
|
105
|
+
except SystemExit:
|
|
106
|
+
raise
|
|
107
|
+
except Exception as e:
|
|
108
|
+
print(f"Error: {e}", file=sys.stderr) # noqa: T201
|
|
109
|
+
return 1
|
|
110
|
+
|
|
111
|
+
return 0
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
if __name__ == "__main__":
|
|
115
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from python.frontend.commands.base import BaseCommand
|
|
2
|
+
from python.frontend.commands.bench import BenchCommand
|
|
3
|
+
from python.frontend.commands.compile import CompileCommand
|
|
4
|
+
from python.frontend.commands.model_check import ModelCheckCommand
|
|
5
|
+
from python.frontend.commands.prove import ProveCommand
|
|
6
|
+
from python.frontend.commands.verify import VerifyCommand
|
|
7
|
+
from python.frontend.commands.witness import WitnessCommand
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"BaseCommand",
|
|
11
|
+
"BenchCommand",
|
|
12
|
+
"CompileCommand",
|
|
13
|
+
"ModelCheckCommand",
|
|
14
|
+
"ProveCommand",
|
|
15
|
+
"VerifyCommand",
|
|
16
|
+
"WitnessCommand",
|
|
17
|
+
]
|