JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +5 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,1000 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from python.core.utils.witness_utils import compare_witness_to_io, load_witness
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from python.core.circuits.errors import (
|
|
14
|
+
CircuitConfigurationError,
|
|
15
|
+
CircuitFileError,
|
|
16
|
+
CircuitInputError,
|
|
17
|
+
CircuitProcessingError,
|
|
18
|
+
CircuitRunError,
|
|
19
|
+
WitnessMatchError,
|
|
20
|
+
)
|
|
21
|
+
from python.core.utils.helper_functions import (
|
|
22
|
+
CircuitExecutionConfig,
|
|
23
|
+
RunType,
|
|
24
|
+
ZKProofSystems,
|
|
25
|
+
compile_circuit,
|
|
26
|
+
compute_and_store_output,
|
|
27
|
+
generate_proof,
|
|
28
|
+
generate_verification,
|
|
29
|
+
generate_witness,
|
|
30
|
+
prepare_io_files,
|
|
31
|
+
read_from_json,
|
|
32
|
+
run_end_to_end,
|
|
33
|
+
to_json,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Circuit:
|
|
38
|
+
"""
|
|
39
|
+
Base class for all ZK circuits.
|
|
40
|
+
|
|
41
|
+
This class defines the standard interface and common utilities for
|
|
42
|
+
building, testing, and running ZK circuits.
|
|
43
|
+
Subclasses are expected to implement circuit-specific logic such as
|
|
44
|
+
input preparation, output computation, and model handling.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self: Circuit) -> None:
|
|
48
|
+
# Default folder paths - can be overridden in subclasses
|
|
49
|
+
self.input_folder = "inputs"
|
|
50
|
+
self.proof_folder = "analysis"
|
|
51
|
+
self.temp_folder = "temp"
|
|
52
|
+
self.circuit_folder = ""
|
|
53
|
+
self.weights_folder = "weights"
|
|
54
|
+
self.output_folder = "output"
|
|
55
|
+
self.proof_system = ZKProofSystems.Expander
|
|
56
|
+
|
|
57
|
+
# This will be set by prepare_io_files decorator
|
|
58
|
+
self._file_info = None
|
|
59
|
+
self.required_keys = None
|
|
60
|
+
self.logger = logging.getLogger(__name__)
|
|
61
|
+
|
|
62
|
+
def check_attributes(self: Circuit) -> None:
|
|
63
|
+
"""
|
|
64
|
+
Check if the necessary attributes are defined in subclasses.
|
|
65
|
+
Must be overridden in subclasses
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
CircuitConfigurationError: If required attributes are missing.
|
|
69
|
+
"""
|
|
70
|
+
missing = [
|
|
71
|
+
attr
|
|
72
|
+
for attr in ("required_keys", "name", "scale_exponent", "scale_base")
|
|
73
|
+
if not hasattr(self, attr)
|
|
74
|
+
]
|
|
75
|
+
if missing:
|
|
76
|
+
raise CircuitConfigurationError(missing_attributes=missing)
|
|
77
|
+
|
|
78
|
+
def parse_inputs(self: Circuit, **kwargs: dict[str, Any]) -> None:
|
|
79
|
+
"""Parse and validate required input parameters
|
|
80
|
+
for the circuit into an instance attribute.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
**kwargs (dict[str, Any]): Input parameters to parse and validate.
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
NotImplementedError: If `required_keys` is not set.
|
|
87
|
+
KeyError: If any required parameter is missing.
|
|
88
|
+
ValueError: If any parameter value is not an integer or list of integers.
|
|
89
|
+
"""
|
|
90
|
+
if self.required_keys is None:
|
|
91
|
+
msg = "self.required_keys must"
|
|
92
|
+
" be specified in the circuit definition."
|
|
93
|
+
raise CircuitConfigurationError(
|
|
94
|
+
msg,
|
|
95
|
+
)
|
|
96
|
+
for key in self.required_keys:
|
|
97
|
+
if key not in kwargs:
|
|
98
|
+
msg = f"Missing required parameter: '{key}'"
|
|
99
|
+
raise CircuitInputError(msg)
|
|
100
|
+
|
|
101
|
+
value = kwargs[key]
|
|
102
|
+
|
|
103
|
+
# # Validate type (ensure integer)
|
|
104
|
+
if not isinstance(value, (int, list)):
|
|
105
|
+
msg = (
|
|
106
|
+
f"Parameter '{key}' must be an int or list of ints, "
|
|
107
|
+
f"got {type(value).__name__}."
|
|
108
|
+
)
|
|
109
|
+
raise CircuitInputError(
|
|
110
|
+
msg,
|
|
111
|
+
)
|
|
112
|
+
setattr(self, key, value)
|
|
113
|
+
|
|
114
|
+
@compute_and_store_output
|
|
115
|
+
def get_outputs(self: Circuit) -> None:
|
|
116
|
+
"""
|
|
117
|
+
Compute circuit outputs.
|
|
118
|
+
This method should be implemented by subclasses.
|
|
119
|
+
"""
|
|
120
|
+
msg = "get_outputs must be implemented"
|
|
121
|
+
raise NotImplementedError(msg)
|
|
122
|
+
|
|
123
|
+
def get_inputs(
|
|
124
|
+
self: Circuit,
|
|
125
|
+
file_path: str | None = None,
|
|
126
|
+
*,
|
|
127
|
+
is_scaled: bool | None = False,
|
|
128
|
+
) -> None:
|
|
129
|
+
"""
|
|
130
|
+
Compute and return the circuit's input values.
|
|
131
|
+
This method should be implemented by subclasses.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
file_path (str | None): Optional path to input file.
|
|
135
|
+
is_scaled (bool | None): Whether inputs are scaled.
|
|
136
|
+
"""
|
|
137
|
+
_ = file_path, is_scaled
|
|
138
|
+
msg = "get_inputs must be implemented"
|
|
139
|
+
raise NotImplementedError(msg)
|
|
140
|
+
|
|
141
|
+
@prepare_io_files
|
|
142
|
+
def base_testing(self: Circuit, exec_config: CircuitExecutionConfig) -> None:
|
|
143
|
+
"""Run the circuit in a specified mode
|
|
144
|
+
(testing, proving, compiling, etc.).
|
|
145
|
+
|
|
146
|
+
File path resolution is handled automatically by the
|
|
147
|
+
`prepare_io_files` decorator.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
exec_config (CircuitExecutionConfig): Configuration object containing
|
|
151
|
+
run_type, file paths, and other execution parameters.
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
CircuitConfigurationError: If `_file_info` is not set by the decorator.
|
|
155
|
+
"""
|
|
156
|
+
if exec_config.circuit_path is None:
|
|
157
|
+
exec_config.circuit_path = f"{exec_config.circuit_name}.txt"
|
|
158
|
+
|
|
159
|
+
if not self._file_info:
|
|
160
|
+
msg = (
|
|
161
|
+
"Circuit file information (_file_info)"
|
|
162
|
+
" must be set by the prepare_io_files decorator."
|
|
163
|
+
)
|
|
164
|
+
raise CircuitConfigurationError(
|
|
165
|
+
msg,
|
|
166
|
+
details={"decorator": "prepare_io_files"},
|
|
167
|
+
)
|
|
168
|
+
exec_config.metadata_path = self._file_info.get("metadata_path")
|
|
169
|
+
exec_config.architecture_path = self._file_info.get("architecture_path")
|
|
170
|
+
exec_config.w_and_b_path = self._file_info.get("w_and_b_path")
|
|
171
|
+
|
|
172
|
+
# Run the appropriate proof operation based on run_type
|
|
173
|
+
self.parse_proof_run_type(exec_config)
|
|
174
|
+
|
|
175
|
+
def _raise_unknown_run_type(self: Circuit, run_type: RunType) -> None:
|
|
176
|
+
self.logger.error("Unknown run type: %s", run_type)
|
|
177
|
+
msg = f"Unsupported run type: {run_type}"
|
|
178
|
+
raise CircuitRunError(
|
|
179
|
+
msg,
|
|
180
|
+
operation="parse_proof_run_type",
|
|
181
|
+
details={"run_type": run_type},
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def parse_proof_run_type(
|
|
185
|
+
self: Circuit,
|
|
186
|
+
exec_config: CircuitExecutionConfig,
|
|
187
|
+
) -> None:
|
|
188
|
+
"""Dispatch proof-related operations based on the selected run type.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
exec_config (CircuitExecutionConfig): Configuration object containing
|
|
192
|
+
file paths, run type, and other parameters.
|
|
193
|
+
|
|
194
|
+
Raises:
|
|
195
|
+
CircuitRunError: If `run_type` is unknown or operation fails.
|
|
196
|
+
"""
|
|
197
|
+
is_scaled = True
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
if exec_config.run_type == RunType.END_TO_END:
|
|
201
|
+
self._compile_preprocessing(
|
|
202
|
+
metadata_path=exec_config.metadata_path,
|
|
203
|
+
architecture_path=exec_config.architecture_path,
|
|
204
|
+
w_and_b_path=exec_config.w_and_b_path,
|
|
205
|
+
quantized_path=exec_config.quantized_path,
|
|
206
|
+
)
|
|
207
|
+
processed_input_file = self._gen_witness_preprocessing(
|
|
208
|
+
input_file=exec_config.input_file,
|
|
209
|
+
output_file=exec_config.output_file,
|
|
210
|
+
quantized_path=exec_config.quantized_path,
|
|
211
|
+
write_json=exec_config.write_json,
|
|
212
|
+
is_scaled=is_scaled,
|
|
213
|
+
)
|
|
214
|
+
run_end_to_end(
|
|
215
|
+
circuit_name=exec_config.circuit_name,
|
|
216
|
+
circuit_path=exec_config.circuit_path,
|
|
217
|
+
input_file=processed_input_file,
|
|
218
|
+
output_file=exec_config.output_file,
|
|
219
|
+
proof_system=exec_config.proof_system,
|
|
220
|
+
dev_mode=exec_config.dev_mode,
|
|
221
|
+
ecc=exec_config.ecc,
|
|
222
|
+
)
|
|
223
|
+
elif exec_config.run_type == RunType.COMPILE_CIRCUIT:
|
|
224
|
+
self._compile_preprocessing(
|
|
225
|
+
metadata_path=exec_config.metadata_path,
|
|
226
|
+
architecture_path=exec_config.architecture_path,
|
|
227
|
+
w_and_b_path=exec_config.w_and_b_path,
|
|
228
|
+
quantized_path=exec_config.quantized_path,
|
|
229
|
+
)
|
|
230
|
+
compile_circuit(
|
|
231
|
+
circuit_name=exec_config.circuit_name,
|
|
232
|
+
circuit_path=exec_config.circuit_path,
|
|
233
|
+
metadata_path=exec_config.metadata_path,
|
|
234
|
+
architecture_path=exec_config.architecture_path,
|
|
235
|
+
w_and_b_path=exec_config.w_and_b_path,
|
|
236
|
+
proof_system=exec_config.proof_system,
|
|
237
|
+
dev_mode=exec_config.dev_mode,
|
|
238
|
+
bench=exec_config.bench,
|
|
239
|
+
)
|
|
240
|
+
elif exec_config.run_type == RunType.GEN_WITNESS:
|
|
241
|
+
processed_input_file = self._gen_witness_preprocessing(
|
|
242
|
+
input_file=exec_config.input_file,
|
|
243
|
+
output_file=exec_config.output_file,
|
|
244
|
+
quantized_path=exec_config.quantized_path,
|
|
245
|
+
write_json=exec_config.write_json,
|
|
246
|
+
is_scaled=is_scaled,
|
|
247
|
+
)
|
|
248
|
+
generate_witness(
|
|
249
|
+
circuit_name=exec_config.circuit_name,
|
|
250
|
+
circuit_path=exec_config.circuit_path,
|
|
251
|
+
witness_file=exec_config.witness_file,
|
|
252
|
+
input_file=processed_input_file,
|
|
253
|
+
output_file=exec_config.output_file,
|
|
254
|
+
metadata_path=exec_config.metadata_path,
|
|
255
|
+
proof_system=exec_config.proof_system,
|
|
256
|
+
dev_mode=exec_config.dev_mode,
|
|
257
|
+
bench=exec_config.bench,
|
|
258
|
+
)
|
|
259
|
+
elif exec_config.run_type == RunType.PROVE_WITNESS:
|
|
260
|
+
generate_proof(
|
|
261
|
+
circuit_name=exec_config.circuit_name,
|
|
262
|
+
circuit_path=exec_config.circuit_path,
|
|
263
|
+
witness_file=exec_config.witness_file,
|
|
264
|
+
proof_file=exec_config.proof_file,
|
|
265
|
+
metadata_path=exec_config.metadata_path,
|
|
266
|
+
proof_system=exec_config.proof_system,
|
|
267
|
+
dev_mode=exec_config.dev_mode,
|
|
268
|
+
ecc=exec_config.ecc,
|
|
269
|
+
bench=exec_config.bench,
|
|
270
|
+
)
|
|
271
|
+
elif exec_config.run_type == RunType.GEN_VERIFY:
|
|
272
|
+
witness_file = exec_config.witness_file
|
|
273
|
+
output_file = exec_config.output_file
|
|
274
|
+
processed_input_file = self.rename_inputs(exec_config.input_file)
|
|
275
|
+
proof_system = exec_config.proof_system
|
|
276
|
+
if not self.load_and_compare_witness_to_io(
|
|
277
|
+
witness_path=witness_file,
|
|
278
|
+
input_path=processed_input_file,
|
|
279
|
+
output_path=output_file,
|
|
280
|
+
proof_system=proof_system,
|
|
281
|
+
):
|
|
282
|
+
raise WitnessMatchError # noqa: TRY301
|
|
283
|
+
generate_verification(
|
|
284
|
+
circuit_name=exec_config.circuit_name,
|
|
285
|
+
circuit_path=exec_config.circuit_path,
|
|
286
|
+
input_file=processed_input_file,
|
|
287
|
+
output_file=output_file,
|
|
288
|
+
witness_file=witness_file,
|
|
289
|
+
proof_file=exec_config.proof_file,
|
|
290
|
+
metadata_path=exec_config.metadata_path,
|
|
291
|
+
proof_system=proof_system,
|
|
292
|
+
dev_mode=exec_config.dev_mode,
|
|
293
|
+
ecc=exec_config.ecc,
|
|
294
|
+
bench=exec_config.bench,
|
|
295
|
+
)
|
|
296
|
+
else:
|
|
297
|
+
self._raise_unknown_run_type(exec_config.run_type)
|
|
298
|
+
except CircuitRunError:
|
|
299
|
+
self.logger.exception(
|
|
300
|
+
"Operation %s failed",
|
|
301
|
+
exec_config.run_type,
|
|
302
|
+
extra={"run_type": exec_config.run_type},
|
|
303
|
+
)
|
|
304
|
+
raise
|
|
305
|
+
except (
|
|
306
|
+
CircuitProcessingError,
|
|
307
|
+
CircuitInputError,
|
|
308
|
+
CircuitFileError,
|
|
309
|
+
Exception,
|
|
310
|
+
) as e:
|
|
311
|
+
self.logger.exception(
|
|
312
|
+
"Operation %s failed",
|
|
313
|
+
exec_config.run_type,
|
|
314
|
+
extra={"run_type": exec_config.run_type},
|
|
315
|
+
)
|
|
316
|
+
raise CircuitRunError(
|
|
317
|
+
operation=exec_config.run_type,
|
|
318
|
+
) from e
|
|
319
|
+
|
|
320
|
+
def load_and_compare_witness_to_io(
|
|
321
|
+
self: Circuit,
|
|
322
|
+
witness_path: str,
|
|
323
|
+
input_path: str,
|
|
324
|
+
output_path: str,
|
|
325
|
+
proof_system: ZKProofSystems,
|
|
326
|
+
) -> bool:
|
|
327
|
+
"""
|
|
328
|
+
Load a witness from disk and compare its
|
|
329
|
+
public inputs to expected inputs and outputs.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
witness_path (str): Path to the binary witness file.
|
|
333
|
+
input_path (str): Path to a JSON file containing expected inputs.
|
|
334
|
+
output_path (str): Path to a JSON file containing expected outputs.
|
|
335
|
+
Only the `"outputs"` field is used for comparison.
|
|
336
|
+
proof_system(ZKProofSystems): Proof system to be used.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
bool:
|
|
340
|
+
True if all witness public inputs match the expected inputs and outputs,
|
|
341
|
+
False otherwise.
|
|
342
|
+
|
|
343
|
+
Raises:
|
|
344
|
+
WitnessMatchError:
|
|
345
|
+
If the witness file is malformed or missing the modulus field.
|
|
346
|
+
"""
|
|
347
|
+
w = load_witness(witness_path, proof_system)
|
|
348
|
+
expected_inputs = self._read_from_json_safely(input_path)
|
|
349
|
+
expected_outputs = self._read_from_json_safely(output_path)
|
|
350
|
+
if "modulus" not in w:
|
|
351
|
+
msg = "Witness not correctly formed. Missing modulus."
|
|
352
|
+
raise WitnessMatchError(msg)
|
|
353
|
+
return compare_witness_to_io(
|
|
354
|
+
w,
|
|
355
|
+
expected_inputs,
|
|
356
|
+
expected_outputs,
|
|
357
|
+
w["modulus"],
|
|
358
|
+
proof_system,
|
|
359
|
+
self.scale_and_round,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
def contains_float(self: Circuit, obj: float | dict | list) -> bool:
|
|
363
|
+
"""Recursively check whether an object contains any float values.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
obj (float | dict | list): The object to inspect.
|
|
367
|
+
Can be a float, list, dict.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
bool: True if any float is found within the object
|
|
371
|
+
(including nested lists/dicts), False otherwise.
|
|
372
|
+
"""
|
|
373
|
+
if isinstance(obj, float):
|
|
374
|
+
return True
|
|
375
|
+
if isinstance(obj, dict):
|
|
376
|
+
return any(self.contains_float(v) for v in obj.values())
|
|
377
|
+
if isinstance(obj, list):
|
|
378
|
+
return any(self.contains_float(i) for i in obj)
|
|
379
|
+
return False
|
|
380
|
+
|
|
381
|
+
def adjust_shape(self: Circuit, shape: list[int] | dict[str, int]) -> list[int]:
|
|
382
|
+
"""Normalize a shape representation into a valid list of positive integers.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
shape (list[int] | dict[str, int]):
|
|
386
|
+
The shape, which can be a list of ints
|
|
387
|
+
or a dict containing one shape list.
|
|
388
|
+
|
|
389
|
+
Raises:
|
|
390
|
+
CircuitInputError:
|
|
391
|
+
If `shape` is a dict containing more than one shape definition.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
list[int]:
|
|
395
|
+
The adjusted shape where all non-positive values are replaced with 1.
|
|
396
|
+
"""
|
|
397
|
+
if isinstance(shape, dict):
|
|
398
|
+
# Get the first shape from the dict
|
|
399
|
+
# (assuming only one input is relevant here)
|
|
400
|
+
if len(shape.values()) == 1:
|
|
401
|
+
shape = next(iter(shape.values()))
|
|
402
|
+
else:
|
|
403
|
+
msg = (
|
|
404
|
+
"Shape dictionary contains multiple entries;"
|
|
405
|
+
" only one input shape is allowed."
|
|
406
|
+
)
|
|
407
|
+
raise CircuitInputError(
|
|
408
|
+
msg,
|
|
409
|
+
parameter="shape",
|
|
410
|
+
expected="dict with exactly one key-value pair",
|
|
411
|
+
details={"shape_keys": list(shape.keys())},
|
|
412
|
+
)
|
|
413
|
+
return [s if s > 0 else 1 for s in shape]
|
|
414
|
+
|
|
415
|
+
def scale_and_round(
|
|
416
|
+
self: Circuit,
|
|
417
|
+
value: list[int] | np.ndarray | torch.Tensor,
|
|
418
|
+
scale_base: int,
|
|
419
|
+
scale_exponent: int,
|
|
420
|
+
) -> list[int] | np.ndarray | torch.Tensor:
|
|
421
|
+
"""Scale and round numeric values to integers based on
|
|
422
|
+
circuit scaling parameters.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
value (list[int] | np.ndarray | torch.Tensor): The values to process.
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
list[int] | np.ndarray | torch.Tensor: The scaled and rounded values,
|
|
429
|
+
preserving the original structure.
|
|
430
|
+
"""
|
|
431
|
+
import torch # noqa: PLC0415
|
|
432
|
+
|
|
433
|
+
from python.core.model_processing.onnx_quantizer.layers.base import ( # noqa: PLC0415
|
|
434
|
+
BaseOpQuantizer,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
scaling = BaseOpQuantizer.get_scaling(
|
|
438
|
+
scale_base=scale_base,
|
|
439
|
+
scale_exponent=scale_exponent,
|
|
440
|
+
)
|
|
441
|
+
if self.contains_float(value):
|
|
442
|
+
return (
|
|
443
|
+
torch.round(
|
|
444
|
+
torch.tensor(value) * scaling,
|
|
445
|
+
)
|
|
446
|
+
.long()
|
|
447
|
+
.tolist()
|
|
448
|
+
)
|
|
449
|
+
return value
|
|
450
|
+
|
|
451
|
+
def adjust_inputs(self: Circuit, input_file: str) -> str:
|
|
452
|
+
"""
|
|
453
|
+
Load input values from a JSON file, adjust them by scaling
|
|
454
|
+
and reshaping according to circuit parameters,
|
|
455
|
+
and save the adjusted inputs to a new file.
|
|
456
|
+
|
|
457
|
+
Args:
|
|
458
|
+
input_file (str):
|
|
459
|
+
Path to the input JSON file containing the original input values.
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
str: Path to the new file containing the adjusted input values.
|
|
463
|
+
|
|
464
|
+
Raises:
|
|
465
|
+
CircuitFileError: If reading from or writing to JSON files fails.
|
|
466
|
+
CircuitInputError: If input validation fails
|
|
467
|
+
(e.g., multiple 'input' keys when expecting single).
|
|
468
|
+
CircuitConfigurationError: If required shape attributes are missing.
|
|
469
|
+
CircuitProcessingError: If reshaping or scaling operations fail.
|
|
470
|
+
"""
|
|
471
|
+
inputs = self._read_from_json_safely(input_file)
|
|
472
|
+
|
|
473
|
+
input_variables = getattr(self, "input_variables", ["input"])
|
|
474
|
+
if input_variables == ["input"]:
|
|
475
|
+
new_inputs = self._adjust_single_input(inputs)
|
|
476
|
+
else:
|
|
477
|
+
new_inputs = self._adjust_multiple_inputs(inputs, input_variables)
|
|
478
|
+
|
|
479
|
+
# Save reshaped inputs
|
|
480
|
+
path = Path(input_file)
|
|
481
|
+
new_input_file = path.stem + "_reshaped" + path.suffix
|
|
482
|
+
self._to_json_safely(new_inputs, new_input_file, "adjusted input")
|
|
483
|
+
return new_input_file
|
|
484
|
+
|
|
485
|
+
def _adjust_single_input(self: Circuit, inputs: dict) -> dict:
|
|
486
|
+
"""
|
|
487
|
+
Adjust inputs when there is a single 'input' variable,
|
|
488
|
+
handling special cases like multiple keys containing 'input'
|
|
489
|
+
or fallback from 'output' to 'input'.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
inputs (dict): Dictionary of input values loaded from JSON.
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
dict: Adjusted inputs with scaled and reshaped values.
|
|
496
|
+
|
|
497
|
+
Raises:
|
|
498
|
+
CircuitInputError:
|
|
499
|
+
If multiple keys containing 'input' are found
|
|
500
|
+
or if required shape attributes are missing.
|
|
501
|
+
"""
|
|
502
|
+
new_inputs: dict[str, Any] = {}
|
|
503
|
+
has_input_been_found = False
|
|
504
|
+
|
|
505
|
+
for key, value in inputs.items():
|
|
506
|
+
value_adjusted = self.scale_and_round(
|
|
507
|
+
value,
|
|
508
|
+
self.scale_base,
|
|
509
|
+
self.scale_exponent,
|
|
510
|
+
)
|
|
511
|
+
if "input" in key:
|
|
512
|
+
if has_input_been_found:
|
|
513
|
+
msg = (
|
|
514
|
+
"Multiple inputs found containing 'input'. "
|
|
515
|
+
"Only one allowed when input_variables = ['input']"
|
|
516
|
+
)
|
|
517
|
+
raise CircuitInputError(
|
|
518
|
+
msg,
|
|
519
|
+
parameter="input",
|
|
520
|
+
expected="single input key",
|
|
521
|
+
details={"input_keys": [k for k in inputs if "input" in k]},
|
|
522
|
+
)
|
|
523
|
+
has_input_been_found = True
|
|
524
|
+
value_adjusted = self._reshape_input_value(
|
|
525
|
+
value_adjusted,
|
|
526
|
+
"input_shape",
|
|
527
|
+
key,
|
|
528
|
+
)
|
|
529
|
+
new_inputs["input"] = value_adjusted
|
|
530
|
+
else:
|
|
531
|
+
new_inputs[key] = value_adjusted
|
|
532
|
+
|
|
533
|
+
# Special case: fallback mapping output → input
|
|
534
|
+
if "input" not in new_inputs and "output" in new_inputs:
|
|
535
|
+
new_inputs["input"] = inputs["output"]
|
|
536
|
+
del inputs["output"]
|
|
537
|
+
|
|
538
|
+
return new_inputs
|
|
539
|
+
|
|
540
|
+
def _adjust_multiple_inputs(
|
|
541
|
+
self: Circuit,
|
|
542
|
+
inputs: dict,
|
|
543
|
+
input_variables: list[str],
|
|
544
|
+
) -> dict:
|
|
545
|
+
"""
|
|
546
|
+
Adjust inputs when there are multiple named input variables,
|
|
547
|
+
scaling and reshaping each according to their respective shape attributes.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
inputs (dict): Dictionary of input values loaded from JSON.
|
|
551
|
+
input_variables (list[str]): List of input variable names to adjust.
|
|
552
|
+
|
|
553
|
+
Returns:
|
|
554
|
+
dict: Adjusted inputs with scaled and reshaped values.
|
|
555
|
+
|
|
556
|
+
Raises:
|
|
557
|
+
CircuitConfigurationError:
|
|
558
|
+
If required shape attributes are missing for any input variable.
|
|
559
|
+
CircuitProcessingError: If reshaping operations fail.
|
|
560
|
+
"""
|
|
561
|
+
new_inputs: dict[str, Any] = {}
|
|
562
|
+
for key, value in inputs.items():
|
|
563
|
+
value_adjusted = self.scale_and_round(
|
|
564
|
+
value,
|
|
565
|
+
self.scale_base,
|
|
566
|
+
self.scale_exponent,
|
|
567
|
+
)
|
|
568
|
+
if key in input_variables:
|
|
569
|
+
shape_attr = f"{key}_shape"
|
|
570
|
+
value_adjusted = self._reshape_input_value(
|
|
571
|
+
value_adjusted,
|
|
572
|
+
shape_attr,
|
|
573
|
+
key,
|
|
574
|
+
)
|
|
575
|
+
new_inputs[key] = value_adjusted
|
|
576
|
+
return new_inputs
|
|
577
|
+
|
|
578
|
+
def _reshape_input_value(
|
|
579
|
+
self: Circuit,
|
|
580
|
+
value: list[int] | np.ndarray | torch.Tensor,
|
|
581
|
+
shape_attr: str,
|
|
582
|
+
input_key: str,
|
|
583
|
+
) -> list[int]:
|
|
584
|
+
"""
|
|
585
|
+
Reshape an input value to match the
|
|
586
|
+
specified shape attribute of the circuit.
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
value (list[int] | np.ndarray | torch.Tensor):
|
|
590
|
+
The input value to reshape, typically a list or tensor-like structure.
|
|
591
|
+
shape_attr (str):
|
|
592
|
+
Name of the attribute containing the target shape (e.g., 'input_shape').
|
|
593
|
+
input_key (str):
|
|
594
|
+
Key of the input being reshaped, used for error messages.
|
|
595
|
+
|
|
596
|
+
Returns:
|
|
597
|
+
list[int]: The reshaped value as a list.
|
|
598
|
+
|
|
599
|
+
Raises:
|
|
600
|
+
CircuitConfigurationError: If the required shape attribute is not defined.
|
|
601
|
+
CircuitProcessingError: If the reshaping operation fails.
|
|
602
|
+
"""
|
|
603
|
+
if not hasattr(self, shape_attr):
|
|
604
|
+
msg = f"Required shape attribute '{shape_attr}'"
|
|
605
|
+
f" must be defined to reshape input '{input_key}'."
|
|
606
|
+
raise CircuitConfigurationError(
|
|
607
|
+
msg,
|
|
608
|
+
missing_attributes=[shape_attr],
|
|
609
|
+
details={"input_key": input_key},
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
import torch # noqa: PLC0415
|
|
613
|
+
|
|
614
|
+
shape = getattr(self, shape_attr)
|
|
615
|
+
shape = self.adjust_shape(shape)
|
|
616
|
+
|
|
617
|
+
try:
|
|
618
|
+
return torch.tensor(value).reshape(shape).tolist()
|
|
619
|
+
except Exception as e:
|
|
620
|
+
msg = f"Failed to reshape input data for '{input_key}'."
|
|
621
|
+
raise CircuitProcessingError(
|
|
622
|
+
msg,
|
|
623
|
+
operation="reshape",
|
|
624
|
+
data_type="tensor",
|
|
625
|
+
details={"shape": shape},
|
|
626
|
+
) from e
|
|
627
|
+
|
|
628
|
+
def _to_json_safely(
|
|
629
|
+
self: Circuit,
|
|
630
|
+
inputs: dict,
|
|
631
|
+
input_file: str,
|
|
632
|
+
var_name: str,
|
|
633
|
+
) -> None:
|
|
634
|
+
"""Safely write data to a JSON file, handling exceptions.
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
inputs (dict): Data to write.
|
|
638
|
+
input_file (str): Path to the output file.
|
|
639
|
+
var_name (str): Name of the variable for error messages.
|
|
640
|
+
"""
|
|
641
|
+
try:
|
|
642
|
+
to_json(inputs, input_file)
|
|
643
|
+
except Exception as e:
|
|
644
|
+
msg = f"Failed to write {var_name} file: {input_file}"
|
|
645
|
+
raise CircuitFileError(
|
|
646
|
+
msg,
|
|
647
|
+
file_path=input_file,
|
|
648
|
+
) from e
|
|
649
|
+
|
|
650
|
+
def _read_from_json_safely(
|
|
651
|
+
self: Circuit,
|
|
652
|
+
input_file: str,
|
|
653
|
+
) -> dict[str, Any]:
|
|
654
|
+
"""Safely read data from a JSON file, handling exceptions.
|
|
655
|
+
|
|
656
|
+
Args:
|
|
657
|
+
input_file (str): Path to the input file.
|
|
658
|
+
|
|
659
|
+
Returns:
|
|
660
|
+
dict[str, Any]: Data read from the file.
|
|
661
|
+
"""
|
|
662
|
+
try:
|
|
663
|
+
return read_from_json(input_file)
|
|
664
|
+
except Exception as e:
|
|
665
|
+
msg = f"Failed to read input file: {input_file}"
|
|
666
|
+
raise CircuitFileError(
|
|
667
|
+
msg,
|
|
668
|
+
file_path=input_file,
|
|
669
|
+
) from e
|
|
670
|
+
|
|
671
|
+
def _gen_witness_preprocessing(
|
|
672
|
+
self: Circuit,
|
|
673
|
+
input_file: str,
|
|
674
|
+
output_file: str,
|
|
675
|
+
quantized_path: str,
|
|
676
|
+
*,
|
|
677
|
+
write_json: bool,
|
|
678
|
+
is_scaled: bool,
|
|
679
|
+
) -> str:
|
|
680
|
+
"""Preprocess inputs and outputs before witness generation.
|
|
681
|
+
|
|
682
|
+
Args:
|
|
683
|
+
input_file (str): Path to the input JSON file.
|
|
684
|
+
output_file (str): Path to save computed outputs.
|
|
685
|
+
quantized_path (str): Path to quantized model file.
|
|
686
|
+
write_json (bool): Whether to compute new inputs and write to JSON.
|
|
687
|
+
is_scaled (bool): Whether the inputs are already scaled.
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
str: Path to the final processed input file.
|
|
691
|
+
"""
|
|
692
|
+
# Rescale and reshape
|
|
693
|
+
if quantized_path:
|
|
694
|
+
self.load_quantized_model(quantized_path)
|
|
695
|
+
else:
|
|
696
|
+
self.load_quantized_model(self._file_info.get("quantized_model_path"))
|
|
697
|
+
|
|
698
|
+
if write_json:
|
|
699
|
+
inputs = self.get_inputs()
|
|
700
|
+
outputs = self.get_outputs(inputs)
|
|
701
|
+
|
|
702
|
+
inputs = self.format_inputs(inputs)
|
|
703
|
+
|
|
704
|
+
output = self.format_outputs(outputs)
|
|
705
|
+
|
|
706
|
+
self._to_json_safely(inputs, input_file, "input")
|
|
707
|
+
self._to_json_safely(output, output_file, "output")
|
|
708
|
+
|
|
709
|
+
else:
|
|
710
|
+
input_file = self.adjust_inputs(input_file)
|
|
711
|
+
inputs = self.get_inputs_from_file(input_file, is_scaled=is_scaled)
|
|
712
|
+
# Compute output (with caching via decorator)
|
|
713
|
+
output = self.get_outputs(inputs)
|
|
714
|
+
outputs = self.format_outputs(output)
|
|
715
|
+
|
|
716
|
+
self._to_json_safely(outputs, output_file, "output")
|
|
717
|
+
return input_file
|
|
718
|
+
|
|
719
|
+
def _compile_preprocessing(
|
|
720
|
+
self: Circuit,
|
|
721
|
+
metadata_path: str,
|
|
722
|
+
architecture_path: str,
|
|
723
|
+
w_and_b_path: str,
|
|
724
|
+
quantized_path: str,
|
|
725
|
+
) -> None:
|
|
726
|
+
"""Prepare model weights and quantized files for circuit compilation.
|
|
727
|
+
|
|
728
|
+
Args:
|
|
729
|
+
metadata_path (str): Path to save model metadata in JSON format.
|
|
730
|
+
architecture_path (str): Path to save model architecture in JSON format.
|
|
731
|
+
w_and_b_path (str): Path to save model weights and biases in JSON format.
|
|
732
|
+
quantized_path (str): Path to save the quantized model.
|
|
733
|
+
|
|
734
|
+
Raises:
|
|
735
|
+
CircuitConfigurationError: If model weights type is unsupported.
|
|
736
|
+
"""
|
|
737
|
+
func_model_and_quantize = getattr(self, "get_model_and_quantize", None)
|
|
738
|
+
if callable(func_model_and_quantize):
|
|
739
|
+
func_model_and_quantize()
|
|
740
|
+
|
|
741
|
+
metadata = self.get_metadata()
|
|
742
|
+
architecture = self.get_architecture()
|
|
743
|
+
w_and_b = self.get_w_and_b()
|
|
744
|
+
|
|
745
|
+
if quantized_path:
|
|
746
|
+
self.save_quantized_model(quantized_path)
|
|
747
|
+
else:
|
|
748
|
+
self.save_quantized_model(self._file_info.get("quantized_model_path"))
|
|
749
|
+
|
|
750
|
+
if metadata:
|
|
751
|
+
self._to_json_safely(metadata, metadata_path, "metadata")
|
|
752
|
+
if architecture:
|
|
753
|
+
self._to_json_safely(architecture, architecture_path, "architecture")
|
|
754
|
+
|
|
755
|
+
if isinstance(w_and_b, list):
|
|
756
|
+
for i, w in enumerate(w_and_b):
|
|
757
|
+
if i == 0:
|
|
758
|
+
self._to_json_safely(w, Path(w_and_b_path), "w_and_b")
|
|
759
|
+
else:
|
|
760
|
+
val = i + 1
|
|
761
|
+
file_path = (
|
|
762
|
+
Path(w_and_b_path).parent
|
|
763
|
+
/ f"{Path(w_and_b_path).stem!s}{val}{Path(w_and_b_path).suffix}"
|
|
764
|
+
)
|
|
765
|
+
self._to_json_safely(w, file_path, "w_and_b")
|
|
766
|
+
elif isinstance(w_and_b, (dict, tuple)):
|
|
767
|
+
self._to_json_safely(w_and_b, w_and_b_path, "w_and_b")
|
|
768
|
+
else:
|
|
769
|
+
msg = f"Unsupported w_and_b type: {type(w_and_b)}."
|
|
770
|
+
" Expected list, dict, or tuple."
|
|
771
|
+
raise CircuitConfigurationError(
|
|
772
|
+
msg,
|
|
773
|
+
details={"w_and_b_type": str(type(w_and_b))},
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
def save_model(self: Circuit, file_path: str) -> None:
|
|
777
|
+
"""
|
|
778
|
+
Save the current model to a file. Should be overridden in subclasses
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
file_path (str): Path to save the model.
|
|
782
|
+
"""
|
|
783
|
+
|
|
784
|
+
def load_model(self: Circuit, file_path: str) -> None:
|
|
785
|
+
"""
|
|
786
|
+
Load the model from a file. Should be overridden in subclasses
|
|
787
|
+
|
|
788
|
+
Args:
|
|
789
|
+
file_path (str): Path to load the model.
|
|
790
|
+
"""
|
|
791
|
+
|
|
792
|
+
def save_quantized_model(self: Circuit, file_path: str) -> None:
|
|
793
|
+
"""
|
|
794
|
+
Save the current quantized model to a file. Should be overridden in subclasses
|
|
795
|
+
|
|
796
|
+
Args:
|
|
797
|
+
file_path (str): Path to save the model.
|
|
798
|
+
"""
|
|
799
|
+
|
|
800
|
+
def load_quantized_model(self: Circuit, file_path: str) -> None:
|
|
801
|
+
"""
|
|
802
|
+
Load the quantized model from a file. Should be overridden in subclasses
|
|
803
|
+
|
|
804
|
+
Args:
|
|
805
|
+
file_path (str): Path to load the model.
|
|
806
|
+
"""
|
|
807
|
+
|
|
808
|
+
def get_weights(self: Circuit) -> dict:
|
|
809
|
+
"""Retrieve model weights. Should be overridden in subclasses
|
|
810
|
+
|
|
811
|
+
Returns:
|
|
812
|
+
dict: Model weights.
|
|
813
|
+
"""
|
|
814
|
+
return {}
|
|
815
|
+
|
|
816
|
+
def get_metadata(self: Circuit) -> dict:
|
|
817
|
+
"""Retrieve model metadata. Should be overridden in subclasses
|
|
818
|
+
|
|
819
|
+
Returns:
|
|
820
|
+
dict: Model metadata.
|
|
821
|
+
"""
|
|
822
|
+
return {}
|
|
823
|
+
|
|
824
|
+
def get_architecture(self: Circuit) -> dict:
|
|
825
|
+
"""Retrieve model architecture. Should be overridden in subclasses
|
|
826
|
+
|
|
827
|
+
Returns:
|
|
828
|
+
dict: Model architecture.
|
|
829
|
+
"""
|
|
830
|
+
return {}
|
|
831
|
+
|
|
832
|
+
def get_w_and_b(self: Circuit) -> dict:
|
|
833
|
+
"""Retrieve model weights and biases. Should be overridden in subclasses
|
|
834
|
+
|
|
835
|
+
Returns:
|
|
836
|
+
dict: Model weights and biases.
|
|
837
|
+
"""
|
|
838
|
+
return self.get_weights()
|
|
839
|
+
|
|
840
|
+
def get_inputs_from_file(
|
|
841
|
+
self: Circuit,
|
|
842
|
+
input_file: str,
|
|
843
|
+
*,
|
|
844
|
+
is_scaled: bool = True,
|
|
845
|
+
) -> dict[str, list[int]]:
|
|
846
|
+
"""Load input values from a JSON file, scaling if necessary.
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
input_file (str): Path to the input JSON file.
|
|
850
|
+
is_scaled (bool, optional): If False, scale inputs
|
|
851
|
+
according to circuit settings. Defaults to True.
|
|
852
|
+
|
|
853
|
+
Returns:
|
|
854
|
+
dict[str, list[int]]: Mapping from input names to integer lists of inputs.
|
|
855
|
+
"""
|
|
856
|
+
if is_scaled:
|
|
857
|
+
return self._read_from_json_safely(input_file)
|
|
858
|
+
|
|
859
|
+
import torch # noqa: PLC0415
|
|
860
|
+
|
|
861
|
+
from python.core.model_processing.onnx_quantizer.layers.base import ( # noqa: PLC0415
|
|
862
|
+
BaseOpQuantizer,
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
out = {}
|
|
866
|
+
read = self._read_from_json_safely(input_file)
|
|
867
|
+
|
|
868
|
+
scaling = BaseOpQuantizer.get_scaling(self.scale_base, self.scale_exponent)
|
|
869
|
+
try:
|
|
870
|
+
for k in read:
|
|
871
|
+
|
|
872
|
+
out[k] = torch.as_tensor(read[k]) * scaling
|
|
873
|
+
out[k] = out[k].tolist()
|
|
874
|
+
except Exception as e:
|
|
875
|
+
msg = f"Failed to scale input data for key '{k}'"
|
|
876
|
+
raise CircuitProcessingError(
|
|
877
|
+
msg,
|
|
878
|
+
operation="scale",
|
|
879
|
+
data_type="tensor",
|
|
880
|
+
details={"key": k},
|
|
881
|
+
) from e
|
|
882
|
+
return out
|
|
883
|
+
|
|
884
|
+
def scale_inputs_only(self: Circuit, input_file: str) -> str:
|
|
885
|
+
"""
|
|
886
|
+
Load input values from a JSON file, scale them according to circuit parameters,
|
|
887
|
+
without reshaping, and save the scaled inputs to a new file.
|
|
888
|
+
|
|
889
|
+
Args:
|
|
890
|
+
input_file (str):
|
|
891
|
+
Path to the input JSON file containing the original input values.
|
|
892
|
+
|
|
893
|
+
Returns:
|
|
894
|
+
str: Path to the new file containing the scaled input values.
|
|
895
|
+
|
|
896
|
+
Raises:
|
|
897
|
+
CircuitFileError: If reading from or writing to JSON files fails.
|
|
898
|
+
"""
|
|
899
|
+
inputs = self._read_from_json_safely(input_file)
|
|
900
|
+
|
|
901
|
+
new_inputs = {}
|
|
902
|
+
for key, value in inputs.items():
|
|
903
|
+
new_inputs[key] = self.scale_and_round(
|
|
904
|
+
value,
|
|
905
|
+
self.scale_base,
|
|
906
|
+
self.scale_exponent,
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
# Save scaled inputs
|
|
910
|
+
path = Path(input_file)
|
|
911
|
+
new_input_file = path.stem + "_scaled" + path.suffix
|
|
912
|
+
self._to_json_safely(new_inputs, new_input_file, "scaled input")
|
|
913
|
+
return new_input_file
|
|
914
|
+
|
|
915
|
+
def rename_inputs(self: Circuit, input_file: str) -> str:
|
|
916
|
+
"""
|
|
917
|
+
Load input values from a JSON file, rename keys according to circuit logic
|
|
918
|
+
(similar to adjust_inputs but without scaling or reshaping),
|
|
919
|
+
and save the renamed inputs to a new file.
|
|
920
|
+
|
|
921
|
+
Args:
|
|
922
|
+
input_file (str):
|
|
923
|
+
Path to the input JSON file containing the original input values.
|
|
924
|
+
|
|
925
|
+
Returns:
|
|
926
|
+
str: Path to the new file containing the renamed input values.
|
|
927
|
+
|
|
928
|
+
Raises:
|
|
929
|
+
CircuitFileError: If reading from or writing to JSON files fails.
|
|
930
|
+
CircuitInputError: If input validation fails.
|
|
931
|
+
"""
|
|
932
|
+
inputs = self._read_from_json_safely(input_file)
|
|
933
|
+
|
|
934
|
+
input_variables = getattr(self, "input_variables", ["input"])
|
|
935
|
+
if input_variables == ["input"]:
|
|
936
|
+
new_inputs = self._rename_single_input(inputs)
|
|
937
|
+
else:
|
|
938
|
+
new_inputs = dict(inputs.items())
|
|
939
|
+
|
|
940
|
+
# Save renamed inputs
|
|
941
|
+
path = Path(input_file)
|
|
942
|
+
new_input_file = path.stem + "_renamed" + path.suffix
|
|
943
|
+
self._to_json_safely(new_inputs, new_input_file, "renamed input")
|
|
944
|
+
return new_input_file
|
|
945
|
+
|
|
946
|
+
def _rename_single_input(self: Circuit, inputs: dict) -> dict:
|
|
947
|
+
"""
|
|
948
|
+
Rename inputs when there is a single 'input' variable,
|
|
949
|
+
handling special cases like multiple keys containing 'input'
|
|
950
|
+
or fallback from 'output' to 'input'. No scaling or reshaping.
|
|
951
|
+
|
|
952
|
+
Args:
|
|
953
|
+
inputs (dict): Dictionary of input values loaded from JSON.
|
|
954
|
+
|
|
955
|
+
Returns:
|
|
956
|
+
dict: Renamed inputs with appropriate key mappings.
|
|
957
|
+
|
|
958
|
+
Raises:
|
|
959
|
+
CircuitInputError:
|
|
960
|
+
If multiple keys containing 'input' are found.
|
|
961
|
+
"""
|
|
962
|
+
new_inputs: dict[str, Any] = {}
|
|
963
|
+
has_input_been_found = False
|
|
964
|
+
|
|
965
|
+
for key, value in inputs.items():
|
|
966
|
+
if "input" in key:
|
|
967
|
+
if has_input_been_found:
|
|
968
|
+
msg = (
|
|
969
|
+
"Multiple inputs found containing 'input'. "
|
|
970
|
+
"Only one allowed when input_variables = ['input']"
|
|
971
|
+
)
|
|
972
|
+
raise CircuitInputError(
|
|
973
|
+
msg,
|
|
974
|
+
parameter="input",
|
|
975
|
+
expected="single input key",
|
|
976
|
+
details={"input_keys": [k for k in inputs if "input" in k]},
|
|
977
|
+
)
|
|
978
|
+
has_input_been_found = True
|
|
979
|
+
new_inputs["input"] = value
|
|
980
|
+
else:
|
|
981
|
+
new_inputs[key] = value
|
|
982
|
+
|
|
983
|
+
# Special case: fallback mapping output → input
|
|
984
|
+
if "input" not in new_inputs and "output" in new_inputs:
|
|
985
|
+
new_inputs["input"] = inputs["output"]
|
|
986
|
+
del inputs["output"]
|
|
987
|
+
|
|
988
|
+
return new_inputs
|
|
989
|
+
|
|
990
|
+
def format_outputs(self: Circuit, output: list) -> dict:
|
|
991
|
+
"""Format raw model outputs into a standard dictionary format.
|
|
992
|
+
Can be overridden in subclasses
|
|
993
|
+
|
|
994
|
+
Args:
|
|
995
|
+
output (list): Raw model output.
|
|
996
|
+
|
|
997
|
+
Returns:
|
|
998
|
+
dict: dictionary containing the formatted output under the key 'output'.
|
|
999
|
+
"""
|
|
1000
|
+
return {"output": output}
|