JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +6 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,188 @@
1
+ # python/core/utils/exceptions.py
2
+ from __future__ import annotations
3
+
4
+ from python.core.utils.helper_functions import RunType
5
+
6
+
7
+ class CircuitError(Exception):
8
+ """
9
+ Base class for all circuit-related errors.
10
+
11
+ Attributes:
12
+ message (str): Human-readable description of the error.
13
+ details (dict): Optional structured details for debugging or logging.
14
+ """
15
+
16
+ def __init__(
17
+ self: CircuitError,
18
+ message: str,
19
+ details: dict | None = None,
20
+ ) -> None:
21
+ super().__init__(message)
22
+ self.message = message
23
+ self.details = details or {}
24
+
25
+ def __str__(self: CircuitError) -> str:
26
+ parts = [self.message]
27
+ if self.details:
28
+ parts.append(f"Details: {self.details}")
29
+
30
+ # Show the chained exception cause if present
31
+ if self.__cause__ is not None:
32
+ parts.append(f"Caused by: {self.__cause__}")
33
+ return " | ".join(parts)
34
+
35
+
36
+ class CircuitConfigurationError(CircuitError):
37
+ """
38
+ Raised when circuit is not properly configured (missing or invalid attributes).
39
+
40
+ Attributes:
41
+ missing_attributes (list): List of missing attributes (if known).
42
+ """
43
+
44
+ def __init__(
45
+ self: CircuitConfigurationError,
46
+ message: str | None = None,
47
+ missing_attributes: list | None = None,
48
+ details: dict | None = None,
49
+ ) -> None:
50
+ if missing_attributes and not message:
51
+ message = (
52
+ "Circuit class (python) is misconfigured."
53
+ f" Missing required attributes: {', '.join(missing_attributes)}"
54
+ )
55
+ elif not message:
56
+ message = "Circuit is misconfigured."
57
+ super().__init__(message, details)
58
+ self.missing_attributes = missing_attributes or []
59
+
60
+
61
+ class CircuitInputError(CircuitError):
62
+ """
63
+ Raised when input validation fails (missing or invalid values).
64
+
65
+ Attributes:
66
+ parameter (str): Name of the problematic parameter (if known).
67
+ expected (str): Expected type or value description (optional).
68
+ actual (any): Actual value encountered (optional).
69
+ """
70
+
71
+ def __init__( # noqa: PLR0913
72
+ self: CircuitInputError,
73
+ message: str | None = None,
74
+ parameter: str | None = None,
75
+ expected: str | None = None,
76
+ actual: any | None = None,
77
+ details: dict | None = None,
78
+ ) -> None:
79
+ if parameter and not message:
80
+ msg_parts = [f"Issue with parameter '{parameter}'."]
81
+ if expected:
82
+ msg_parts.append(f"Expected: {expected}.")
83
+ if actual is not None:
84
+ msg_parts.append(f"Got: {actual!r}.")
85
+ message = " ".join(msg_parts)
86
+ elif not message:
87
+ message = "Invalid circuit class (python) input."
88
+ super().__init__(message, details)
89
+ self.parameter = parameter
90
+ self.expected = expected
91
+ self.actual = actual
92
+
93
+
94
+ class CircuitRunError(CircuitError):
95
+ """
96
+ Raised when an operation (compile, prove, verify, etc.) fails.
97
+
98
+ Attributes:
99
+ operation (str): Name of the operation that failed (if known).
100
+ """
101
+
102
+ def __init__(
103
+ self: CircuitRunError,
104
+ message: str | None = None,
105
+ operation: RunType | None = None,
106
+ details: dict | None = None,
107
+ ) -> None:
108
+ operations_to_name = {
109
+ RunType.COMPILE_CIRCUIT: "Compile",
110
+ RunType.GEN_VERIFY: "Verify",
111
+ RunType.PROVE_WITNESS: "Prove",
112
+ RunType.GEN_WITNESS: "Witness",
113
+ }
114
+ if operation and not message:
115
+ message = f"Circuit operation '{operations_to_name.get(operation)}' failed."
116
+ elif not message:
117
+ message = "Circuit run failed."
118
+ super().__init__(message, details)
119
+ self.operation = operation
120
+
121
+
122
+ class CircuitFileError(CircuitError):
123
+ """
124
+ Raised when file-related operations fail
125
+ (e.g., reading, writing, or accessing files).
126
+
127
+ Attributes:
128
+ file_path (str): Path to the problematic file (if known).
129
+ """
130
+
131
+ def __init__(
132
+ self: CircuitFileError,
133
+ message: str | None = None,
134
+ file_path: str | None = None,
135
+ details: dict | None = None,
136
+ ) -> None:
137
+ if file_path and not message:
138
+ message = f"File operation failed for path: {file_path}"
139
+ elif not message:
140
+ message = "Circuit file operation failed."
141
+ super().__init__(message, details)
142
+ self.file_path = file_path
143
+
144
+
145
+ class CircuitProcessingError(CircuitError):
146
+ """
147
+ Raised when data processing operations fail
148
+ (e.g., tensor operations, scaling, reshaping).
149
+
150
+ Attributes:
151
+ operation (str): Name of the operation that failed (if known).
152
+ data_type (str): Type of data being processed (if known).
153
+ """
154
+
155
+ def __init__(
156
+ self: CircuitProcessingError,
157
+ message: str | None = None,
158
+ operation: str | None = None,
159
+ data_type: str | None = None,
160
+ details: dict | None = None,
161
+ ) -> None:
162
+ if operation and not message:
163
+ message = f"Data processing failed during {operation}."
164
+ elif not message:
165
+ message = "Circuit data processing failed."
166
+ super().__init__(message, details)
167
+ self.operation = operation
168
+ self.data_type = data_type
169
+
170
+
171
+ class WitnessMatchError(CircuitError):
172
+ """
173
+ Raised when input validation fails (missing or invalid values).
174
+
175
+ Attributes:
176
+ parameter (str): Name of the problematic parameter (if known).
177
+ expected (str): Expected type or value description (optional).
178
+ actual (any): Actual value encountered (optional).
179
+ """
180
+
181
+ def __init__(
182
+ self: CircuitInputError,
183
+ message: str | None = None,
184
+ ) -> None:
185
+ common_message = "Witness does not match provided inputs and outputs!"
186
+ if message:
187
+ common_message += f" {message}"
188
+ super().__init__(common_message)
@@ -0,0 +1,25 @@
1
+ from __future__ import annotations
2
+
3
+ from python.core.circuits.base import Circuit
4
+ from python.core.utils.general_layer_functions import GeneralLayerFunctions
5
+
6
+
7
+ class ZKModelBase(GeneralLayerFunctions, Circuit):
8
+ """
9
+ Abstract base class for Zero-Knowledge (ZK) ML models.
10
+
11
+ This class provides a standard interface for ZK circuit ML models.
12
+ Instantiates Circuit and GeneralLayerFunctions.
13
+
14
+ Subclasses must implement the constructor to define the model's
15
+ architecture, layers, and circuit details.
16
+ """
17
+
18
+ def __init__(self: ZKModelBase) -> None:
19
+ """Initialize the ZK model. Must be overridden by subclasses
20
+
21
+ Raises:
22
+ NotImplementedError: If called on the base class directly.
23
+ """
24
+ msg = "Must implement __init__"
25
+ raise NotImplementedError(msg)
File without changes
File without changes
@@ -0,0 +1,143 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from enum import Enum
5
+ from typing import TYPE_CHECKING, Optional, Union
6
+
7
+ if TYPE_CHECKING:
8
+ import numpy as np
9
+ import onnx
10
+ import torch
11
+
12
+
13
+ class ModelType(Enum):
14
+ ONNX = "ONNX"
15
+
16
+
17
+ ONNXLayerDict = dict[
18
+ str,
19
+ Union[int, str, list[str], dict[str, list[int]], Optional[list], Optional[dict]],
20
+ ]
21
+
22
+ CircuitParamsDict = dict[str, Union[int, dict[str, bool]]]
23
+
24
+
25
+ class ModelConverter(ABC):
26
+ """
27
+ Abstract base class for AI model conversion, quantization, and I/O operations.
28
+
29
+ This class defines the required interface for implementing a model converter
30
+ that can handle:
31
+ - Saving/loading models in various formats
32
+ - Quantizing models
33
+ - Extracting model weights
34
+ - Generating model outputs
35
+
36
+ Concrete subclasses must implement all abstract methods to provide
37
+ model-specific conversion logic.
38
+ """
39
+
40
+ @abstractmethod
41
+ def save_model(self: ModelConverter, file_path: str) -> None:
42
+ """Save the current model to the specified file path.
43
+
44
+ Args:
45
+ file_path (str): Path to save the model file.
46
+ """
47
+
48
+ @abstractmethod
49
+ def load_model(
50
+ self: ModelConverter,
51
+ file_path: str,
52
+ model_type: ModelType | None = None,
53
+ ) -> onnx.ModelProto:
54
+ """
55
+ Load a model from a file.
56
+
57
+ Args:
58
+ file_path (str): Path to the model file.
59
+ model_type (Optional[ModelType]):
60
+ Optional identifier for the model format/type.
61
+ Useful if multiple formats are supported.
62
+
63
+ Returns:
64
+ onnx.ModelProto: The loaded model.
65
+ """
66
+
67
+ @abstractmethod
68
+ def save_quantized_model(self: ModelConverter, file_path: str) -> None:
69
+ """Save the quantized version of the model to the specified file path.
70
+
71
+ Args:
72
+ file_path (str): Path to save the quantized model file.
73
+ """
74
+
75
+ @abstractmethod
76
+ def load_quantized_model(self: ModelConverter, file_path: str) -> None:
77
+ """Load a quantized model from a file.
78
+
79
+ Args:
80
+ file_path (str): Path to the quantized model file.
81
+ """
82
+
83
+ @abstractmethod
84
+ def quantize_model(
85
+ self: ModelConverter,
86
+ model: onnx.ModelProto,
87
+ scale_base: int,
88
+ scale_exponent: int,
89
+ rescale_config: dict | None = None,
90
+ ) -> onnx.ModelProto:
91
+ """Quantize a model with a given scale and optional rescaling configuration.
92
+
93
+ Args:
94
+ model (onnx.ModelProto): The model instance to quantize.
95
+ scale_base (int): Base for fixed-point scaling (e.g., 2).
96
+ scale_exponent (int): Quantization scale factor.
97
+ rescale_config (Optional[dict], optional):
98
+ Configuration for rescaling layers or weights during quantization.
99
+ Defaults to None.
100
+
101
+ Returns:
102
+ onnx.ModelProto: The quantized model.
103
+ """
104
+
105
+ @abstractmethod
106
+ def get_weights(
107
+ self: ModelConverter,
108
+ ) -> tuple[
109
+ dict[str, list[ONNXLayerDict]],
110
+ dict[str, list[ONNXLayerDict]],
111
+ CircuitParamsDict,
112
+ ]:
113
+ """Retrieve the model's weights.
114
+
115
+ Returns:
116
+ tuple[dict[str, list[ONNXLayerDict]],
117
+ dict[str, list[ONNXLayerDict]], CircuitParamsDict]:
118
+ A tuple ``(architecture, weights, circuit_params)``:
119
+ - ``architecture``: dict with serialized ``architecture`` layers.
120
+ - ``weights``: dict containing ``w_and_b`` (serialized tensors).
121
+ - ``circuit_params``: dict containing scaling parameters and
122
+ ``rescale_config``.
123
+ """
124
+
125
+ @abstractmethod
126
+ def get_model_and_quantize(self: ModelConverter) -> None:
127
+ """Retrieve the model and quantize it in a single operation."""
128
+
129
+ @abstractmethod
130
+ def get_outputs(
131
+ self: ModelConverter,
132
+ inputs: np.ndarray | torch.Tensor,
133
+ ) -> list[np.ndarray]:
134
+ """
135
+ Run inference on the given inputs and return model outputs.
136
+
137
+ Args:
138
+ inputs (np.ndarray | torch.Tensor):
139
+ Input data in the format expected by the model.
140
+
141
+ Returns:
142
+ list[np.ndarray]: Model outputs after processing the inputs.
143
+ """