JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +6 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,231 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from python.core import RUST_BINARY_NAME
11
+ from python.core.circuits.errors import (
12
+ CircuitFileError,
13
+ CircuitProcessingError,
14
+ CircuitRunError,
15
+ )
16
+ from python.core.circuits.zk_model_base import ZKModelBase
17
+
18
+ if TYPE_CHECKING:
19
+ from python.core.model_processing.converters.onnx_converter import (
20
+ CircuitParamsDict,
21
+ ONNXLayerDict,
22
+ )
23
+ from python.core.model_processing.converters.onnx_converter import (
24
+ ONNXConverter,
25
+ ONNXOpQuantizer,
26
+ )
27
+ from python.core.model_processing.onnx_quantizer.layers.base import BaseOpQuantizer
28
+
29
+
30
+ class GenericModelONNX(ONNXConverter, ZKModelBase):
31
+ """
32
+ A generic ONNX-based Zero-Knowledge (ZK) circuit model wrapper.
33
+
34
+ This class provides:
35
+ - Integration for ONNX model loading, quantization (in `ONNXConverter`)
36
+ and ZK circuit infrastructure (in `ZKModelBase`).
37
+ - Support for model quantization via `ONNXOpQuantizer`.
38
+ - Input/output scaling and formatting utilities for ZK compatibility.
39
+
40
+ Attributes
41
+ ----------
42
+ name : str
43
+ Internal identifier for the binary to be run in rust backend.
44
+ op_quantizer : ONNXOpQuantizer
45
+ Operator quantizer for applying custom ONNX quantization rules.
46
+ rescale_config : dict
47
+ Per-node override for rescaling during quantization.
48
+ Keys are node names, values are booleans.
49
+ If not specified, assumption is to rescale each layer
50
+ model_file_name : str
51
+ Path to the ONNX model file used for the circuit.
52
+ scale_base : int
53
+ Base multiplier for scaling (default: 2).
54
+ scale_exponent : int
55
+ Exponent applied to `scale_base` for final scaling factor.
56
+
57
+ Parameters
58
+ ----------
59
+ model_name : str
60
+ Name of the model to load (with or without `.onnx` extension).
61
+
62
+ Notes
63
+ -----
64
+ - The scaling factor (`scale_base ** scale_exponent`) determines how floating point
65
+ inputs/outputs are represented as integers inside the ZK circuit.
66
+ - By default, scaling is fixed; dynamic scaling based on model analysis
67
+ is planned for future implementation.
68
+ - The quantization logic assumes operators are registered with
69
+ `ONNXOpQuantizer`.
70
+ """
71
+
72
+ def __init__(
73
+ self: GenericModelONNX,
74
+ model_name: str,
75
+ *,
76
+ use_find_model: bool = False,
77
+ ) -> None:
78
+ try:
79
+ self.name = RUST_BINARY_NAME
80
+ self.op_quantizer = ONNXOpQuantizer()
81
+ self.rescale_config = {}
82
+ if use_find_model:
83
+ self.model_file_name = self.find_model(model_name)
84
+ else:
85
+ self.model_file_name = model_name
86
+
87
+ self.scale_base = 2
88
+ self.scale_exponent = 18
89
+ ONNXConverter.__init__(self)
90
+ except Exception as e:
91
+
92
+ msg = f"Failed to initialize GenericModelONNX with model '{model_name}'"
93
+ raise CircuitFileError(
94
+ msg,
95
+ file_path=model_name,
96
+ ) from e
97
+
98
+ def find_model(self: GenericModelONNX, model_name: str) -> str:
99
+ """Resolve the ONNX model file path.
100
+
101
+ Args:
102
+ model_name (str): Name of the model (with or without `.onnx` extension).
103
+
104
+ Returns:
105
+ str: Full path to the model file.
106
+ """
107
+ if ".onnx" not in model_name:
108
+ model_name = model_name + ".onnx"
109
+
110
+ # Check direct path first
111
+ if Path(model_name).exists():
112
+ return model_name
113
+
114
+ # Check models_onnx directory
115
+ if "models_onnx" in model_name:
116
+ if Path(model_name).exists():
117
+ return model_name
118
+ models_onnx_path = model_name
119
+ else:
120
+ models_onnx_path = f"models_onnx/{model_name}"
121
+
122
+ if not Path(models_onnx_path).exists():
123
+ msg = f"Model file not found: '{model_name}'"
124
+ raise CircuitFileError(
125
+ msg,
126
+ file_path=models_onnx_path,
127
+ )
128
+ return models_onnx_path
129
+
130
+ def adjust_inputs(self: GenericModelONNX, input_file: str) -> str:
131
+ """Preprocess and flatten model inputs for the circuit.
132
+
133
+ Args:
134
+ input_file (str): Input data file or array compatible with the model.
135
+
136
+ Returns:
137
+ str: Adjusted input file after reshaping and scaling.
138
+ """
139
+ try:
140
+ input_shape = self.input_shape.copy()
141
+ shape = self.adjust_shape(input_shape)
142
+ self.input_shape = [math.prod(shape)]
143
+ x = super().adjust_inputs(input_file)
144
+ self.input_shape = input_shape.copy()
145
+ except Exception as e:
146
+ msg = f"Failed to adjust inputs for GenericModelONNX: {e}"
147
+ raise ValueError(msg) from e
148
+ else:
149
+ return x
150
+
151
+ def get_outputs(
152
+ self: GenericModelONNX,
153
+ inputs: np.ndarray | list[int] | torch.Tensor,
154
+ ) -> torch.Tensor:
155
+ """Run inference and flatten outputs.
156
+
157
+ Args:
158
+ inputs (List[int]): Preprocessed model inputs.
159
+
160
+ Returns:
161
+ torch.Tensor: Flattened model outputs as a tensor.
162
+ """
163
+ try:
164
+ raw_outputs = super().get_outputs(inputs)
165
+ except Exception as e:
166
+ msg = "Failed to get outputs for GenericModelONNX"
167
+ raise CircuitRunError(
168
+ msg,
169
+ operation="get_outputs",
170
+ ) from e
171
+ else:
172
+ return torch.as_tensor(np.array(raw_outputs)).flatten()
173
+
174
+ def format_inputs(
175
+ self: GenericModelONNX,
176
+ inputs: np.ndarray | list[int] | torch.Tensor,
177
+ ) -> dict[str, list[int]]:
178
+ """Format raw inputs into scaled integer tensors for the circuit
179
+ and transformed into json to be sent to rust backend.
180
+ Inputs are scaled by `scale_base ** scale_exponent`
181
+ and converted to long to ensure compatibility with ZK circuits
182
+
183
+ Args:
184
+ inputs (Any): Raw model inputs.
185
+
186
+ Returns:
187
+ Dict[str, List[int]]: Dictionary mapping `input` to scaled integer values.
188
+ """
189
+ try:
190
+ x = {"input": inputs}
191
+ scaling = BaseOpQuantizer.get_scaling(
192
+ scale_base=self.scale_base,
193
+ scale_exponent=self.scale_exponent,
194
+ )
195
+ for key in x: # noqa: PLC0206
196
+ x[key] = torch.as_tensor(x[key]).flatten().tolist()
197
+ x[key] = (torch.as_tensor(x[key]) * scaling).long().tolist()
198
+ except Exception as e:
199
+ msg = f"Failed to format inputs for GenericModelONNX: {e}"
200
+ raise CircuitProcessingError(
201
+ msg,
202
+ operation="format_inputs",
203
+ data_type=type(inputs).__name__,
204
+ ) from e
205
+ else:
206
+ return x
207
+
208
+ def get_weights(
209
+ self: GenericModelONNX,
210
+ ) -> dict[str, list[ONNXLayerDict]]:
211
+ _, w_and_b, _ = super().get_weights()
212
+ # Currently want to read these in separately
213
+ return w_and_b
214
+
215
+ def get_architecture(
216
+ self: GenericModelONNX,
217
+ ) -> dict[str, list[ONNXLayerDict]]:
218
+ architecture, _, _ = super().get_weights()
219
+ # Currently want to read these in separately
220
+ return architecture
221
+
222
+ def get_metadata(
223
+ self: GenericModelONNX,
224
+ ) -> CircuitParamsDict:
225
+ _, _, circuit_params = super().get_weights()
226
+ # Currently want to read these in separately
227
+ return circuit_params
228
+
229
+
230
+ if __name__ == "__main__":
231
+ pass
@@ -0,0 +1,133 @@
1
+ from __future__ import annotations
2
+
3
+ from random import randint
4
+
5
+ from python.core.circuits.base import Circuit, RunType
6
+ from python.core.utils.helper_functions import CircuitExecutionConfig, ZKProofSystems
7
+
8
+
9
+ class SimpleCircuit(Circuit):
10
+ def __init__(self: SimpleCircuit) -> None:
11
+ # Initialize the base class
12
+ super().__init__()
13
+
14
+ # Circuit-specific parameters
15
+ self.name = "simple_circuit" # Use exact name that matches the binary
16
+ self.scale_exponent = 1
17
+ self.scale_base = 1
18
+
19
+ self.input_a = 100
20
+ self.input_b = 200
21
+ #############################################################
22
+ ### NOTE This is not a prg suitable for use in-production ###
23
+ #############################################################
24
+ self.nonce = randint(0, 10000000) # noqa: S311
25
+
26
+ self.required_keys = ["value_a", "value_b", "nonce"]
27
+
28
+ self.input_shape = [1]
29
+
30
+ def get_inputs(self: SimpleCircuit) -> dict[str, int]:
31
+ """Retrieve the current input values for the circuit.
32
+
33
+ Returns:
34
+ dict[str, int]: A dictionary containing `value_a`, `value_b`, and `nonce`.
35
+ """
36
+ return {"value_a": self.input_a, "value_b": self.input_b, "nonce": self.nonce}
37
+
38
+ def get_outputs(self: SimpleCircuit, inputs: dict[str, int] | None = None) -> int:
39
+ """Compute the output of the circuit.
40
+
41
+ Args:
42
+ inputs (dict[str, int], optional):
43
+ A dictionary containing `value_a`, `value_b`, and `nonce`.
44
+ If None, uses the instance's default inputs. Defaults to None.
45
+
46
+ Returns:
47
+ int: output of function
48
+ """
49
+ if inputs is None:
50
+ inputs = {
51
+ "value_a": self.input_a,
52
+ "value_b": self.input_b,
53
+ "nonce": self.nonce,
54
+ }
55
+ print( # noqa: T201
56
+ f"Performing addition operation: {inputs['value_a']} + {inputs['value_b']}",
57
+ )
58
+ return inputs["value_a"] + inputs["value_b"]
59
+
60
+ def format_inputs(self: SimpleCircuit, inputs: dict[str, int]) -> dict[str, int]:
61
+ """Format the inputs for the circuit.
62
+
63
+ Args:
64
+ inputs (dict[str, int]): A dictionary containing circuit input values.
65
+
66
+ Returns:
67
+ dict[str, int]: A dictionary containing circuit input values.
68
+ """
69
+ return inputs
70
+
71
+
72
+ # Example code demonstrating circuit operations
73
+ if __name__ == "__main__":
74
+ # Create a single circuit instance
75
+ print("\n--- Creating circuit instance ---") # noqa: T201
76
+ circuit = SimpleCircuit()
77
+
78
+ print("\n--- Testing different operations ---") # noqa: T201
79
+
80
+ print("\nGetting output again (should use cached value):") # noqa: T201
81
+ output_again = circuit.get_outputs()
82
+ print(f"Circuit output: {output_again}") # noqa: T201
83
+
84
+ # Run another operation
85
+ print("\nRunning compilation:") # noqa: T201
86
+ circuit.base_testing(
87
+ CircuitExecutionConfig(
88
+ run_type=RunType.COMPILE_CIRCUIT,
89
+ dev_mode=True,
90
+ circuit_path="simple_circuit.txt",
91
+ input_file="inputs/simple_circuit_input.json",
92
+ output_file="output/simple_circuit_output.txt",
93
+ proof_system=ZKProofSystems.Expander,
94
+ ),
95
+ )
96
+
97
+ # Read the input and output files to verify
98
+ print("\n--- Verifying input and output files ---") # noqa: T201
99
+ print(f"Input file: {circuit._file_info['input_file']}") # noqa: SLF001, T201
100
+ print(f"Output file: {circuit._file_info['output_file']}") # noqa: SLF001, T201
101
+
102
+ circuit.base_testing(
103
+ CircuitExecutionConfig(
104
+ run_type=RunType.GEN_WITNESS,
105
+ circuit_path="simple_circuit.txt",
106
+ input_file="inputs/simple_circuit_input.json",
107
+ output_file="output/simple_circuit_output.json",
108
+ write_json=True,
109
+ proof_system=ZKProofSystems.Expander,
110
+ ),
111
+ )
112
+
113
+ circuit = SimpleCircuit()
114
+ circuit.base_testing(
115
+ CircuitExecutionConfig(
116
+ run_type=RunType.PROVE_WITNESS,
117
+ circuit_path="simple_circuit.txt",
118
+ input_file="inputs/simple_circuit_input.json",
119
+ output_file="output/simple_circuit_output.json",
120
+ proof_system=ZKProofSystems.Expander,
121
+ ),
122
+ )
123
+
124
+ circuit = SimpleCircuit()
125
+ circuit.base_testing(
126
+ CircuitExecutionConfig(
127
+ run_type=RunType.GEN_VERIFY,
128
+ circuit_path="simple_circuit.txt",
129
+ input_file="inputs/simple_circuit_input.json",
130
+ output_file="output/simple_circuit_output.json",
131
+ proof_system=ZKProofSystems.Expander,
132
+ ),
133
+ )
File without changes