JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +6 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from python.core.utils.errors import (
|
|
9
|
+
CircuitUtilsError,
|
|
10
|
+
InputFileError,
|
|
11
|
+
MissingCircuitAttributeError,
|
|
12
|
+
ShapeMismatchError,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GeneralLayerFunctions:
|
|
17
|
+
"""
|
|
18
|
+
A collection of utility functions for reading, generating, scaling, and
|
|
19
|
+
formatting model inputs/outputs. This is primarily intended for
|
|
20
|
+
preparing inputs for ONNX models or similar layer-based models.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def read_input(self: GeneralLayerFunctions, file_name: str) -> list | dict:
|
|
24
|
+
"""Read model input data from a JSON file.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
file_name (str): Path to the JSON file containing input data.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Any: The value of the "input" field from the JSON file.
|
|
31
|
+
"""
|
|
32
|
+
try:
|
|
33
|
+
with Path(file_name).open("r") as file:
|
|
34
|
+
data = json.load(file)
|
|
35
|
+
except FileNotFoundError as e:
|
|
36
|
+
raise InputFileError(file_name, "File not found", cause=e) from e
|
|
37
|
+
except json.JSONDecodeError as e:
|
|
38
|
+
raise InputFileError(
|
|
39
|
+
file_name,
|
|
40
|
+
f"Invalid JSON format: {e.msg}",
|
|
41
|
+
cause=e,
|
|
42
|
+
) from e
|
|
43
|
+
|
|
44
|
+
if "input" not in data:
|
|
45
|
+
raise InputFileError(file_name, "Missing required 'input' field in JSON")
|
|
46
|
+
|
|
47
|
+
return data["input"]
|
|
48
|
+
|
|
49
|
+
def get_inputs_from_file(
|
|
50
|
+
self: GeneralLayerFunctions,
|
|
51
|
+
file_name: str,
|
|
52
|
+
*,
|
|
53
|
+
is_scaled: bool = False,
|
|
54
|
+
) -> torch.Tensor:
|
|
55
|
+
"""Load and optionally scale inputs from a file.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
file_name (str): Path to the file containing input data.
|
|
59
|
+
is_scaled (bool, optional):
|
|
60
|
+
If True, returns unscaled values. If False, applies scaling using
|
|
61
|
+
`self.scale_base ** self.scale_exponent`. Defaults to False.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
torch.Tensor: The loaded, reshaped, and potentially rescaled input tensor.
|
|
65
|
+
"""
|
|
66
|
+
inputs = self.read_input(file_name)
|
|
67
|
+
try:
|
|
68
|
+
tensor = torch.as_tensor(inputs)
|
|
69
|
+
except Exception as e:
|
|
70
|
+
raise InputFileError(
|
|
71
|
+
file_name,
|
|
72
|
+
f"Invalid input data for tensor conversion: {e}",
|
|
73
|
+
) from e
|
|
74
|
+
|
|
75
|
+
if not is_scaled:
|
|
76
|
+
if not (hasattr(self, "scale_base") and hasattr(self, "scale_exponent")):
|
|
77
|
+
attr_name = "scale_base/scale_exponent"
|
|
78
|
+
msg = "needed for scaling"
|
|
79
|
+
raise MissingCircuitAttributeError(
|
|
80
|
+
attr_name,
|
|
81
|
+
msg,
|
|
82
|
+
)
|
|
83
|
+
tensor = torch.mul(tensor, self.scale_base**self.scale_exponent)
|
|
84
|
+
|
|
85
|
+
tensor = tensor.long()
|
|
86
|
+
|
|
87
|
+
if hasattr(self, "input_shape"):
|
|
88
|
+
shape = self.input_shape
|
|
89
|
+
if hasattr(self, "adjust_shape") and callable(
|
|
90
|
+
self.adjust_shape,
|
|
91
|
+
):
|
|
92
|
+
shape = self.adjust_shape(shape)
|
|
93
|
+
try:
|
|
94
|
+
tensor = tensor.reshape(shape)
|
|
95
|
+
except RuntimeError as e:
|
|
96
|
+
raise ShapeMismatchError(shape, list(tensor.shape)) from e
|
|
97
|
+
return tensor
|
|
98
|
+
|
|
99
|
+
def get_inputs(
|
|
100
|
+
self: GeneralLayerFunctions,
|
|
101
|
+
file_path: str | None = None,
|
|
102
|
+
*,
|
|
103
|
+
is_scaled: bool = False,
|
|
104
|
+
) -> torch.Tensor:
|
|
105
|
+
"""Retrieve model inputs, either from a file or by generating new inputs.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
file_path (str, optional):
|
|
109
|
+
Path to the input file. If None,
|
|
110
|
+
new random inputs are generated. Defaults to None.
|
|
111
|
+
is_scaled (bool, optional):
|
|
112
|
+
Whether to skip scaling of loaded inputs. Defaults to False.
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
NotImplementedError: If `self.input_shape` is not defined.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
torch.Tensor: The input tensor shaped according to `self.input_shape`.
|
|
119
|
+
"""
|
|
120
|
+
if file_path is None:
|
|
121
|
+
attr_name = "input_shape"
|
|
122
|
+
if not hasattr(self, attr_name):
|
|
123
|
+
msg = "needed to generate random inputs"
|
|
124
|
+
raise MissingCircuitAttributeError(
|
|
125
|
+
attr_name,
|
|
126
|
+
msg,
|
|
127
|
+
)
|
|
128
|
+
return self.create_new_inputs()
|
|
129
|
+
|
|
130
|
+
return self.get_inputs_from_file(file_path, is_scaled=is_scaled).reshape(
|
|
131
|
+
self.input_shape,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def create_new_inputs(self: GeneralLayerFunctions) -> torch.Tensor:
|
|
135
|
+
"""Generate new random input tensors.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
__type__:
|
|
139
|
+
- If `self.input_shape` is a list/tuple, returns a single tensor.
|
|
140
|
+
- If `self.input_shape` is a dict, returns a dictionary mapping
|
|
141
|
+
input names to tensors.
|
|
142
|
+
"""
|
|
143
|
+
attr_name = "input_shape"
|
|
144
|
+
if not hasattr(self, attr_name):
|
|
145
|
+
context = "needed to generate new inputs"
|
|
146
|
+
raise MissingCircuitAttributeError(
|
|
147
|
+
attr_name,
|
|
148
|
+
context,
|
|
149
|
+
)
|
|
150
|
+
# ONNX inputs will be in this form, and require inputs to not be scaled up
|
|
151
|
+
if isinstance(self.input_shape, dict):
|
|
152
|
+
keys = self.input_shape.keys()
|
|
153
|
+
if len(keys) == 1:
|
|
154
|
+
# If unknown dim in batch spot, assume batch size of 1
|
|
155
|
+
first_key = next(iter(keys))
|
|
156
|
+
input_shape = self.input_shape[first_key]
|
|
157
|
+
input_shape[0] = 1 if input_shape[0] < 1 else input_shape[0]
|
|
158
|
+
return self.get_rand_inputs(input_shape)
|
|
159
|
+
inputs = {}
|
|
160
|
+
for key in keys:
|
|
161
|
+
# If unknown dim in batch spot, assume batch size of 1
|
|
162
|
+
input_shape = self.input_shape[keys[key]]
|
|
163
|
+
if not isinstance(input_shape, list) and not isinstance(
|
|
164
|
+
input_shape,
|
|
165
|
+
tuple,
|
|
166
|
+
):
|
|
167
|
+
msg = f"Invalid input shape for key '{key}': {input_shape}"
|
|
168
|
+
raise CircuitUtilsError(msg)
|
|
169
|
+
input_shape[0] = 1 if input_shape[0] < 1 else input_shape[0]
|
|
170
|
+
inputs[key] = self.get_rand_inputs(input_shape)
|
|
171
|
+
return inputs
|
|
172
|
+
if not (hasattr(self, "scale_base") and hasattr(self, "scale_exponent")):
|
|
173
|
+
attr_name = "scale_base/scale_exponent"
|
|
174
|
+
context = "needed for scaling random inputs"
|
|
175
|
+
raise MissingCircuitAttributeError(
|
|
176
|
+
attr_name,
|
|
177
|
+
context,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
return torch.mul(
|
|
181
|
+
self.get_rand_inputs(self.input_shape),
|
|
182
|
+
self.scale_base**self.scale_exponent,
|
|
183
|
+
).long()
|
|
184
|
+
|
|
185
|
+
def get_rand_inputs(
|
|
186
|
+
self: GeneralLayerFunctions,
|
|
187
|
+
input_shape: list[int],
|
|
188
|
+
) -> torch.Tensor:
|
|
189
|
+
"""Generate random input values in the range [-1, 1).
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
input_shape (list[int]): Shape of the tensor to generate.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
torch.Tensor: A tensor of random values in [-1, 1).
|
|
196
|
+
"""
|
|
197
|
+
if not isinstance(input_shape, (list, tuple)):
|
|
198
|
+
msg = f"Invalid input_shape type: {type(input_shape)}."
|
|
199
|
+
" Expected list or tuple of ints."
|
|
200
|
+
raise CircuitUtilsError(msg)
|
|
201
|
+
if not all(isinstance(x, int) and x > 0 for x in input_shape):
|
|
202
|
+
raise ShapeMismatchError(
|
|
203
|
+
expected_shape="positive integers",
|
|
204
|
+
actual_shape=input_shape,
|
|
205
|
+
)
|
|
206
|
+
return torch.rand(input_shape) * 2 - 1
|
|
207
|
+
|
|
208
|
+
def format_inputs(
|
|
209
|
+
self: GeneralLayerFunctions,
|
|
210
|
+
inputs: torch.Tensor,
|
|
211
|
+
) -> dict[str, list[int]]:
|
|
212
|
+
"""Format input tensors for JSON serialization.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
inputs (torch.Tensor): The input tensor.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
dict[str, list[int]]:
|
|
219
|
+
A dictionary with the key "input"
|
|
220
|
+
containing the tensor as a list of integers.
|
|
221
|
+
"""
|
|
222
|
+
return {"input": inputs.long().tolist()}
|
|
223
|
+
|
|
224
|
+
def format_outputs(
|
|
225
|
+
self: GeneralLayerFunctions,
|
|
226
|
+
outputs: torch.Tensor,
|
|
227
|
+
) -> dict[str, list[int]]:
|
|
228
|
+
"""Format output tensors for JSON serialization,
|
|
229
|
+
including rescaled outputs for readability.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
outputs (torch.Tensor): _deThe output tensor.cription_
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
dict[str, list[int]]: A dictionary containing:
|
|
236
|
+
- "output": the raw output tensor as a list of integers.
|
|
237
|
+
- "rescaled_output": the output divided by the scaling factor.
|
|
238
|
+
"""
|
|
239
|
+
if hasattr(self, "scale_exponent") and hasattr(self, "scale_base"):
|
|
240
|
+
try:
|
|
241
|
+
rescaled = torch.div(outputs, self.scale_base**self.scale_exponent)
|
|
242
|
+
except Exception as e:
|
|
243
|
+
msg = "Failed to rescale outputs using scale_base="
|
|
244
|
+
f"{getattr(self, 'scale_base', None)} "
|
|
245
|
+
f"and scale_exponent={getattr(self, 'scale_exponent', None)}: {e}"
|
|
246
|
+
raise CircuitUtilsError(msg) from e
|
|
247
|
+
return {
|
|
248
|
+
"output": outputs.long().tolist(),
|
|
249
|
+
"rescaled_output": rescaled.tolist(),
|
|
250
|
+
}
|
|
251
|
+
return {"output": outputs.long().tolist()}
|
|
252
|
+
|
|
253
|
+
def format_inputs_outputs(
|
|
254
|
+
self: GeneralLayerFunctions,
|
|
255
|
+
inputs: torch.Tensor,
|
|
256
|
+
outputs: torch.Tensor,
|
|
257
|
+
) -> tuple[dict[str, list[int]], dict[str, list[int]]]:
|
|
258
|
+
"""Format both inputs and outputs for JSON serialization.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
inputs (torch.Tensor): Model inputs.
|
|
262
|
+
outputs (torch.Tensor): Model outputs.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
tuple[dict[str, list[int]], dict[str, list[int]]]:
|
|
266
|
+
A tuple containing the formatted inputs and formatted outputs.
|
|
267
|
+
"""
|
|
268
|
+
return self.format_inputs(inputs), self.format_outputs(outputs)
|