JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +6 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from python.core import RUST_BINARY_NAME
|
|
11
|
+
from python.core.circuits.errors import (
|
|
12
|
+
CircuitFileError,
|
|
13
|
+
CircuitProcessingError,
|
|
14
|
+
CircuitRunError,
|
|
15
|
+
)
|
|
16
|
+
from python.core.circuits.zk_model_base import ZKModelBase
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from python.core.model_processing.converters.onnx_converter import (
|
|
20
|
+
CircuitParamsDict,
|
|
21
|
+
ONNXLayerDict,
|
|
22
|
+
)
|
|
23
|
+
from python.core.model_processing.converters.onnx_converter import (
|
|
24
|
+
ONNXConverter,
|
|
25
|
+
ONNXOpQuantizer,
|
|
26
|
+
)
|
|
27
|
+
from python.core.model_processing.onnx_quantizer.layers.base import BaseOpQuantizer
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GenericModelONNX(ONNXConverter, ZKModelBase):
|
|
31
|
+
"""
|
|
32
|
+
A generic ONNX-based Zero-Knowledge (ZK) circuit model wrapper.
|
|
33
|
+
|
|
34
|
+
This class provides:
|
|
35
|
+
- Integration for ONNX model loading, quantization (in `ONNXConverter`)
|
|
36
|
+
and ZK circuit infrastructure (in `ZKModelBase`).
|
|
37
|
+
- Support for model quantization via `ONNXOpQuantizer`.
|
|
38
|
+
- Input/output scaling and formatting utilities for ZK compatibility.
|
|
39
|
+
|
|
40
|
+
Attributes
|
|
41
|
+
----------
|
|
42
|
+
name : str
|
|
43
|
+
Internal identifier for the binary to be run in rust backend.
|
|
44
|
+
op_quantizer : ONNXOpQuantizer
|
|
45
|
+
Operator quantizer for applying custom ONNX quantization rules.
|
|
46
|
+
rescale_config : dict
|
|
47
|
+
Per-node override for rescaling during quantization.
|
|
48
|
+
Keys are node names, values are booleans.
|
|
49
|
+
If not specified, assumption is to rescale each layer
|
|
50
|
+
model_file_name : str
|
|
51
|
+
Path to the ONNX model file used for the circuit.
|
|
52
|
+
scale_base : int
|
|
53
|
+
Base multiplier for scaling (default: 2).
|
|
54
|
+
scale_exponent : int
|
|
55
|
+
Exponent applied to `scale_base` for final scaling factor.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
model_name : str
|
|
60
|
+
Name of the model to load (with or without `.onnx` extension).
|
|
61
|
+
|
|
62
|
+
Notes
|
|
63
|
+
-----
|
|
64
|
+
- The scaling factor (`scale_base ** scale_exponent`) determines how floating point
|
|
65
|
+
inputs/outputs are represented as integers inside the ZK circuit.
|
|
66
|
+
- By default, scaling is fixed; dynamic scaling based on model analysis
|
|
67
|
+
is planned for future implementation.
|
|
68
|
+
- The quantization logic assumes operators are registered with
|
|
69
|
+
`ONNXOpQuantizer`.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self: GenericModelONNX,
|
|
74
|
+
model_name: str,
|
|
75
|
+
*,
|
|
76
|
+
use_find_model: bool = False,
|
|
77
|
+
) -> None:
|
|
78
|
+
try:
|
|
79
|
+
self.name = RUST_BINARY_NAME
|
|
80
|
+
self.op_quantizer = ONNXOpQuantizer()
|
|
81
|
+
self.rescale_config = {}
|
|
82
|
+
if use_find_model:
|
|
83
|
+
self.model_file_name = self.find_model(model_name)
|
|
84
|
+
else:
|
|
85
|
+
self.model_file_name = model_name
|
|
86
|
+
|
|
87
|
+
self.scale_base = 2
|
|
88
|
+
self.scale_exponent = 18
|
|
89
|
+
ONNXConverter.__init__(self)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
|
|
92
|
+
msg = f"Failed to initialize GenericModelONNX with model '{model_name}'"
|
|
93
|
+
raise CircuitFileError(
|
|
94
|
+
msg,
|
|
95
|
+
file_path=model_name,
|
|
96
|
+
) from e
|
|
97
|
+
|
|
98
|
+
def find_model(self: GenericModelONNX, model_name: str) -> str:
|
|
99
|
+
"""Resolve the ONNX model file path.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model_name (str): Name of the model (with or without `.onnx` extension).
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
str: Full path to the model file.
|
|
106
|
+
"""
|
|
107
|
+
if ".onnx" not in model_name:
|
|
108
|
+
model_name = model_name + ".onnx"
|
|
109
|
+
|
|
110
|
+
# Check direct path first
|
|
111
|
+
if Path(model_name).exists():
|
|
112
|
+
return model_name
|
|
113
|
+
|
|
114
|
+
# Check models_onnx directory
|
|
115
|
+
if "models_onnx" in model_name:
|
|
116
|
+
if Path(model_name).exists():
|
|
117
|
+
return model_name
|
|
118
|
+
models_onnx_path = model_name
|
|
119
|
+
else:
|
|
120
|
+
models_onnx_path = f"models_onnx/{model_name}"
|
|
121
|
+
|
|
122
|
+
if not Path(models_onnx_path).exists():
|
|
123
|
+
msg = f"Model file not found: '{model_name}'"
|
|
124
|
+
raise CircuitFileError(
|
|
125
|
+
msg,
|
|
126
|
+
file_path=models_onnx_path,
|
|
127
|
+
)
|
|
128
|
+
return models_onnx_path
|
|
129
|
+
|
|
130
|
+
def adjust_inputs(self: GenericModelONNX, input_file: str) -> str:
|
|
131
|
+
"""Preprocess and flatten model inputs for the circuit.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
input_file (str): Input data file or array compatible with the model.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
str: Adjusted input file after reshaping and scaling.
|
|
138
|
+
"""
|
|
139
|
+
try:
|
|
140
|
+
input_shape = self.input_shape.copy()
|
|
141
|
+
shape = self.adjust_shape(input_shape)
|
|
142
|
+
self.input_shape = [math.prod(shape)]
|
|
143
|
+
x = super().adjust_inputs(input_file)
|
|
144
|
+
self.input_shape = input_shape.copy()
|
|
145
|
+
except Exception as e:
|
|
146
|
+
msg = f"Failed to adjust inputs for GenericModelONNX: {e}"
|
|
147
|
+
raise ValueError(msg) from e
|
|
148
|
+
else:
|
|
149
|
+
return x
|
|
150
|
+
|
|
151
|
+
def get_outputs(
|
|
152
|
+
self: GenericModelONNX,
|
|
153
|
+
inputs: np.ndarray | list[int] | torch.Tensor,
|
|
154
|
+
) -> torch.Tensor:
|
|
155
|
+
"""Run inference and flatten outputs.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
inputs (List[int]): Preprocessed model inputs.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
torch.Tensor: Flattened model outputs as a tensor.
|
|
162
|
+
"""
|
|
163
|
+
try:
|
|
164
|
+
raw_outputs = super().get_outputs(inputs)
|
|
165
|
+
except Exception as e:
|
|
166
|
+
msg = "Failed to get outputs for GenericModelONNX"
|
|
167
|
+
raise CircuitRunError(
|
|
168
|
+
msg,
|
|
169
|
+
operation="get_outputs",
|
|
170
|
+
) from e
|
|
171
|
+
else:
|
|
172
|
+
return torch.as_tensor(np.array(raw_outputs)).flatten()
|
|
173
|
+
|
|
174
|
+
def format_inputs(
|
|
175
|
+
self: GenericModelONNX,
|
|
176
|
+
inputs: np.ndarray | list[int] | torch.Tensor,
|
|
177
|
+
) -> dict[str, list[int]]:
|
|
178
|
+
"""Format raw inputs into scaled integer tensors for the circuit
|
|
179
|
+
and transformed into json to be sent to rust backend.
|
|
180
|
+
Inputs are scaled by `scale_base ** scale_exponent`
|
|
181
|
+
and converted to long to ensure compatibility with ZK circuits
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
inputs (Any): Raw model inputs.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Dict[str, List[int]]: Dictionary mapping `input` to scaled integer values.
|
|
188
|
+
"""
|
|
189
|
+
try:
|
|
190
|
+
x = {"input": inputs}
|
|
191
|
+
scaling = BaseOpQuantizer.get_scaling(
|
|
192
|
+
scale_base=self.scale_base,
|
|
193
|
+
scale_exponent=self.scale_exponent,
|
|
194
|
+
)
|
|
195
|
+
for key in x: # noqa: PLC0206
|
|
196
|
+
x[key] = torch.as_tensor(x[key]).flatten().tolist()
|
|
197
|
+
x[key] = (torch.as_tensor(x[key]) * scaling).long().tolist()
|
|
198
|
+
except Exception as e:
|
|
199
|
+
msg = f"Failed to format inputs for GenericModelONNX: {e}"
|
|
200
|
+
raise CircuitProcessingError(
|
|
201
|
+
msg,
|
|
202
|
+
operation="format_inputs",
|
|
203
|
+
data_type=type(inputs).__name__,
|
|
204
|
+
) from e
|
|
205
|
+
else:
|
|
206
|
+
return x
|
|
207
|
+
|
|
208
|
+
def get_weights(
|
|
209
|
+
self: GenericModelONNX,
|
|
210
|
+
) -> dict[str, list[ONNXLayerDict]]:
|
|
211
|
+
_, w_and_b, _ = super().get_weights()
|
|
212
|
+
# Currently want to read these in separately
|
|
213
|
+
return w_and_b
|
|
214
|
+
|
|
215
|
+
def get_architecture(
|
|
216
|
+
self: GenericModelONNX,
|
|
217
|
+
) -> dict[str, list[ONNXLayerDict]]:
|
|
218
|
+
architecture, _, _ = super().get_weights()
|
|
219
|
+
# Currently want to read these in separately
|
|
220
|
+
return architecture
|
|
221
|
+
|
|
222
|
+
def get_metadata(
|
|
223
|
+
self: GenericModelONNX,
|
|
224
|
+
) -> CircuitParamsDict:
|
|
225
|
+
_, _, circuit_params = super().get_weights()
|
|
226
|
+
# Currently want to read these in separately
|
|
227
|
+
return circuit_params
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
if __name__ == "__main__":
|
|
231
|
+
pass
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from random import randint
|
|
4
|
+
|
|
5
|
+
from python.core.circuits.base import Circuit, RunType
|
|
6
|
+
from python.core.utils.helper_functions import CircuitExecutionConfig, ZKProofSystems
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SimpleCircuit(Circuit):
|
|
10
|
+
def __init__(self: SimpleCircuit) -> None:
|
|
11
|
+
# Initialize the base class
|
|
12
|
+
super().__init__()
|
|
13
|
+
|
|
14
|
+
# Circuit-specific parameters
|
|
15
|
+
self.name = "simple_circuit" # Use exact name that matches the binary
|
|
16
|
+
self.scale_exponent = 1
|
|
17
|
+
self.scale_base = 1
|
|
18
|
+
|
|
19
|
+
self.input_a = 100
|
|
20
|
+
self.input_b = 200
|
|
21
|
+
#############################################################
|
|
22
|
+
### NOTE This is not a prg suitable for use in-production ###
|
|
23
|
+
#############################################################
|
|
24
|
+
self.nonce = randint(0, 10000000) # noqa: S311
|
|
25
|
+
|
|
26
|
+
self.required_keys = ["value_a", "value_b", "nonce"]
|
|
27
|
+
|
|
28
|
+
self.input_shape = [1]
|
|
29
|
+
|
|
30
|
+
def get_inputs(self: SimpleCircuit) -> dict[str, int]:
|
|
31
|
+
"""Retrieve the current input values for the circuit.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
dict[str, int]: A dictionary containing `value_a`, `value_b`, and `nonce`.
|
|
35
|
+
"""
|
|
36
|
+
return {"value_a": self.input_a, "value_b": self.input_b, "nonce": self.nonce}
|
|
37
|
+
|
|
38
|
+
def get_outputs(self: SimpleCircuit, inputs: dict[str, int] | None = None) -> int:
|
|
39
|
+
"""Compute the output of the circuit.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
inputs (dict[str, int], optional):
|
|
43
|
+
A dictionary containing `value_a`, `value_b`, and `nonce`.
|
|
44
|
+
If None, uses the instance's default inputs. Defaults to None.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
int: output of function
|
|
48
|
+
"""
|
|
49
|
+
if inputs is None:
|
|
50
|
+
inputs = {
|
|
51
|
+
"value_a": self.input_a,
|
|
52
|
+
"value_b": self.input_b,
|
|
53
|
+
"nonce": self.nonce,
|
|
54
|
+
}
|
|
55
|
+
print( # noqa: T201
|
|
56
|
+
f"Performing addition operation: {inputs['value_a']} + {inputs['value_b']}",
|
|
57
|
+
)
|
|
58
|
+
return inputs["value_a"] + inputs["value_b"]
|
|
59
|
+
|
|
60
|
+
def format_inputs(self: SimpleCircuit, inputs: dict[str, int]) -> dict[str, int]:
|
|
61
|
+
"""Format the inputs for the circuit.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
inputs (dict[str, int]): A dictionary containing circuit input values.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
dict[str, int]: A dictionary containing circuit input values.
|
|
68
|
+
"""
|
|
69
|
+
return inputs
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# Example code demonstrating circuit operations
|
|
73
|
+
if __name__ == "__main__":
|
|
74
|
+
# Create a single circuit instance
|
|
75
|
+
print("\n--- Creating circuit instance ---") # noqa: T201
|
|
76
|
+
circuit = SimpleCircuit()
|
|
77
|
+
|
|
78
|
+
print("\n--- Testing different operations ---") # noqa: T201
|
|
79
|
+
|
|
80
|
+
print("\nGetting output again (should use cached value):") # noqa: T201
|
|
81
|
+
output_again = circuit.get_outputs()
|
|
82
|
+
print(f"Circuit output: {output_again}") # noqa: T201
|
|
83
|
+
|
|
84
|
+
# Run another operation
|
|
85
|
+
print("\nRunning compilation:") # noqa: T201
|
|
86
|
+
circuit.base_testing(
|
|
87
|
+
CircuitExecutionConfig(
|
|
88
|
+
run_type=RunType.COMPILE_CIRCUIT,
|
|
89
|
+
dev_mode=True,
|
|
90
|
+
circuit_path="simple_circuit.txt",
|
|
91
|
+
input_file="inputs/simple_circuit_input.json",
|
|
92
|
+
output_file="output/simple_circuit_output.txt",
|
|
93
|
+
proof_system=ZKProofSystems.Expander,
|
|
94
|
+
),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Read the input and output files to verify
|
|
98
|
+
print("\n--- Verifying input and output files ---") # noqa: T201
|
|
99
|
+
print(f"Input file: {circuit._file_info['input_file']}") # noqa: SLF001, T201
|
|
100
|
+
print(f"Output file: {circuit._file_info['output_file']}") # noqa: SLF001, T201
|
|
101
|
+
|
|
102
|
+
circuit.base_testing(
|
|
103
|
+
CircuitExecutionConfig(
|
|
104
|
+
run_type=RunType.GEN_WITNESS,
|
|
105
|
+
circuit_path="simple_circuit.txt",
|
|
106
|
+
input_file="inputs/simple_circuit_input.json",
|
|
107
|
+
output_file="output/simple_circuit_output.json",
|
|
108
|
+
write_json=True,
|
|
109
|
+
proof_system=ZKProofSystems.Expander,
|
|
110
|
+
),
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
circuit = SimpleCircuit()
|
|
114
|
+
circuit.base_testing(
|
|
115
|
+
CircuitExecutionConfig(
|
|
116
|
+
run_type=RunType.PROVE_WITNESS,
|
|
117
|
+
circuit_path="simple_circuit.txt",
|
|
118
|
+
input_file="inputs/simple_circuit_input.json",
|
|
119
|
+
output_file="output/simple_circuit_output.json",
|
|
120
|
+
proof_system=ZKProofSystems.Expander,
|
|
121
|
+
),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
circuit = SimpleCircuit()
|
|
125
|
+
circuit.base_testing(
|
|
126
|
+
CircuitExecutionConfig(
|
|
127
|
+
run_type=RunType.GEN_VERIFY,
|
|
128
|
+
circuit_path="simple_circuit.txt",
|
|
129
|
+
input_file="inputs/simple_circuit_input.json",
|
|
130
|
+
output_file="output/simple_circuit_output.json",
|
|
131
|
+
proof_system=ZKProofSystems.Expander,
|
|
132
|
+
),
|
|
133
|
+
)
|
|
File without changes
|