JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +5 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,200 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable
4
+
5
+ import onnx
6
+
7
+ from python.core.model_processing.onnx_quantizer.exceptions import (
8
+ HandlerImplementationError,
9
+ MissingHandlerError,
10
+ UnsupportedOpError,
11
+ )
12
+ from python.core.model_processing.onnx_quantizer.layers.base import (
13
+ PassthroughQuantizer,
14
+ ScaleConfig,
15
+ )
16
+ from python.core.model_processing.onnx_quantizer.layers.constant import (
17
+ ConstantQuantizer,
18
+ )
19
+ from python.core.model_processing.onnx_quantizer.layers.conv import ConvQuantizer
20
+ from python.core.model_processing.onnx_quantizer.layers.gemm import GemmQuantizer
21
+ from python.core.model_processing.onnx_quantizer.layers.maxpool import MaxpoolQuantizer
22
+ from python.core.model_processing.onnx_quantizer.layers.relu import ReluQuantizer
23
+
24
+
25
+ class ONNXOpQuantizer:
26
+ """
27
+ Registry for ONNX operator quantizers.
28
+ This should be used to obtain the quantized
29
+ layer based on any provided operation of that layer type
30
+
31
+ Attributes
32
+ ----------
33
+ handlers : Dict[str, Callable]
34
+ Maps ONNX op_type strings to quantizer handler instances.
35
+ new_initializers : List[onnx.TensorProto]
36
+ A list of newly created ONNX initializers
37
+ (weights or biases typically) during quantization.
38
+ This is shared with handlers that may add new constants.
39
+
40
+ Methods
41
+ -------
42
+ register(op_type, handler)
43
+ Registers a handler for an ONNX op_type.
44
+ quantize(node, rescale, graph, scale_exponent, scale_base, initializer_map)
45
+ Apply quantization to a specific ONNX node using its registered handler.
46
+ check_model(model)
47
+ Ensure all operations in the model are supported and validate
48
+ each layer's parameters are valid and supported.
49
+ check_layer(node, initializer_map)
50
+ Validate a single ONNX node using its handler's check_supported method,
51
+ to check that the given layers parameters and structure is supported.
52
+ get_initializer_map(model)
53
+ Build a {name: TensorProto} mapping for the model's initializers.
54
+ """
55
+
56
+ def __init__(self: ONNXOpQuantizer) -> None:
57
+ self.handlers: dict[
58
+ str,
59
+ Callable[
60
+ [onnx.NodeProto, bool],
61
+ onnx.NodeProto | list[onnx.NodeProto],
62
+ ],
63
+ ] = {}
64
+ self.new_initializers = []
65
+
66
+ # Register handlers
67
+ self.register("Conv", ConvQuantizer(self.new_initializers))
68
+ self.register("Relu", ReluQuantizer())
69
+ self.register("Reshape", PassthroughQuantizer())
70
+ self.register("Gemm", GemmQuantizer(self.new_initializers))
71
+ self.register("Constant", ConstantQuantizer())
72
+ self.register("MaxPool", MaxpoolQuantizer())
73
+ self.register("Flatten", PassthroughQuantizer())
74
+
75
+ def register(
76
+ self: ONNXOpQuantizer,
77
+ op_type: str,
78
+ handler: Callable[
79
+ [onnx.NodeProto, bool],
80
+ onnx.NodeProto | list[onnx.NodeProto],
81
+ ],
82
+ ) -> None:
83
+ """Register a quantizer handler for a given ONNX op_type.
84
+
85
+ Args:
86
+ op_type (str): Name of the ONNX operator type (e.g., "Conv", "Relu").
87
+ handler (Callable[[onnx.NodeProto, bool],
88
+ Union[onnx.NodeProto, list[onnx.NodeProto]]]):
89
+ - Handler instance implementing `quantize()`
90
+ (and optionally `check_supported()`).
91
+
92
+ Raises:
93
+ HandlerImplementationError: If handler has not properly implemented
94
+ `quantize` method
95
+ """
96
+ if not hasattr(handler, "quantize") or not callable(handler.quantize):
97
+ raise HandlerImplementationError(op_type, "Missing 'quantize' method.")
98
+
99
+ self.handlers[op_type] = handler
100
+
101
+ def quantize( # noqa: PLR0913
102
+ self: ONNXOpQuantizer,
103
+ node: onnx.NodeProto,
104
+ graph: onnx.GraphProto,
105
+ scale_exponent: int,
106
+ scale_base: int,
107
+ initializer_map: dict[str, onnx.TensorProto],
108
+ *,
109
+ rescale: bool = True,
110
+ ) -> onnx.NodeProto | list[onnx.NodeProto]:
111
+ """Quantize an ONNX node using its registered handler.
112
+
113
+ Args:
114
+ node (onnx.NodeProto): The ONNX node to quantize.
115
+ rescale (bool): Whether to apply rescaling.
116
+ graph (onnx.GraphProto): The ONNX graph containing the node.
117
+ scale_exponent (int): Quantization scale value.
118
+ The scaling becomes scale_base**scale_exponent.
119
+ scale_base (int): Base for the quantization scale.
120
+ The scaling becomes scale_base**scale.
121
+ initializer_map (dict[str, onnx.TensorProto]):
122
+ Mapping of initializer names (typically weights and biases) to tensors.
123
+
124
+ Returns:
125
+ Union[onnx.NodeProto, List[onnx.NodeProto]]: The quantized node +
126
+ any additional nodes created in the process.
127
+ """
128
+ handler = self.handlers.get(node.op_type)
129
+ if handler:
130
+ result = handler.quantize(
131
+ node=node,
132
+ graph=graph,
133
+ scale_config=ScaleConfig(scale_exponent, scale_base, rescale),
134
+ initializer_map=initializer_map,
135
+ )
136
+ if isinstance(result, onnx.NodeProto):
137
+ return [result]
138
+ return result
139
+
140
+ raise UnsupportedOpError(node.op_type)
141
+
142
+ def check_model(self: ONNXOpQuantizer, model: onnx.ModelProto) -> None:
143
+ """Verify that all nodes in the model are supported and valid.
144
+
145
+ Args:
146
+ model (onnx.ModelProto): The ONNX model to check.
147
+
148
+ Raises:
149
+ UnsupportedOpError: If the model contains unsupported operators.
150
+ """
151
+ initializer_map = self.get_initializer_map(model)
152
+
153
+ model_ops = {node.op_type for node in model.graph.node}
154
+ unsupported = model_ops - self.handlers.keys()
155
+
156
+ if unsupported:
157
+ raise UnsupportedOpError(unsupported)
158
+
159
+ # Call check_layer on each node (e.g., for param validation)
160
+ for node in model.graph.node:
161
+ self.check_layer(node, initializer_map)
162
+
163
+ def check_layer(
164
+ self: ONNXOpQuantizer,
165
+ node: onnx.NodeProto,
166
+ initializer_map: dict[str, onnx.TensorProto],
167
+ ) -> None:
168
+ """
169
+ Check an individual node using its handler.
170
+ Parameters for the node will be checked that they
171
+ meet the supported parameter requirements.
172
+
173
+ Args:
174
+ node (onnx.NodeProto): The node to check.
175
+ initializer_map (dict[str, onnx.TensorProto]): Mapping of initializer names
176
+ to tensor typically used in weights and biases.
177
+
178
+ Raises:
179
+ MissingHandlerError: If no handler is registered for the given node.
180
+ """
181
+ handler = self.handlers.get(node.op_type)
182
+ if not handler:
183
+ raise MissingHandlerError(node.op_type)
184
+
185
+ if hasattr(handler, "check_supported") and callable(handler.check_supported):
186
+ handler.check_supported(node, initializer_map)
187
+
188
+ def get_initializer_map(
189
+ self: ONNXOpQuantizer,
190
+ model: onnx.ModelProto,
191
+ ) -> dict[str, onnx.TensorProto]:
192
+ """Build a dictionary mapping initializer names to tensors in graph.
193
+
194
+ Args:
195
+ model (onnx.ModelProto): The ONNX model.
196
+
197
+ Returns:
198
+ dict[str, onnx.TensorProto]: Map from initializer name to tensors in graph.
199
+ """
200
+ return {init.name: init for init in model.graph.initializer}
File without changes
@@ -0,0 +1,57 @@
1
+ import torch.nn as nn
2
+ from python.core.circuits.base import Circuit
3
+ from random import randint
4
+
5
+ class SimpleCircuit(Circuit):
6
+ '''
7
+ Note: This template is irrelevant if using the onnx circuit builder.
8
+ The template only helps developers if they choose to incorporate other circuit builders into the framework.
9
+
10
+ To begin, we need to specify some basic attributes surrounding the circuit we will be using.
11
+ required_keys - specify the variables in the input dictionary (and input file).
12
+ name - name of the rust bin to be run by the circuit.
13
+
14
+ scale_base - specify the base of the scaling applied to each value
15
+ scale_exponent - the exponent applied to the base to get the scaling factor. Scaling factor will be multiplied by each input
16
+
17
+ Other default inputs can be defined below
18
+ '''
19
+ def __init__(self, file_name):
20
+ # Initialize the base class
21
+ super().__init__()
22
+
23
+ # Circuit-specific parameters
24
+ self.required_keys = ["input_a", "input_b", "nonce"]
25
+ self.name = "simple_circuit" # Use exact name that matches the binary
26
+
27
+ self.scale_exponent = 1
28
+ self.scale_base = 1
29
+
30
+ self.input_a = 100
31
+ self.input_b = 200
32
+ self.nonce = randint(0,10000)
33
+
34
+ '''
35
+ The following are some important functions used by the model. get inputs should be defined to specify the inputs to the circuit
36
+ '''
37
+ def get_inputs(self):
38
+ '''
39
+ Specify the inputs to the circuit, based on what was specified in the __init__. Can also have inputs to this function for the inputs.
40
+ '''
41
+ return {'input_a': self.input_a, 'input_b': self.input_b, 'nonce': self.nonce}
42
+
43
+ def get_outputs(self, inputs = None):
44
+ """
45
+ Compute the output of the circuit.
46
+ This is overwritten from the base class to ensure computation happens only once.
47
+ """
48
+ if inputs == None:
49
+ inputs = {'input_a': self.input_a, 'input_b': self.input_b, 'nonce': self.nonce}
50
+ print(f"Performing addition operation: {inputs['input_a']} + {inputs['input_b']}")
51
+ return inputs['input_a'] + inputs['input_b']
52
+
53
+ # def format_inputs(self, inputs):
54
+ # return {"input": inputs.long().tolist()}
55
+
56
+ # def format_outputs(self, outputs):
57
+ # return {"output": outputs.long().tolist()}
File without changes
@@ -0,0 +1,163 @@
1
+ # python/core/utils/benchmarking_helpers.py
2
+ from __future__ import annotations
3
+
4
+ # --- Standard library --------------------------------------------------------
5
+ import threading
6
+ import time
7
+
8
+ # --- Third-party -------------------------------------------------------------
9
+ import psutil
10
+
11
+ """
12
+ Lightweight helpers to measure peak memory usage of child processes during
13
+ benchmarks. Uses a background thread that periodically sums the RSS of all
14
+ descendants of the current process (optionally filtered by a name keyword).
15
+ """
16
+
17
+
18
+ def _safe_rss_kb(pid: int) -> int:
19
+ """
20
+ Return RSS for a PID in KB. On errors/missing process, return 0.
21
+ """
22
+ try:
23
+ proc = psutil.Process(pid)
24
+ rss_bytes = proc.memory_info().rss # type: ignore[attr-defined]
25
+ return int(rss_bytes // 1024)
26
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
27
+ return 0
28
+
29
+
30
+ def _list_children(parent_pid: int) -> list[psutil.Process]:
31
+ """
32
+ Return all descendant processes of a parent PID.
33
+ Empty list if the parent is gone or access is denied.
34
+ """
35
+ try:
36
+ parent = psutil.Process(parent_pid)
37
+ return parent.children(recursive=True)
38
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
39
+ return []
40
+
41
+
42
+ def _safe_name_lower(proc: psutil.Process) -> str | None:
43
+ """
44
+ Lowercased process name, or None if unavailable.
45
+ """
46
+ try:
47
+ return proc.name().lower()
48
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
49
+ return None
50
+
51
+
52
+ def monitor_subprocess_memory(
53
+ parent_pid: int,
54
+ process_name_keyword: str,
55
+ results: dict[str, int],
56
+ stop_event: threading.Event,
57
+ *,
58
+ poll_interval_s: float = 0.1,
59
+ ) -> None:
60
+ """
61
+ Track the peak sum of RSS across child processes of `parent_pid`.
62
+
63
+ If `process_name_keyword` is non-empty, only include children whose
64
+ name contains that lowercase keyword.
65
+
66
+ Writes peaks (KB) into `results` under:
67
+ - 'peak_subprocess_mem' (RSS)
68
+ - 'peak_subprocess_swap' (0; swap not collected here)
69
+ - 'peak_subprocess_total' (mem + swap)
70
+ """
71
+ keyword = process_name_keyword.strip().lower()
72
+ peak_rss_kb = 0
73
+
74
+ # Initialize keys so callers can inspect mid-run safely
75
+ results["peak_subprocess_mem"] = 0
76
+ results["peak_subprocess_swap"] = 0
77
+ results["peak_subprocess_total"] = 0
78
+
79
+ while not stop_event.is_set():
80
+ children = _list_children(parent_pid)
81
+ if not children and not psutil.pid_exists(parent_pid):
82
+ break
83
+
84
+ if keyword:
85
+ filtered: list[psutil.Process] = []
86
+ for c in children:
87
+ nm = _safe_name_lower(c)
88
+ if nm and keyword in nm:
89
+ filtered.append(c)
90
+ else:
91
+ filtered = children
92
+
93
+ rss_sum_kb = 0
94
+ for c in filtered:
95
+ rss_sum_kb += _safe_rss_kb(c.pid)
96
+
97
+ if rss_sum_kb > peak_rss_kb:
98
+ peak_rss_kb = rss_sum_kb
99
+ results["peak_subprocess_mem"] = peak_rss_kb
100
+ results["peak_subprocess_swap"] = 0
101
+ results["peak_subprocess_total"] = peak_rss_kb
102
+
103
+ time.sleep(poll_interval_s)
104
+
105
+ # Final write (covers the case where peak never changed inside the loop)
106
+ results["peak_subprocess_mem"] = max(
107
+ results.get("peak_subprocess_mem", 0),
108
+ peak_rss_kb,
109
+ )
110
+ results["peak_subprocess_swap"] = 0
111
+ results["peak_subprocess_total"] = results["peak_subprocess_mem"]
112
+
113
+
114
+ def start_memory_collection(
115
+ process_name: str,
116
+ ) -> tuple[threading.Event, threading.Thread, dict[str, int]]:
117
+ """
118
+ Spawn and start a monitoring thread for the current process' children.
119
+
120
+ Args:
121
+ process_name:
122
+ Optional substring to filter child process names (case-insensitive).
123
+ Pass "" to include all children.
124
+
125
+ Returns:
126
+ (stop_event, monitor_thread, monitor_results_dict)
127
+ """
128
+ parent_pid = psutil.Process().pid
129
+ monitor_results: dict[str, int] = {}
130
+ stop_event = threading.Event()
131
+ monitor_thread = threading.Thread(
132
+ target=monitor_subprocess_memory,
133
+ args=(parent_pid, process_name, monitor_results, stop_event),
134
+ kwargs={"poll_interval_s": 0.02},
135
+ daemon=True,
136
+ )
137
+ monitor_thread.start()
138
+ time.sleep(0.05) # allow thread to start and populate initial keys
139
+ return stop_event, monitor_thread, monitor_results
140
+
141
+
142
+ def end_memory_collection(
143
+ stop_event: threading.Event,
144
+ monitor_thread: threading.Thread,
145
+ monitor_results: dict[str, int],
146
+ ) -> dict[str, float]:
147
+ """
148
+ Stop the monitor thread and return a summary dict in MB:
149
+ {'ram': <MB>, 'swap': <MB>, 'total': <MB>}
150
+ """
151
+ stop_event.set()
152
+ monitor_thread.join(timeout=5.0)
153
+
154
+ rss_kb = int(monitor_results.get("peak_subprocess_mem", 0))
155
+ swap_kb = int(monitor_results.get("peak_subprocess_swap", 0))
156
+ total_kb = int(monitor_results.get("peak_subprocess_total", rss_kb + swap_kb))
157
+
158
+ kb_to_mb = 1.0 / 1024.0
159
+ return {
160
+ "ram": rss_kb * kb_to_mb,
161
+ "swap": swap_kb * kb_to_mb,
162
+ "total": total_kb * kb_to_mb,
163
+ }
@@ -0,0 +1,4 @@
1
+ from __future__ import annotations
2
+
3
+ MODEL_SOURCE_ONNX: str = "onnx"
4
+ MODEL_SOURCE_CLASS: str = "class"
@@ -0,0 +1,117 @@
1
+ from __future__ import annotations
2
+
3
+
4
+ class CircuitExecutionError(Exception):
5
+ """Base exception for all circuit execution-related errors."""
6
+
7
+ def __init__(self: CircuitExecutionError, message: str) -> None:
8
+ super().__init__(message)
9
+ self.message = message
10
+
11
+
12
+ class MissingFileError(CircuitExecutionError):
13
+ """Raised when cant find file"""
14
+
15
+ def __init__(self: MissingFileError, message: str, path: str | None = None) -> None:
16
+ full_message = message if path is None else f"{message} [Path: {path}]"
17
+ super().__init__(full_message)
18
+ self.path = path
19
+
20
+
21
+ class FileCacheError(CircuitExecutionError):
22
+ """Raised when reading or writing cached output fails."""
23
+
24
+ def __init__(self: FileCacheError, message: str, path: str | None = None) -> None:
25
+ full_message = message if path is None else f"{message} [Path: {path}]"
26
+ super().__init__(full_message)
27
+ self.path = path
28
+
29
+
30
+ class ProofBackendError(CircuitExecutionError):
31
+ """Raised when a Cargo command fails."""
32
+
33
+ def __init__( # noqa: PLR0913
34
+ self: ProofBackendError,
35
+ message: str,
36
+ command: list[str] | None = None,
37
+ returncode: int | None = None,
38
+ stdout: str | None = None,
39
+ stderr: str | None = None,
40
+ ) -> None:
41
+ parts = [message]
42
+ if command is not None:
43
+ command2 = [str(c) for c in command]
44
+ parts.append(f"Command: {' '.join(command2)}")
45
+ command = command2
46
+ if returncode is not None:
47
+ parts.append(f"Exit code: {returncode}")
48
+ if stdout:
49
+ parts.append(f"STDOUT:\n{stdout}")
50
+ if stderr:
51
+ parts.append(f"STDERR:\n{stderr}")
52
+ full_message = "\n".join(parts)
53
+ super().__init__(full_message)
54
+ self.command = command
55
+ self.returncode = returncode
56
+ self.stdout = stdout
57
+ self.stderr = stderr
58
+
59
+
60
+ class ProofSystemNotImplementedError(CircuitExecutionError):
61
+ """Raised when a proof system is not implemented."""
62
+
63
+ def __init__(self: ProofSystemNotImplementedError, proof_system: object) -> None:
64
+ message = f"Proof system '{proof_system}' is not implemented."
65
+ super().__init__(message)
66
+ self.proof_system = proof_system
67
+
68
+
69
+ class CircuitUtilsError(Exception):
70
+ """Base exception for layer utility errors."""
71
+
72
+
73
+ class InputFileError(CircuitUtilsError):
74
+ """Raised when reading an input file fails."""
75
+
76
+ def __init__(
77
+ self: InputFileError,
78
+ file_path: str,
79
+ message: str,
80
+ *,
81
+ cause: Exception | None = None,
82
+ ) -> None:
83
+ full_msg = f"Failed to read input file '{file_path}': {message}"
84
+ super().__init__(full_msg)
85
+ self.file_path = file_path
86
+ self.__cause__ = cause
87
+
88
+
89
+ class MissingCircuitAttributeError(CircuitUtilsError):
90
+ """Raised when a required attribute is missing or not set."""
91
+
92
+ def __init__(
93
+ self: MissingCircuitAttributeError,
94
+ attribute_name: str,
95
+ context: str | None = None,
96
+ ) -> None:
97
+ msg = f"Required attribute '{attribute_name}' is missing"
98
+ if context:
99
+ msg += f" ({context})"
100
+ super().__init__(msg)
101
+ self.attribute_name = attribute_name
102
+
103
+
104
+ class ShapeMismatchError(CircuitUtilsError):
105
+ """Raised when reshaping tensors fails due to incompatible shapes."""
106
+
107
+ def __init__(
108
+ self: ShapeMismatchError,
109
+ expected_shape: list[int],
110
+ actual_shape: list[int],
111
+ ) -> None:
112
+ super().__init__(
113
+ f"Cannot reshape tensor of shape {actual_shape}"
114
+ f" to expected shape {expected_shape}",
115
+ )
116
+ self.expected_shape = expected_shape
117
+ self.actual_shape = actual_shape