JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +5 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,190 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any, Generator, Mapping, Sequence, TypeAlias, Union
5
+
6
+ import numpy as np
7
+ import pytest
8
+ import torch
9
+
10
+ if TYPE_CHECKING:
11
+ from python.core.circuits.zk_model_base import ZKModelBase
12
+ from python.core.utils.helper_functions import CircuitExecutionConfig, RunType
13
+
14
+ GOOD_OUTPUT = ["Witness Generated"]
15
+ BAD_OUTPUT = [
16
+ "Witness generation failed",
17
+ "Outputs generated do not match outputs supplied",
18
+ ]
19
+
20
+ NUMPARAMS3 = 3
21
+ NUMPARAMS4 = 4
22
+
23
+
24
+ @pytest.fixture(scope="module")
25
+ def model_fixture(
26
+ request: pytest.FixtureRequest,
27
+ tmp_path_factory: pytest.TempPathFactory,
28
+ ) -> dict[str, Any]:
29
+ param = request.param
30
+ name = f"{param.name}"
31
+ model_class = param.loader
32
+ args, kwargs = (), {}
33
+
34
+ if len(param) == NUMPARAMS3:
35
+ if isinstance(param[2], dict):
36
+ kwargs = param[2]
37
+ else:
38
+ args = param[2]
39
+ elif len(param) == NUMPARAMS4:
40
+ args, kwargs = param[2], param[3]
41
+
42
+ temp_dir = tmp_path_factory.mktemp(name)
43
+ circuit_path = temp_dir / f"{name}_circuit.txt"
44
+ quantized_path = temp_dir / f"{name}_quantized.pt"
45
+
46
+ model = model_class(*args, **kwargs)
47
+
48
+ model.base_testing(
49
+ CircuitExecutionConfig(
50
+ run_type=RunType.COMPILE_CIRCUIT,
51
+ dev_mode=True,
52
+ circuit_path=str(circuit_path),
53
+ quantized_path=quantized_path,
54
+ ),
55
+ )
56
+
57
+ return {
58
+ "name": name,
59
+ "model_class": model_class,
60
+ "circuit_path": circuit_path,
61
+ "temp_dir": temp_dir,
62
+ "model": model,
63
+ "quantized_model": quantized_path,
64
+ }
65
+
66
+
67
+ @pytest.fixture()
68
+ def temp_witness_file(tmp_path: str) -> Generator[Path, None, None]:
69
+ witness_path = tmp_path / "temp_witness.txt"
70
+ # Give it to the test
71
+ yield witness_path
72
+
73
+ # After the test is done, remove it
74
+ if Path.exists(witness_path):
75
+ witness_path.unlink()
76
+
77
+
78
+ @pytest.fixture()
79
+ def temp_input_file(tmp_path: str) -> Generator[Path, None, None]:
80
+ input_path = tmp_path / "temp_input.txt"
81
+ # Give it to the test
82
+ yield input_path
83
+
84
+ # After the test is done, remove it
85
+ if Path.exists(input_path):
86
+ input_path.unlink()
87
+
88
+
89
+ @pytest.fixture()
90
+ def temp_output_file(tmp_path: str) -> Generator[Path, None, None]:
91
+ output_path = tmp_path / "temp_output.txt"
92
+ # Give it to the test
93
+ yield output_path
94
+
95
+ # After the test is done, remove it
96
+ if Path.exists(output_path):
97
+ output_path.unlink()
98
+
99
+
100
+ @pytest.fixture()
101
+ def temp_proof_file(tmp_path: str) -> Generator[Path, None, None]:
102
+ output_path = tmp_path / "temp_proof.txt"
103
+ # Give it to the test
104
+ yield output_path
105
+
106
+ # After the test is done, remove it
107
+ if Path.exists(output_path):
108
+ output_path.unlink()
109
+
110
+
111
+ ScalarOrTensor: TypeAlias = Union[int, float, torch.Tensor]
112
+ NestedArray: TypeAlias = Union[
113
+ ScalarOrTensor,
114
+ list["NestedArray"],
115
+ tuple["NestedArray"],
116
+ np.ndarray,
117
+ ]
118
+
119
+
120
+ def add_1_to_first_element(x: NestedArray) -> NestedArray:
121
+ """Safely adds 1 to the first element of any scalar/list/tensor."""
122
+ if isinstance(x, (int, float)):
123
+ return x + 1
124
+ if isinstance(x, torch.Tensor):
125
+ x = x.clone() # avoid in-place modification
126
+ x.view(-1)[0] += 1
127
+ return x
128
+ if isinstance(x, (list, tuple, np.ndarray)):
129
+ x = list(x)
130
+ x[0] = add_1_to_first_element(x[0])
131
+ return x
132
+ msg = f"Unsupported type for get_outputs patch: {type(x)}"
133
+ raise TypeError(msg)
134
+
135
+
136
+ # Define models to be tested
137
+ circuit_compile_results = {}
138
+ witness_generated_results = {}
139
+
140
+ Nested: TypeAlias = Union[float, Mapping[str, "Nested"], Sequence["Nested"]]
141
+
142
+
143
+ def contains_float(obj: Nested) -> bool:
144
+ if isinstance(obj, float):
145
+ return True
146
+ if isinstance(obj, dict):
147
+ return any(contains_float(v) for v in obj.values())
148
+ if isinstance(obj, list):
149
+ return any(contains_float(i) for i in obj)
150
+ return False
151
+
152
+
153
+ @pytest.fixture(scope="module")
154
+ def check_model_compiles(model_fixture: dict[str, Any]) -> None:
155
+ # Default to True; will be set to False if first test fails
156
+ result = circuit_compile_results.get(model_fixture["model"])
157
+ if result is False:
158
+ pytest.skip(
159
+ f"Skipping because the first test failed for: {model_fixture['model']}",
160
+ )
161
+ return result
162
+
163
+
164
+ @pytest.fixture(scope="module")
165
+ def check_witness_generated(model_fixture: dict[str, Any]) -> None:
166
+ # Default to True; will be set to False if first test fails
167
+ result = witness_generated_results.get(model_fixture["model"])
168
+ if result is False:
169
+ pytest.skip(
170
+ f"Skipping because the first test failed for: {model_fixture['model']}",
171
+ )
172
+ return result
173
+
174
+
175
+ def assert_very_close(
176
+ inputs_1: np.array,
177
+ inputs_2: np.array,
178
+ model: ZKModelBase,
179
+ ) -> None:
180
+ for i in inputs_1:
181
+ x = torch.div(
182
+ torch.as_tensor(inputs_1[i]),
183
+ model.scale_base**model.scale_exponent,
184
+ )
185
+ y = torch.div(
186
+ torch.as_tensor(inputs_2[i]),
187
+ model.scale_base**model.scale_exponent,
188
+ )
189
+
190
+ assert torch.isclose(x, y, rtol=1e-8).all()
@@ -0,0 +1,217 @@
1
+ import json
2
+ import subprocess
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Generator
6
+
7
+ import numpy as np
8
+ import onnx
9
+ import pytest
10
+ from onnx import helper, numpy_helper
11
+
12
+
13
+ def create_simple_gemm_onnx_model(
14
+ input_size: int,
15
+ output_size: int,
16
+ model_path: Path,
17
+ ) -> None:
18
+ """Create a simple ONNX model with a single GEMM layer."""
19
+ # Define input
20
+ input_tensor = helper.make_tensor_value_info(
21
+ "input",
22
+ onnx.TensorProto.FLOAT,
23
+ [1, input_size],
24
+ )
25
+
26
+ # Define output
27
+ output_tensor = helper.make_tensor_value_info(
28
+ "output",
29
+ onnx.TensorProto.FLOAT,
30
+ [1, output_size],
31
+ )
32
+
33
+ # Create random number generator
34
+ rng = np.random.default_rng()
35
+
36
+ # Create weight tensor
37
+ weight = rng.standard_normal((output_size, input_size)).astype(np.float32)
38
+ weight_tensor = numpy_helper.from_array(weight, name="weight")
39
+
40
+ # Create bias tensor
41
+ bias = rng.standard_normal((output_size,)).astype(np.float32)
42
+ bias_tensor = numpy_helper.from_array(bias, name="bias")
43
+
44
+ # Create GEMM node
45
+ gemm_node = helper.make_node(
46
+ "Gemm",
47
+ inputs=["input", "weight", "bias"],
48
+ outputs=["output"],
49
+ alpha=1.0,
50
+ beta=1.0,
51
+ transB=1, # Transpose B (weight)
52
+ )
53
+
54
+ # Create graph
55
+ graph = helper.make_graph(
56
+ [gemm_node],
57
+ "simple_gemm",
58
+ [input_tensor],
59
+ [output_tensor],
60
+ [weight_tensor, bias_tensor],
61
+ )
62
+
63
+ # Create model
64
+ model = helper.make_model(graph, producer_name="simple_gemm_creator")
65
+
66
+ # Save model
67
+ onnx.save(model, str(model_path))
68
+
69
+
70
+ @pytest.mark.e2e()
71
+ def test_parallel_compile_and_witness_two_simple_models( # noqa: PLR0915
72
+ tmp_path: str,
73
+ capsys: Generator[pytest.CaptureFixture[str], None, None],
74
+ ) -> None:
75
+ """Test compiling and running witness
76
+ for two different simple ONNX models in parallel.
77
+ """
78
+ # Create two simple ONNX models with different shapes
79
+ model1_path = Path(tmp_path) / "simple_gemm1.onnx"
80
+ model2_path = Path(tmp_path) / "simple_gemm2.onnx"
81
+ model1_input_size = 4
82
+ model1_output_size = 2
83
+
84
+ model2_input_size = 6
85
+ model2_output_size = 3
86
+
87
+ create_simple_gemm_onnx_model(model1_input_size, model1_output_size, model1_path)
88
+ create_simple_gemm_onnx_model(model2_input_size, model2_output_size, model2_path)
89
+
90
+ # Define paths for artifacts
91
+ circuit1_path = Path(tmp_path) / "circuit1.txt"
92
+ circuit2_path = Path(tmp_path) / "circuit2.txt"
93
+
94
+ # Compile both models
95
+ compile_cmd1 = [
96
+ sys.executable,
97
+ "-m",
98
+ "python.frontend.cli",
99
+ "compile",
100
+ "-m",
101
+ str(model1_path),
102
+ "-c",
103
+ str(circuit1_path),
104
+ ]
105
+ compile_cmd2 = [
106
+ sys.executable,
107
+ "-m",
108
+ "python.frontend.cli",
109
+ "compile",
110
+ "-m",
111
+ str(model2_path),
112
+ "-c",
113
+ str(circuit2_path),
114
+ ]
115
+
116
+ # Run compile commands
117
+ result1 = subprocess.run(
118
+ compile_cmd1, # noqa: S603
119
+ capture_output=True,
120
+ text=True,
121
+ check=False,
122
+ )
123
+ assert result1.returncode == 0, f"Compile failed for model1: {result1.stderr}"
124
+
125
+ result2 = subprocess.run(
126
+ compile_cmd2, # noqa: S603
127
+ capture_output=True,
128
+ text=True,
129
+ check=False,
130
+ )
131
+ assert result2.returncode == 0, f"Compile failed for model2: {result2.stderr}"
132
+
133
+ # Create input files
134
+ input1_data = {"input": [1.0] * model1_input_size} # 10 inputs
135
+ input2_data = {"input": [1.0] * model2_input_size} # 20 inputs
136
+
137
+ input1_path = Path(tmp_path) / "input1.json"
138
+ input2_path = Path(tmp_path) / "input2.json"
139
+
140
+ with Path.open(input1_path, "w") as f:
141
+ json.dump(input1_data, f)
142
+ with Path.open(input2_path, "w") as f:
143
+ json.dump(input2_data, f)
144
+
145
+ # Define output and witness paths
146
+ output1_path = Path(tmp_path) / "output1.json"
147
+ witness1_path = Path(tmp_path) / "witness1.bin"
148
+ output2_path = Path(tmp_path) / "output2.json"
149
+ witness2_path = Path(tmp_path) / "witness2.bin"
150
+
151
+ # Run witness commands in parallel
152
+ witness_cmd1 = [
153
+ sys.executable,
154
+ "-m",
155
+ "python.frontend.cli",
156
+ "witness",
157
+ "-c",
158
+ str(circuit1_path),
159
+ "-i",
160
+ str(input1_path),
161
+ "-o",
162
+ str(output1_path),
163
+ "-w",
164
+ str(witness1_path),
165
+ ]
166
+ witness_cmd2 = [
167
+ sys.executable,
168
+ "-m",
169
+ "python.frontend.cli",
170
+ "witness",
171
+ "-c",
172
+ str(circuit2_path),
173
+ "-i",
174
+ str(input2_path),
175
+ "-o",
176
+ str(output2_path),
177
+ "-w",
178
+ str(witness2_path),
179
+ ]
180
+
181
+ # Start both processes
182
+ proc1 = subprocess.Popen(witness_cmd1) # noqa: S603
183
+ proc2 = subprocess.Popen(witness_cmd2) # noqa: S603
184
+
185
+ # Wait for both to complete
186
+ proc1.wait()
187
+ proc2.wait()
188
+
189
+ # Check return codes
190
+ assert proc1.returncode == 0, "Witness failed for model1"
191
+ assert proc2.returncode == 0, "Witness failed for model2"
192
+
193
+ # Verify outputs exist
194
+ assert output1_path.exists(), "Output1 file not generated"
195
+ assert output2_path.exists(), "Output2 file not generated"
196
+ assert witness1_path.exists(), "Witness1 file not generated"
197
+ assert witness2_path.exists(), "Witness2 file not generated"
198
+
199
+ # Check output contents (should have the correct shapes)
200
+ with Path.open(output1_path) as f:
201
+ output1 = json.load(f)
202
+ with Path.open(output2_path) as f:
203
+ output2 = json.load(f)
204
+
205
+ # Model1: input 10 -> output 5
206
+ assert "output" in output1, "Output1 missing 'output' key"
207
+ assert (
208
+ len(output1["output"]) == model1_output_size
209
+ ), f"Output1 should have {model1_output_size} elements,"
210
+ f" got {len(output1['output'])}"
211
+
212
+ # Model2: input 20 -> output 8
213
+ assert "output" in output2, "Output2 missing 'output' key"
214
+ assert (
215
+ len(output2["output"]) == model2_output_size
216
+ ), f"Output2 should have {model2_output_size} elements,"
217
+ f" got {len(output2['output'])}"
File without changes