JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +6 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
import onnx
|
|
6
|
+
|
|
7
|
+
from python.core.model_processing.onnx_quantizer.exceptions import (
|
|
8
|
+
HandlerImplementationError,
|
|
9
|
+
MissingHandlerError,
|
|
10
|
+
UnsupportedOpError,
|
|
11
|
+
)
|
|
12
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
13
|
+
PassthroughQuantizer,
|
|
14
|
+
ScaleConfig,
|
|
15
|
+
)
|
|
16
|
+
from python.core.model_processing.onnx_quantizer.layers.constant import (
|
|
17
|
+
ConstantQuantizer,
|
|
18
|
+
)
|
|
19
|
+
from python.core.model_processing.onnx_quantizer.layers.conv import ConvQuantizer
|
|
20
|
+
from python.core.model_processing.onnx_quantizer.layers.gemm import GemmQuantizer
|
|
21
|
+
from python.core.model_processing.onnx_quantizer.layers.maxpool import MaxpoolQuantizer
|
|
22
|
+
from python.core.model_processing.onnx_quantizer.layers.relu import ReluQuantizer
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ONNXOpQuantizer:
|
|
26
|
+
"""
|
|
27
|
+
Registry for ONNX operator quantizers.
|
|
28
|
+
This should be used to obtain the quantized
|
|
29
|
+
layer based on any provided operation of that layer type
|
|
30
|
+
|
|
31
|
+
Attributes
|
|
32
|
+
----------
|
|
33
|
+
handlers : Dict[str, Callable]
|
|
34
|
+
Maps ONNX op_type strings to quantizer handler instances.
|
|
35
|
+
new_initializers : List[onnx.TensorProto]
|
|
36
|
+
A list of newly created ONNX initializers
|
|
37
|
+
(weights or biases typically) during quantization.
|
|
38
|
+
This is shared with handlers that may add new constants.
|
|
39
|
+
|
|
40
|
+
Methods
|
|
41
|
+
-------
|
|
42
|
+
register(op_type, handler)
|
|
43
|
+
Registers a handler for an ONNX op_type.
|
|
44
|
+
quantize(node, rescale, graph, scale_exponent, scale_base, initializer_map)
|
|
45
|
+
Apply quantization to a specific ONNX node using its registered handler.
|
|
46
|
+
check_model(model)
|
|
47
|
+
Ensure all operations in the model are supported and validate
|
|
48
|
+
each layer's parameters are valid and supported.
|
|
49
|
+
check_layer(node, initializer_map)
|
|
50
|
+
Validate a single ONNX node using its handler's check_supported method,
|
|
51
|
+
to check that the given layers parameters and structure is supported.
|
|
52
|
+
get_initializer_map(model)
|
|
53
|
+
Build a {name: TensorProto} mapping for the model's initializers.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self: ONNXOpQuantizer) -> None:
|
|
57
|
+
self.handlers: dict[
|
|
58
|
+
str,
|
|
59
|
+
Callable[
|
|
60
|
+
[onnx.NodeProto, bool],
|
|
61
|
+
onnx.NodeProto | list[onnx.NodeProto],
|
|
62
|
+
],
|
|
63
|
+
] = {}
|
|
64
|
+
self.new_initializers = []
|
|
65
|
+
|
|
66
|
+
# Register handlers
|
|
67
|
+
self.register("Conv", ConvQuantizer(self.new_initializers))
|
|
68
|
+
self.register("Relu", ReluQuantizer())
|
|
69
|
+
self.register("Reshape", PassthroughQuantizer())
|
|
70
|
+
self.register("Gemm", GemmQuantizer(self.new_initializers))
|
|
71
|
+
self.register("Constant", ConstantQuantizer())
|
|
72
|
+
self.register("MaxPool", MaxpoolQuantizer())
|
|
73
|
+
self.register("Flatten", PassthroughQuantizer())
|
|
74
|
+
|
|
75
|
+
def register(
|
|
76
|
+
self: ONNXOpQuantizer,
|
|
77
|
+
op_type: str,
|
|
78
|
+
handler: Callable[
|
|
79
|
+
[onnx.NodeProto, bool],
|
|
80
|
+
onnx.NodeProto | list[onnx.NodeProto],
|
|
81
|
+
],
|
|
82
|
+
) -> None:
|
|
83
|
+
"""Register a quantizer handler for a given ONNX op_type.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
op_type (str): Name of the ONNX operator type (e.g., "Conv", "Relu").
|
|
87
|
+
handler (Callable[[onnx.NodeProto, bool],
|
|
88
|
+
Union[onnx.NodeProto, list[onnx.NodeProto]]]):
|
|
89
|
+
- Handler instance implementing `quantize()`
|
|
90
|
+
(and optionally `check_supported()`).
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
HandlerImplementationError: If handler has not properly implemented
|
|
94
|
+
`quantize` method
|
|
95
|
+
"""
|
|
96
|
+
if not hasattr(handler, "quantize") or not callable(handler.quantize):
|
|
97
|
+
raise HandlerImplementationError(op_type, "Missing 'quantize' method.")
|
|
98
|
+
|
|
99
|
+
self.handlers[op_type] = handler
|
|
100
|
+
|
|
101
|
+
def quantize( # noqa: PLR0913
|
|
102
|
+
self: ONNXOpQuantizer,
|
|
103
|
+
node: onnx.NodeProto,
|
|
104
|
+
graph: onnx.GraphProto,
|
|
105
|
+
scale_exponent: int,
|
|
106
|
+
scale_base: int,
|
|
107
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
108
|
+
*,
|
|
109
|
+
rescale: bool = True,
|
|
110
|
+
) -> onnx.NodeProto | list[onnx.NodeProto]:
|
|
111
|
+
"""Quantize an ONNX node using its registered handler.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
node (onnx.NodeProto): The ONNX node to quantize.
|
|
115
|
+
rescale (bool): Whether to apply rescaling.
|
|
116
|
+
graph (onnx.GraphProto): The ONNX graph containing the node.
|
|
117
|
+
scale_exponent (int): Quantization scale value.
|
|
118
|
+
The scaling becomes scale_base**scale_exponent.
|
|
119
|
+
scale_base (int): Base for the quantization scale.
|
|
120
|
+
The scaling becomes scale_base**scale.
|
|
121
|
+
initializer_map (dict[str, onnx.TensorProto]):
|
|
122
|
+
Mapping of initializer names (typically weights and biases) to tensors.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Union[onnx.NodeProto, List[onnx.NodeProto]]: The quantized node +
|
|
126
|
+
any additional nodes created in the process.
|
|
127
|
+
"""
|
|
128
|
+
handler = self.handlers.get(node.op_type)
|
|
129
|
+
if handler:
|
|
130
|
+
result = handler.quantize(
|
|
131
|
+
node=node,
|
|
132
|
+
graph=graph,
|
|
133
|
+
scale_config=ScaleConfig(scale_exponent, scale_base, rescale),
|
|
134
|
+
initializer_map=initializer_map,
|
|
135
|
+
)
|
|
136
|
+
if isinstance(result, onnx.NodeProto):
|
|
137
|
+
return [result]
|
|
138
|
+
return result
|
|
139
|
+
|
|
140
|
+
raise UnsupportedOpError(node.op_type)
|
|
141
|
+
|
|
142
|
+
def check_model(self: ONNXOpQuantizer, model: onnx.ModelProto) -> None:
|
|
143
|
+
"""Verify that all nodes in the model are supported and valid.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
model (onnx.ModelProto): The ONNX model to check.
|
|
147
|
+
|
|
148
|
+
Raises:
|
|
149
|
+
UnsupportedOpError: If the model contains unsupported operators.
|
|
150
|
+
"""
|
|
151
|
+
initializer_map = self.get_initializer_map(model)
|
|
152
|
+
|
|
153
|
+
model_ops = {node.op_type for node in model.graph.node}
|
|
154
|
+
unsupported = model_ops - self.handlers.keys()
|
|
155
|
+
|
|
156
|
+
if unsupported:
|
|
157
|
+
raise UnsupportedOpError(unsupported)
|
|
158
|
+
|
|
159
|
+
# Call check_layer on each node (e.g., for param validation)
|
|
160
|
+
for node in model.graph.node:
|
|
161
|
+
self.check_layer(node, initializer_map)
|
|
162
|
+
|
|
163
|
+
def check_layer(
|
|
164
|
+
self: ONNXOpQuantizer,
|
|
165
|
+
node: onnx.NodeProto,
|
|
166
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
167
|
+
) -> None:
|
|
168
|
+
"""
|
|
169
|
+
Check an individual node using its handler.
|
|
170
|
+
Parameters for the node will be checked that they
|
|
171
|
+
meet the supported parameter requirements.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
node (onnx.NodeProto): The node to check.
|
|
175
|
+
initializer_map (dict[str, onnx.TensorProto]): Mapping of initializer names
|
|
176
|
+
to tensor typically used in weights and biases.
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
MissingHandlerError: If no handler is registered for the given node.
|
|
180
|
+
"""
|
|
181
|
+
handler = self.handlers.get(node.op_type)
|
|
182
|
+
if not handler:
|
|
183
|
+
raise MissingHandlerError(node.op_type)
|
|
184
|
+
|
|
185
|
+
if hasattr(handler, "check_supported") and callable(handler.check_supported):
|
|
186
|
+
handler.check_supported(node, initializer_map)
|
|
187
|
+
|
|
188
|
+
def get_initializer_map(
|
|
189
|
+
self: ONNXOpQuantizer,
|
|
190
|
+
model: onnx.ModelProto,
|
|
191
|
+
) -> dict[str, onnx.TensorProto]:
|
|
192
|
+
"""Build a dictionary mapping initializer names to tensors in graph.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
model (onnx.ModelProto): The ONNX model.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
dict[str, onnx.TensorProto]: Map from initializer name to tensors in graph.
|
|
199
|
+
"""
|
|
200
|
+
return {init.name: init for init in model.graph.initializer}
|
|
File without changes
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from python.core.circuits.base import Circuit
|
|
3
|
+
from random import randint
|
|
4
|
+
|
|
5
|
+
class SimpleCircuit(Circuit):
|
|
6
|
+
'''
|
|
7
|
+
Note: This template is irrelevant if using the onnx circuit builder.
|
|
8
|
+
The template only helps developers if they choose to incorporate other circuit builders into the framework.
|
|
9
|
+
|
|
10
|
+
To begin, we need to specify some basic attributes surrounding the circuit we will be using.
|
|
11
|
+
required_keys - specify the variables in the input dictionary (and input file).
|
|
12
|
+
name - name of the rust bin to be run by the circuit.
|
|
13
|
+
|
|
14
|
+
scale_base - specify the base of the scaling applied to each value
|
|
15
|
+
scale_exponent - the exponent applied to the base to get the scaling factor. Scaling factor will be multiplied by each input
|
|
16
|
+
|
|
17
|
+
Other default inputs can be defined below
|
|
18
|
+
'''
|
|
19
|
+
def __init__(self, file_name):
|
|
20
|
+
# Initialize the base class
|
|
21
|
+
super().__init__()
|
|
22
|
+
|
|
23
|
+
# Circuit-specific parameters
|
|
24
|
+
self.required_keys = ["input_a", "input_b", "nonce"]
|
|
25
|
+
self.name = "simple_circuit" # Use exact name that matches the binary
|
|
26
|
+
|
|
27
|
+
self.scale_exponent = 1
|
|
28
|
+
self.scale_base = 1
|
|
29
|
+
|
|
30
|
+
self.input_a = 100
|
|
31
|
+
self.input_b = 200
|
|
32
|
+
self.nonce = randint(0,10000)
|
|
33
|
+
|
|
34
|
+
'''
|
|
35
|
+
The following are some important functions used by the model. get inputs should be defined to specify the inputs to the circuit
|
|
36
|
+
'''
|
|
37
|
+
def get_inputs(self):
|
|
38
|
+
'''
|
|
39
|
+
Specify the inputs to the circuit, based on what was specified in the __init__. Can also have inputs to this function for the inputs.
|
|
40
|
+
'''
|
|
41
|
+
return {'input_a': self.input_a, 'input_b': self.input_b, 'nonce': self.nonce}
|
|
42
|
+
|
|
43
|
+
def get_outputs(self, inputs = None):
|
|
44
|
+
"""
|
|
45
|
+
Compute the output of the circuit.
|
|
46
|
+
This is overwritten from the base class to ensure computation happens only once.
|
|
47
|
+
"""
|
|
48
|
+
if inputs == None:
|
|
49
|
+
inputs = {'input_a': self.input_a, 'input_b': self.input_b, 'nonce': self.nonce}
|
|
50
|
+
print(f"Performing addition operation: {inputs['input_a']} + {inputs['input_b']}")
|
|
51
|
+
return inputs['input_a'] + inputs['input_b']
|
|
52
|
+
|
|
53
|
+
# def format_inputs(self, inputs):
|
|
54
|
+
# return {"input": inputs.long().tolist()}
|
|
55
|
+
|
|
56
|
+
# def format_outputs(self, outputs):
|
|
57
|
+
# return {"output": outputs.long().tolist()}
|
|
File without changes
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# python/core/utils/benchmarking_helpers.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
# --- Standard library --------------------------------------------------------
|
|
5
|
+
import threading
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
# --- Third-party -------------------------------------------------------------
|
|
9
|
+
import psutil
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
Lightweight helpers to measure peak memory usage of child processes during
|
|
13
|
+
benchmarks. Uses a background thread that periodically sums the RSS of all
|
|
14
|
+
descendants of the current process (optionally filtered by a name keyword).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _safe_rss_kb(pid: int) -> int:
|
|
19
|
+
"""
|
|
20
|
+
Return RSS for a PID in KB. On errors/missing process, return 0.
|
|
21
|
+
"""
|
|
22
|
+
try:
|
|
23
|
+
proc = psutil.Process(pid)
|
|
24
|
+
rss_bytes = proc.memory_info().rss # type: ignore[attr-defined]
|
|
25
|
+
return int(rss_bytes // 1024)
|
|
26
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
27
|
+
return 0
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _list_children(parent_pid: int) -> list[psutil.Process]:
|
|
31
|
+
"""
|
|
32
|
+
Return all descendant processes of a parent PID.
|
|
33
|
+
Empty list if the parent is gone or access is denied.
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
parent = psutil.Process(parent_pid)
|
|
37
|
+
return parent.children(recursive=True)
|
|
38
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
39
|
+
return []
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _safe_name_lower(proc: psutil.Process) -> str | None:
|
|
43
|
+
"""
|
|
44
|
+
Lowercased process name, or None if unavailable.
|
|
45
|
+
"""
|
|
46
|
+
try:
|
|
47
|
+
return proc.name().lower()
|
|
48
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def monitor_subprocess_memory(
|
|
53
|
+
parent_pid: int,
|
|
54
|
+
process_name_keyword: str,
|
|
55
|
+
results: dict[str, int],
|
|
56
|
+
stop_event: threading.Event,
|
|
57
|
+
*,
|
|
58
|
+
poll_interval_s: float = 0.1,
|
|
59
|
+
) -> None:
|
|
60
|
+
"""
|
|
61
|
+
Track the peak sum of RSS across child processes of `parent_pid`.
|
|
62
|
+
|
|
63
|
+
If `process_name_keyword` is non-empty, only include children whose
|
|
64
|
+
name contains that lowercase keyword.
|
|
65
|
+
|
|
66
|
+
Writes peaks (KB) into `results` under:
|
|
67
|
+
- 'peak_subprocess_mem' (RSS)
|
|
68
|
+
- 'peak_subprocess_swap' (0; swap not collected here)
|
|
69
|
+
- 'peak_subprocess_total' (mem + swap)
|
|
70
|
+
"""
|
|
71
|
+
keyword = process_name_keyword.strip().lower()
|
|
72
|
+
peak_rss_kb = 0
|
|
73
|
+
|
|
74
|
+
# Initialize keys so callers can inspect mid-run safely
|
|
75
|
+
results["peak_subprocess_mem"] = 0
|
|
76
|
+
results["peak_subprocess_swap"] = 0
|
|
77
|
+
results["peak_subprocess_total"] = 0
|
|
78
|
+
|
|
79
|
+
while not stop_event.is_set():
|
|
80
|
+
children = _list_children(parent_pid)
|
|
81
|
+
if not children and not psutil.pid_exists(parent_pid):
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
if keyword:
|
|
85
|
+
filtered: list[psutil.Process] = []
|
|
86
|
+
for c in children:
|
|
87
|
+
nm = _safe_name_lower(c)
|
|
88
|
+
if nm and keyword in nm:
|
|
89
|
+
filtered.append(c)
|
|
90
|
+
else:
|
|
91
|
+
filtered = children
|
|
92
|
+
|
|
93
|
+
rss_sum_kb = 0
|
|
94
|
+
for c in filtered:
|
|
95
|
+
rss_sum_kb += _safe_rss_kb(c.pid)
|
|
96
|
+
|
|
97
|
+
if rss_sum_kb > peak_rss_kb:
|
|
98
|
+
peak_rss_kb = rss_sum_kb
|
|
99
|
+
results["peak_subprocess_mem"] = peak_rss_kb
|
|
100
|
+
results["peak_subprocess_swap"] = 0
|
|
101
|
+
results["peak_subprocess_total"] = peak_rss_kb
|
|
102
|
+
|
|
103
|
+
time.sleep(poll_interval_s)
|
|
104
|
+
|
|
105
|
+
# Final write (covers the case where peak never changed inside the loop)
|
|
106
|
+
results["peak_subprocess_mem"] = max(
|
|
107
|
+
results.get("peak_subprocess_mem", 0),
|
|
108
|
+
peak_rss_kb,
|
|
109
|
+
)
|
|
110
|
+
results["peak_subprocess_swap"] = 0
|
|
111
|
+
results["peak_subprocess_total"] = results["peak_subprocess_mem"]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def start_memory_collection(
|
|
115
|
+
process_name: str,
|
|
116
|
+
) -> tuple[threading.Event, threading.Thread, dict[str, int]]:
|
|
117
|
+
"""
|
|
118
|
+
Spawn and start a monitoring thread for the current process' children.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
process_name:
|
|
122
|
+
Optional substring to filter child process names (case-insensitive).
|
|
123
|
+
Pass "" to include all children.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
(stop_event, monitor_thread, monitor_results_dict)
|
|
127
|
+
"""
|
|
128
|
+
parent_pid = psutil.Process().pid
|
|
129
|
+
monitor_results: dict[str, int] = {}
|
|
130
|
+
stop_event = threading.Event()
|
|
131
|
+
monitor_thread = threading.Thread(
|
|
132
|
+
target=monitor_subprocess_memory,
|
|
133
|
+
args=(parent_pid, process_name, monitor_results, stop_event),
|
|
134
|
+
kwargs={"poll_interval_s": 0.02},
|
|
135
|
+
daemon=True,
|
|
136
|
+
)
|
|
137
|
+
monitor_thread.start()
|
|
138
|
+
time.sleep(0.05) # allow thread to start and populate initial keys
|
|
139
|
+
return stop_event, monitor_thread, monitor_results
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def end_memory_collection(
|
|
143
|
+
stop_event: threading.Event,
|
|
144
|
+
monitor_thread: threading.Thread,
|
|
145
|
+
monitor_results: dict[str, int],
|
|
146
|
+
) -> dict[str, float]:
|
|
147
|
+
"""
|
|
148
|
+
Stop the monitor thread and return a summary dict in MB:
|
|
149
|
+
{'ram': <MB>, 'swap': <MB>, 'total': <MB>}
|
|
150
|
+
"""
|
|
151
|
+
stop_event.set()
|
|
152
|
+
monitor_thread.join(timeout=5.0)
|
|
153
|
+
|
|
154
|
+
rss_kb = int(monitor_results.get("peak_subprocess_mem", 0))
|
|
155
|
+
swap_kb = int(monitor_results.get("peak_subprocess_swap", 0))
|
|
156
|
+
total_kb = int(monitor_results.get("peak_subprocess_total", rss_kb + swap_kb))
|
|
157
|
+
|
|
158
|
+
kb_to_mb = 1.0 / 1024.0
|
|
159
|
+
return {
|
|
160
|
+
"ram": rss_kb * kb_to_mb,
|
|
161
|
+
"swap": swap_kb * kb_to_mb,
|
|
162
|
+
"total": total_kb * kb_to_mb,
|
|
163
|
+
}
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CircuitExecutionError(Exception):
|
|
5
|
+
"""Base exception for all circuit execution-related errors."""
|
|
6
|
+
|
|
7
|
+
def __init__(self: CircuitExecutionError, message: str) -> None:
|
|
8
|
+
super().__init__(message)
|
|
9
|
+
self.message = message
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MissingFileError(CircuitExecutionError):
|
|
13
|
+
"""Raised when cant find file"""
|
|
14
|
+
|
|
15
|
+
def __init__(self: MissingFileError, message: str, path: str | None = None) -> None:
|
|
16
|
+
full_message = message if path is None else f"{message} [Path: {path}]"
|
|
17
|
+
super().__init__(full_message)
|
|
18
|
+
self.path = path
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FileCacheError(CircuitExecutionError):
|
|
22
|
+
"""Raised when reading or writing cached output fails."""
|
|
23
|
+
|
|
24
|
+
def __init__(self: FileCacheError, message: str, path: str | None = None) -> None:
|
|
25
|
+
full_message = message if path is None else f"{message} [Path: {path}]"
|
|
26
|
+
super().__init__(full_message)
|
|
27
|
+
self.path = path
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ProofBackendError(CircuitExecutionError):
|
|
31
|
+
"""Raised when a Cargo command fails."""
|
|
32
|
+
|
|
33
|
+
def __init__( # noqa: PLR0913
|
|
34
|
+
self: ProofBackendError,
|
|
35
|
+
message: str,
|
|
36
|
+
command: list[str] | None = None,
|
|
37
|
+
returncode: int | None = None,
|
|
38
|
+
stdout: str | None = None,
|
|
39
|
+
stderr: str | None = None,
|
|
40
|
+
) -> None:
|
|
41
|
+
parts = [message]
|
|
42
|
+
if command is not None:
|
|
43
|
+
command2 = [str(c) for c in command]
|
|
44
|
+
parts.append(f"Command: {' '.join(command2)}")
|
|
45
|
+
command = command2
|
|
46
|
+
if returncode is not None:
|
|
47
|
+
parts.append(f"Exit code: {returncode}")
|
|
48
|
+
if stdout:
|
|
49
|
+
parts.append(f"STDOUT:\n{stdout}")
|
|
50
|
+
if stderr:
|
|
51
|
+
parts.append(f"STDERR:\n{stderr}")
|
|
52
|
+
full_message = "\n".join(parts)
|
|
53
|
+
super().__init__(full_message)
|
|
54
|
+
self.command = command
|
|
55
|
+
self.returncode = returncode
|
|
56
|
+
self.stdout = stdout
|
|
57
|
+
self.stderr = stderr
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ProofSystemNotImplementedError(CircuitExecutionError):
|
|
61
|
+
"""Raised when a proof system is not implemented."""
|
|
62
|
+
|
|
63
|
+
def __init__(self: ProofSystemNotImplementedError, proof_system: object) -> None:
|
|
64
|
+
message = f"Proof system '{proof_system}' is not implemented."
|
|
65
|
+
super().__init__(message)
|
|
66
|
+
self.proof_system = proof_system
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class CircuitUtilsError(Exception):
|
|
70
|
+
"""Base exception for layer utility errors."""
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class InputFileError(CircuitUtilsError):
|
|
74
|
+
"""Raised when reading an input file fails."""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self: InputFileError,
|
|
78
|
+
file_path: str,
|
|
79
|
+
message: str,
|
|
80
|
+
*,
|
|
81
|
+
cause: Exception | None = None,
|
|
82
|
+
) -> None:
|
|
83
|
+
full_msg = f"Failed to read input file '{file_path}': {message}"
|
|
84
|
+
super().__init__(full_msg)
|
|
85
|
+
self.file_path = file_path
|
|
86
|
+
self.__cause__ = cause
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class MissingCircuitAttributeError(CircuitUtilsError):
|
|
90
|
+
"""Raised when a required attribute is missing or not set."""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self: MissingCircuitAttributeError,
|
|
94
|
+
attribute_name: str,
|
|
95
|
+
context: str | None = None,
|
|
96
|
+
) -> None:
|
|
97
|
+
msg = f"Required attribute '{attribute_name}' is missing"
|
|
98
|
+
if context:
|
|
99
|
+
msg += f" ({context})"
|
|
100
|
+
super().__init__(msg)
|
|
101
|
+
self.attribute_name = attribute_name
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ShapeMismatchError(CircuitUtilsError):
|
|
105
|
+
"""Raised when reshaping tensors fails due to incompatible shapes."""
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self: ShapeMismatchError,
|
|
109
|
+
expected_shape: list[int],
|
|
110
|
+
actual_shape: list[int],
|
|
111
|
+
) -> None:
|
|
112
|
+
super().__init__(
|
|
113
|
+
f"Cannot reshape tensor of shape {actual_shape}"
|
|
114
|
+
f" to expected shape {expected_shape}",
|
|
115
|
+
)
|
|
116
|
+
self.expected_shape = expected_shape
|
|
117
|
+
self.actual_shape = actual_shape
|