JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +5 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,1000 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from python.core.utils.witness_utils import compare_witness_to_io, load_witness
8
+
9
+ if TYPE_CHECKING:
10
+ import numpy as np
11
+ import torch
12
+
13
+ from python.core.circuits.errors import (
14
+ CircuitConfigurationError,
15
+ CircuitFileError,
16
+ CircuitInputError,
17
+ CircuitProcessingError,
18
+ CircuitRunError,
19
+ WitnessMatchError,
20
+ )
21
+ from python.core.utils.helper_functions import (
22
+ CircuitExecutionConfig,
23
+ RunType,
24
+ ZKProofSystems,
25
+ compile_circuit,
26
+ compute_and_store_output,
27
+ generate_proof,
28
+ generate_verification,
29
+ generate_witness,
30
+ prepare_io_files,
31
+ read_from_json,
32
+ run_end_to_end,
33
+ to_json,
34
+ )
35
+
36
+
37
+ class Circuit:
38
+ """
39
+ Base class for all ZK circuits.
40
+
41
+ This class defines the standard interface and common utilities for
42
+ building, testing, and running ZK circuits.
43
+ Subclasses are expected to implement circuit-specific logic such as
44
+ input preparation, output computation, and model handling.
45
+ """
46
+
47
+ def __init__(self: Circuit) -> None:
48
+ # Default folder paths - can be overridden in subclasses
49
+ self.input_folder = "inputs"
50
+ self.proof_folder = "analysis"
51
+ self.temp_folder = "temp"
52
+ self.circuit_folder = ""
53
+ self.weights_folder = "weights"
54
+ self.output_folder = "output"
55
+ self.proof_system = ZKProofSystems.Expander
56
+
57
+ # This will be set by prepare_io_files decorator
58
+ self._file_info = None
59
+ self.required_keys = None
60
+ self.logger = logging.getLogger(__name__)
61
+
62
+ def check_attributes(self: Circuit) -> None:
63
+ """
64
+ Check if the necessary attributes are defined in subclasses.
65
+ Must be overridden in subclasses
66
+
67
+ Raises:
68
+ CircuitConfigurationError: If required attributes are missing.
69
+ """
70
+ missing = [
71
+ attr
72
+ for attr in ("required_keys", "name", "scale_exponent", "scale_base")
73
+ if not hasattr(self, attr)
74
+ ]
75
+ if missing:
76
+ raise CircuitConfigurationError(missing_attributes=missing)
77
+
78
+ def parse_inputs(self: Circuit, **kwargs: dict[str, Any]) -> None:
79
+ """Parse and validate required input parameters
80
+ for the circuit into an instance attribute.
81
+
82
+ Args:
83
+ **kwargs (dict[str, Any]): Input parameters to parse and validate.
84
+
85
+ Raises:
86
+ NotImplementedError: If `required_keys` is not set.
87
+ KeyError: If any required parameter is missing.
88
+ ValueError: If any parameter value is not an integer or list of integers.
89
+ """
90
+ if self.required_keys is None:
91
+ msg = "self.required_keys must"
92
+ " be specified in the circuit definition."
93
+ raise CircuitConfigurationError(
94
+ msg,
95
+ )
96
+ for key in self.required_keys:
97
+ if key not in kwargs:
98
+ msg = f"Missing required parameter: '{key}'"
99
+ raise CircuitInputError(msg)
100
+
101
+ value = kwargs[key]
102
+
103
+ # # Validate type (ensure integer)
104
+ if not isinstance(value, (int, list)):
105
+ msg = (
106
+ f"Parameter '{key}' must be an int or list of ints, "
107
+ f"got {type(value).__name__}."
108
+ )
109
+ raise CircuitInputError(
110
+ msg,
111
+ )
112
+ setattr(self, key, value)
113
+
114
+ @compute_and_store_output
115
+ def get_outputs(self: Circuit) -> None:
116
+ """
117
+ Compute circuit outputs.
118
+ This method should be implemented by subclasses.
119
+ """
120
+ msg = "get_outputs must be implemented"
121
+ raise NotImplementedError(msg)
122
+
123
+ def get_inputs(
124
+ self: Circuit,
125
+ file_path: str | None = None,
126
+ *,
127
+ is_scaled: bool | None = False,
128
+ ) -> None:
129
+ """
130
+ Compute and return the circuit's input values.
131
+ This method should be implemented by subclasses.
132
+
133
+ Args:
134
+ file_path (str | None): Optional path to input file.
135
+ is_scaled (bool | None): Whether inputs are scaled.
136
+ """
137
+ _ = file_path, is_scaled
138
+ msg = "get_inputs must be implemented"
139
+ raise NotImplementedError(msg)
140
+
141
+ @prepare_io_files
142
+ def base_testing(self: Circuit, exec_config: CircuitExecutionConfig) -> None:
143
+ """Run the circuit in a specified mode
144
+ (testing, proving, compiling, etc.).
145
+
146
+ File path resolution is handled automatically by the
147
+ `prepare_io_files` decorator.
148
+
149
+ Args:
150
+ exec_config (CircuitExecutionConfig): Configuration object containing
151
+ run_type, file paths, and other execution parameters.
152
+
153
+ Raises:
154
+ CircuitConfigurationError: If `_file_info` is not set by the decorator.
155
+ """
156
+ if exec_config.circuit_path is None:
157
+ exec_config.circuit_path = f"{exec_config.circuit_name}.txt"
158
+
159
+ if not self._file_info:
160
+ msg = (
161
+ "Circuit file information (_file_info)"
162
+ " must be set by the prepare_io_files decorator."
163
+ )
164
+ raise CircuitConfigurationError(
165
+ msg,
166
+ details={"decorator": "prepare_io_files"},
167
+ )
168
+ exec_config.metadata_path = self._file_info.get("metadata_path")
169
+ exec_config.architecture_path = self._file_info.get("architecture_path")
170
+ exec_config.w_and_b_path = self._file_info.get("w_and_b_path")
171
+
172
+ # Run the appropriate proof operation based on run_type
173
+ self.parse_proof_run_type(exec_config)
174
+
175
+ def _raise_unknown_run_type(self: Circuit, run_type: RunType) -> None:
176
+ self.logger.error("Unknown run type: %s", run_type)
177
+ msg = f"Unsupported run type: {run_type}"
178
+ raise CircuitRunError(
179
+ msg,
180
+ operation="parse_proof_run_type",
181
+ details={"run_type": run_type},
182
+ )
183
+
184
+ def parse_proof_run_type(
185
+ self: Circuit,
186
+ exec_config: CircuitExecutionConfig,
187
+ ) -> None:
188
+ """Dispatch proof-related operations based on the selected run type.
189
+
190
+ Args:
191
+ exec_config (CircuitExecutionConfig): Configuration object containing
192
+ file paths, run type, and other parameters.
193
+
194
+ Raises:
195
+ CircuitRunError: If `run_type` is unknown or operation fails.
196
+ """
197
+ is_scaled = True
198
+
199
+ try:
200
+ if exec_config.run_type == RunType.END_TO_END:
201
+ self._compile_preprocessing(
202
+ metadata_path=exec_config.metadata_path,
203
+ architecture_path=exec_config.architecture_path,
204
+ w_and_b_path=exec_config.w_and_b_path,
205
+ quantized_path=exec_config.quantized_path,
206
+ )
207
+ processed_input_file = self._gen_witness_preprocessing(
208
+ input_file=exec_config.input_file,
209
+ output_file=exec_config.output_file,
210
+ quantized_path=exec_config.quantized_path,
211
+ write_json=exec_config.write_json,
212
+ is_scaled=is_scaled,
213
+ )
214
+ run_end_to_end(
215
+ circuit_name=exec_config.circuit_name,
216
+ circuit_path=exec_config.circuit_path,
217
+ input_file=processed_input_file,
218
+ output_file=exec_config.output_file,
219
+ proof_system=exec_config.proof_system,
220
+ dev_mode=exec_config.dev_mode,
221
+ ecc=exec_config.ecc,
222
+ )
223
+ elif exec_config.run_type == RunType.COMPILE_CIRCUIT:
224
+ self._compile_preprocessing(
225
+ metadata_path=exec_config.metadata_path,
226
+ architecture_path=exec_config.architecture_path,
227
+ w_and_b_path=exec_config.w_and_b_path,
228
+ quantized_path=exec_config.quantized_path,
229
+ )
230
+ compile_circuit(
231
+ circuit_name=exec_config.circuit_name,
232
+ circuit_path=exec_config.circuit_path,
233
+ metadata_path=exec_config.metadata_path,
234
+ architecture_path=exec_config.architecture_path,
235
+ w_and_b_path=exec_config.w_and_b_path,
236
+ proof_system=exec_config.proof_system,
237
+ dev_mode=exec_config.dev_mode,
238
+ bench=exec_config.bench,
239
+ )
240
+ elif exec_config.run_type == RunType.GEN_WITNESS:
241
+ processed_input_file = self._gen_witness_preprocessing(
242
+ input_file=exec_config.input_file,
243
+ output_file=exec_config.output_file,
244
+ quantized_path=exec_config.quantized_path,
245
+ write_json=exec_config.write_json,
246
+ is_scaled=is_scaled,
247
+ )
248
+ generate_witness(
249
+ circuit_name=exec_config.circuit_name,
250
+ circuit_path=exec_config.circuit_path,
251
+ witness_file=exec_config.witness_file,
252
+ input_file=processed_input_file,
253
+ output_file=exec_config.output_file,
254
+ metadata_path=exec_config.metadata_path,
255
+ proof_system=exec_config.proof_system,
256
+ dev_mode=exec_config.dev_mode,
257
+ bench=exec_config.bench,
258
+ )
259
+ elif exec_config.run_type == RunType.PROVE_WITNESS:
260
+ generate_proof(
261
+ circuit_name=exec_config.circuit_name,
262
+ circuit_path=exec_config.circuit_path,
263
+ witness_file=exec_config.witness_file,
264
+ proof_file=exec_config.proof_file,
265
+ metadata_path=exec_config.metadata_path,
266
+ proof_system=exec_config.proof_system,
267
+ dev_mode=exec_config.dev_mode,
268
+ ecc=exec_config.ecc,
269
+ bench=exec_config.bench,
270
+ )
271
+ elif exec_config.run_type == RunType.GEN_VERIFY:
272
+ witness_file = exec_config.witness_file
273
+ output_file = exec_config.output_file
274
+ processed_input_file = self.rename_inputs(exec_config.input_file)
275
+ proof_system = exec_config.proof_system
276
+ if not self.load_and_compare_witness_to_io(
277
+ witness_path=witness_file,
278
+ input_path=processed_input_file,
279
+ output_path=output_file,
280
+ proof_system=proof_system,
281
+ ):
282
+ raise WitnessMatchError # noqa: TRY301
283
+ generate_verification(
284
+ circuit_name=exec_config.circuit_name,
285
+ circuit_path=exec_config.circuit_path,
286
+ input_file=processed_input_file,
287
+ output_file=output_file,
288
+ witness_file=witness_file,
289
+ proof_file=exec_config.proof_file,
290
+ metadata_path=exec_config.metadata_path,
291
+ proof_system=proof_system,
292
+ dev_mode=exec_config.dev_mode,
293
+ ecc=exec_config.ecc,
294
+ bench=exec_config.bench,
295
+ )
296
+ else:
297
+ self._raise_unknown_run_type(exec_config.run_type)
298
+ except CircuitRunError:
299
+ self.logger.exception(
300
+ "Operation %s failed",
301
+ exec_config.run_type,
302
+ extra={"run_type": exec_config.run_type},
303
+ )
304
+ raise
305
+ except (
306
+ CircuitProcessingError,
307
+ CircuitInputError,
308
+ CircuitFileError,
309
+ Exception,
310
+ ) as e:
311
+ self.logger.exception(
312
+ "Operation %s failed",
313
+ exec_config.run_type,
314
+ extra={"run_type": exec_config.run_type},
315
+ )
316
+ raise CircuitRunError(
317
+ operation=exec_config.run_type,
318
+ ) from e
319
+
320
+ def load_and_compare_witness_to_io(
321
+ self: Circuit,
322
+ witness_path: str,
323
+ input_path: str,
324
+ output_path: str,
325
+ proof_system: ZKProofSystems,
326
+ ) -> bool:
327
+ """
328
+ Load a witness from disk and compare its
329
+ public inputs to expected inputs and outputs.
330
+
331
+ Args:
332
+ witness_path (str): Path to the binary witness file.
333
+ input_path (str): Path to a JSON file containing expected inputs.
334
+ output_path (str): Path to a JSON file containing expected outputs.
335
+ Only the `"outputs"` field is used for comparison.
336
+ proof_system(ZKProofSystems): Proof system to be used.
337
+
338
+ Returns:
339
+ bool:
340
+ True if all witness public inputs match the expected inputs and outputs,
341
+ False otherwise.
342
+
343
+ Raises:
344
+ WitnessMatchError:
345
+ If the witness file is malformed or missing the modulus field.
346
+ """
347
+ w = load_witness(witness_path, proof_system)
348
+ expected_inputs = self._read_from_json_safely(input_path)
349
+ expected_outputs = self._read_from_json_safely(output_path)
350
+ if "modulus" not in w:
351
+ msg = "Witness not correctly formed. Missing modulus."
352
+ raise WitnessMatchError(msg)
353
+ return compare_witness_to_io(
354
+ w,
355
+ expected_inputs,
356
+ expected_outputs,
357
+ w["modulus"],
358
+ proof_system,
359
+ self.scale_and_round,
360
+ )
361
+
362
+ def contains_float(self: Circuit, obj: float | dict | list) -> bool:
363
+ """Recursively check whether an object contains any float values.
364
+
365
+ Args:
366
+ obj (float | dict | list): The object to inspect.
367
+ Can be a float, list, dict.
368
+
369
+ Returns:
370
+ bool: True if any float is found within the object
371
+ (including nested lists/dicts), False otherwise.
372
+ """
373
+ if isinstance(obj, float):
374
+ return True
375
+ if isinstance(obj, dict):
376
+ return any(self.contains_float(v) for v in obj.values())
377
+ if isinstance(obj, list):
378
+ return any(self.contains_float(i) for i in obj)
379
+ return False
380
+
381
+ def adjust_shape(self: Circuit, shape: list[int] | dict[str, int]) -> list[int]:
382
+ """Normalize a shape representation into a valid list of positive integers.
383
+
384
+ Args:
385
+ shape (list[int] | dict[str, int]):
386
+ The shape, which can be a list of ints
387
+ or a dict containing one shape list.
388
+
389
+ Raises:
390
+ CircuitInputError:
391
+ If `shape` is a dict containing more than one shape definition.
392
+
393
+ Returns:
394
+ list[int]:
395
+ The adjusted shape where all non-positive values are replaced with 1.
396
+ """
397
+ if isinstance(shape, dict):
398
+ # Get the first shape from the dict
399
+ # (assuming only one input is relevant here)
400
+ if len(shape.values()) == 1:
401
+ shape = next(iter(shape.values()))
402
+ else:
403
+ msg = (
404
+ "Shape dictionary contains multiple entries;"
405
+ " only one input shape is allowed."
406
+ )
407
+ raise CircuitInputError(
408
+ msg,
409
+ parameter="shape",
410
+ expected="dict with exactly one key-value pair",
411
+ details={"shape_keys": list(shape.keys())},
412
+ )
413
+ return [s if s > 0 else 1 for s in shape]
414
+
415
+ def scale_and_round(
416
+ self: Circuit,
417
+ value: list[int] | np.ndarray | torch.Tensor,
418
+ scale_base: int,
419
+ scale_exponent: int,
420
+ ) -> list[int] | np.ndarray | torch.Tensor:
421
+ """Scale and round numeric values to integers based on
422
+ circuit scaling parameters.
423
+
424
+ Args:
425
+ value (list[int] | np.ndarray | torch.Tensor): The values to process.
426
+
427
+ Returns:
428
+ list[int] | np.ndarray | torch.Tensor: The scaled and rounded values,
429
+ preserving the original structure.
430
+ """
431
+ import torch # noqa: PLC0415
432
+
433
+ from python.core.model_processing.onnx_quantizer.layers.base import ( # noqa: PLC0415
434
+ BaseOpQuantizer,
435
+ )
436
+
437
+ scaling = BaseOpQuantizer.get_scaling(
438
+ scale_base=scale_base,
439
+ scale_exponent=scale_exponent,
440
+ )
441
+ if self.contains_float(value):
442
+ return (
443
+ torch.round(
444
+ torch.tensor(value) * scaling,
445
+ )
446
+ .long()
447
+ .tolist()
448
+ )
449
+ return value
450
+
451
+ def adjust_inputs(self: Circuit, input_file: str) -> str:
452
+ """
453
+ Load input values from a JSON file, adjust them by scaling
454
+ and reshaping according to circuit parameters,
455
+ and save the adjusted inputs to a new file.
456
+
457
+ Args:
458
+ input_file (str):
459
+ Path to the input JSON file containing the original input values.
460
+
461
+ Returns:
462
+ str: Path to the new file containing the adjusted input values.
463
+
464
+ Raises:
465
+ CircuitFileError: If reading from or writing to JSON files fails.
466
+ CircuitInputError: If input validation fails
467
+ (e.g., multiple 'input' keys when expecting single).
468
+ CircuitConfigurationError: If required shape attributes are missing.
469
+ CircuitProcessingError: If reshaping or scaling operations fail.
470
+ """
471
+ inputs = self._read_from_json_safely(input_file)
472
+
473
+ input_variables = getattr(self, "input_variables", ["input"])
474
+ if input_variables == ["input"]:
475
+ new_inputs = self._adjust_single_input(inputs)
476
+ else:
477
+ new_inputs = self._adjust_multiple_inputs(inputs, input_variables)
478
+
479
+ # Save reshaped inputs
480
+ path = Path(input_file)
481
+ new_input_file = path.stem + "_reshaped" + path.suffix
482
+ self._to_json_safely(new_inputs, new_input_file, "adjusted input")
483
+ return new_input_file
484
+
485
+ def _adjust_single_input(self: Circuit, inputs: dict) -> dict:
486
+ """
487
+ Adjust inputs when there is a single 'input' variable,
488
+ handling special cases like multiple keys containing 'input'
489
+ or fallback from 'output' to 'input'.
490
+
491
+ Args:
492
+ inputs (dict): Dictionary of input values loaded from JSON.
493
+
494
+ Returns:
495
+ dict: Adjusted inputs with scaled and reshaped values.
496
+
497
+ Raises:
498
+ CircuitInputError:
499
+ If multiple keys containing 'input' are found
500
+ or if required shape attributes are missing.
501
+ """
502
+ new_inputs: dict[str, Any] = {}
503
+ has_input_been_found = False
504
+
505
+ for key, value in inputs.items():
506
+ value_adjusted = self.scale_and_round(
507
+ value,
508
+ self.scale_base,
509
+ self.scale_exponent,
510
+ )
511
+ if "input" in key:
512
+ if has_input_been_found:
513
+ msg = (
514
+ "Multiple inputs found containing 'input'. "
515
+ "Only one allowed when input_variables = ['input']"
516
+ )
517
+ raise CircuitInputError(
518
+ msg,
519
+ parameter="input",
520
+ expected="single input key",
521
+ details={"input_keys": [k for k in inputs if "input" in k]},
522
+ )
523
+ has_input_been_found = True
524
+ value_adjusted = self._reshape_input_value(
525
+ value_adjusted,
526
+ "input_shape",
527
+ key,
528
+ )
529
+ new_inputs["input"] = value_adjusted
530
+ else:
531
+ new_inputs[key] = value_adjusted
532
+
533
+ # Special case: fallback mapping output → input
534
+ if "input" not in new_inputs and "output" in new_inputs:
535
+ new_inputs["input"] = inputs["output"]
536
+ del inputs["output"]
537
+
538
+ return new_inputs
539
+
540
+ def _adjust_multiple_inputs(
541
+ self: Circuit,
542
+ inputs: dict,
543
+ input_variables: list[str],
544
+ ) -> dict:
545
+ """
546
+ Adjust inputs when there are multiple named input variables,
547
+ scaling and reshaping each according to their respective shape attributes.
548
+
549
+ Args:
550
+ inputs (dict): Dictionary of input values loaded from JSON.
551
+ input_variables (list[str]): List of input variable names to adjust.
552
+
553
+ Returns:
554
+ dict: Adjusted inputs with scaled and reshaped values.
555
+
556
+ Raises:
557
+ CircuitConfigurationError:
558
+ If required shape attributes are missing for any input variable.
559
+ CircuitProcessingError: If reshaping operations fail.
560
+ """
561
+ new_inputs: dict[str, Any] = {}
562
+ for key, value in inputs.items():
563
+ value_adjusted = self.scale_and_round(
564
+ value,
565
+ self.scale_base,
566
+ self.scale_exponent,
567
+ )
568
+ if key in input_variables:
569
+ shape_attr = f"{key}_shape"
570
+ value_adjusted = self._reshape_input_value(
571
+ value_adjusted,
572
+ shape_attr,
573
+ key,
574
+ )
575
+ new_inputs[key] = value_adjusted
576
+ return new_inputs
577
+
578
+ def _reshape_input_value(
579
+ self: Circuit,
580
+ value: list[int] | np.ndarray | torch.Tensor,
581
+ shape_attr: str,
582
+ input_key: str,
583
+ ) -> list[int]:
584
+ """
585
+ Reshape an input value to match the
586
+ specified shape attribute of the circuit.
587
+
588
+ Args:
589
+ value (list[int] | np.ndarray | torch.Tensor):
590
+ The input value to reshape, typically a list or tensor-like structure.
591
+ shape_attr (str):
592
+ Name of the attribute containing the target shape (e.g., 'input_shape').
593
+ input_key (str):
594
+ Key of the input being reshaped, used for error messages.
595
+
596
+ Returns:
597
+ list[int]: The reshaped value as a list.
598
+
599
+ Raises:
600
+ CircuitConfigurationError: If the required shape attribute is not defined.
601
+ CircuitProcessingError: If the reshaping operation fails.
602
+ """
603
+ if not hasattr(self, shape_attr):
604
+ msg = f"Required shape attribute '{shape_attr}'"
605
+ f" must be defined to reshape input '{input_key}'."
606
+ raise CircuitConfigurationError(
607
+ msg,
608
+ missing_attributes=[shape_attr],
609
+ details={"input_key": input_key},
610
+ )
611
+
612
+ import torch # noqa: PLC0415
613
+
614
+ shape = getattr(self, shape_attr)
615
+ shape = self.adjust_shape(shape)
616
+
617
+ try:
618
+ return torch.tensor(value).reshape(shape).tolist()
619
+ except Exception as e:
620
+ msg = f"Failed to reshape input data for '{input_key}'."
621
+ raise CircuitProcessingError(
622
+ msg,
623
+ operation="reshape",
624
+ data_type="tensor",
625
+ details={"shape": shape},
626
+ ) from e
627
+
628
+ def _to_json_safely(
629
+ self: Circuit,
630
+ inputs: dict,
631
+ input_file: str,
632
+ var_name: str,
633
+ ) -> None:
634
+ """Safely write data to a JSON file, handling exceptions.
635
+
636
+ Args:
637
+ inputs (dict): Data to write.
638
+ input_file (str): Path to the output file.
639
+ var_name (str): Name of the variable for error messages.
640
+ """
641
+ try:
642
+ to_json(inputs, input_file)
643
+ except Exception as e:
644
+ msg = f"Failed to write {var_name} file: {input_file}"
645
+ raise CircuitFileError(
646
+ msg,
647
+ file_path=input_file,
648
+ ) from e
649
+
650
+ def _read_from_json_safely(
651
+ self: Circuit,
652
+ input_file: str,
653
+ ) -> dict[str, Any]:
654
+ """Safely read data from a JSON file, handling exceptions.
655
+
656
+ Args:
657
+ input_file (str): Path to the input file.
658
+
659
+ Returns:
660
+ dict[str, Any]: Data read from the file.
661
+ """
662
+ try:
663
+ return read_from_json(input_file)
664
+ except Exception as e:
665
+ msg = f"Failed to read input file: {input_file}"
666
+ raise CircuitFileError(
667
+ msg,
668
+ file_path=input_file,
669
+ ) from e
670
+
671
+ def _gen_witness_preprocessing(
672
+ self: Circuit,
673
+ input_file: str,
674
+ output_file: str,
675
+ quantized_path: str,
676
+ *,
677
+ write_json: bool,
678
+ is_scaled: bool,
679
+ ) -> str:
680
+ """Preprocess inputs and outputs before witness generation.
681
+
682
+ Args:
683
+ input_file (str): Path to the input JSON file.
684
+ output_file (str): Path to save computed outputs.
685
+ quantized_path (str): Path to quantized model file.
686
+ write_json (bool): Whether to compute new inputs and write to JSON.
687
+ is_scaled (bool): Whether the inputs are already scaled.
688
+
689
+ Returns:
690
+ str: Path to the final processed input file.
691
+ """
692
+ # Rescale and reshape
693
+ if quantized_path:
694
+ self.load_quantized_model(quantized_path)
695
+ else:
696
+ self.load_quantized_model(self._file_info.get("quantized_model_path"))
697
+
698
+ if write_json:
699
+ inputs = self.get_inputs()
700
+ outputs = self.get_outputs(inputs)
701
+
702
+ inputs = self.format_inputs(inputs)
703
+
704
+ output = self.format_outputs(outputs)
705
+
706
+ self._to_json_safely(inputs, input_file, "input")
707
+ self._to_json_safely(output, output_file, "output")
708
+
709
+ else:
710
+ input_file = self.adjust_inputs(input_file)
711
+ inputs = self.get_inputs_from_file(input_file, is_scaled=is_scaled)
712
+ # Compute output (with caching via decorator)
713
+ output = self.get_outputs(inputs)
714
+ outputs = self.format_outputs(output)
715
+
716
+ self._to_json_safely(outputs, output_file, "output")
717
+ return input_file
718
+
719
+ def _compile_preprocessing(
720
+ self: Circuit,
721
+ metadata_path: str,
722
+ architecture_path: str,
723
+ w_and_b_path: str,
724
+ quantized_path: str,
725
+ ) -> None:
726
+ """Prepare model weights and quantized files for circuit compilation.
727
+
728
+ Args:
729
+ metadata_path (str): Path to save model metadata in JSON format.
730
+ architecture_path (str): Path to save model architecture in JSON format.
731
+ w_and_b_path (str): Path to save model weights and biases in JSON format.
732
+ quantized_path (str): Path to save the quantized model.
733
+
734
+ Raises:
735
+ CircuitConfigurationError: If model weights type is unsupported.
736
+ """
737
+ func_model_and_quantize = getattr(self, "get_model_and_quantize", None)
738
+ if callable(func_model_and_quantize):
739
+ func_model_and_quantize()
740
+
741
+ metadata = self.get_metadata()
742
+ architecture = self.get_architecture()
743
+ w_and_b = self.get_w_and_b()
744
+
745
+ if quantized_path:
746
+ self.save_quantized_model(quantized_path)
747
+ else:
748
+ self.save_quantized_model(self._file_info.get("quantized_model_path"))
749
+
750
+ if metadata:
751
+ self._to_json_safely(metadata, metadata_path, "metadata")
752
+ if architecture:
753
+ self._to_json_safely(architecture, architecture_path, "architecture")
754
+
755
+ if isinstance(w_and_b, list):
756
+ for i, w in enumerate(w_and_b):
757
+ if i == 0:
758
+ self._to_json_safely(w, Path(w_and_b_path), "w_and_b")
759
+ else:
760
+ val = i + 1
761
+ file_path = (
762
+ Path(w_and_b_path).parent
763
+ / f"{Path(w_and_b_path).stem!s}{val}{Path(w_and_b_path).suffix}"
764
+ )
765
+ self._to_json_safely(w, file_path, "w_and_b")
766
+ elif isinstance(w_and_b, (dict, tuple)):
767
+ self._to_json_safely(w_and_b, w_and_b_path, "w_and_b")
768
+ else:
769
+ msg = f"Unsupported w_and_b type: {type(w_and_b)}."
770
+ " Expected list, dict, or tuple."
771
+ raise CircuitConfigurationError(
772
+ msg,
773
+ details={"w_and_b_type": str(type(w_and_b))},
774
+ )
775
+
776
+ def save_model(self: Circuit, file_path: str) -> None:
777
+ """
778
+ Save the current model to a file. Should be overridden in subclasses
779
+
780
+ Args:
781
+ file_path (str): Path to save the model.
782
+ """
783
+
784
+ def load_model(self: Circuit, file_path: str) -> None:
785
+ """
786
+ Load the model from a file. Should be overridden in subclasses
787
+
788
+ Args:
789
+ file_path (str): Path to load the model.
790
+ """
791
+
792
+ def save_quantized_model(self: Circuit, file_path: str) -> None:
793
+ """
794
+ Save the current quantized model to a file. Should be overridden in subclasses
795
+
796
+ Args:
797
+ file_path (str): Path to save the model.
798
+ """
799
+
800
+ def load_quantized_model(self: Circuit, file_path: str) -> None:
801
+ """
802
+ Load the quantized model from a file. Should be overridden in subclasses
803
+
804
+ Args:
805
+ file_path (str): Path to load the model.
806
+ """
807
+
808
+ def get_weights(self: Circuit) -> dict:
809
+ """Retrieve model weights. Should be overridden in subclasses
810
+
811
+ Returns:
812
+ dict: Model weights.
813
+ """
814
+ return {}
815
+
816
+ def get_metadata(self: Circuit) -> dict:
817
+ """Retrieve model metadata. Should be overridden in subclasses
818
+
819
+ Returns:
820
+ dict: Model metadata.
821
+ """
822
+ return {}
823
+
824
+ def get_architecture(self: Circuit) -> dict:
825
+ """Retrieve model architecture. Should be overridden in subclasses
826
+
827
+ Returns:
828
+ dict: Model architecture.
829
+ """
830
+ return {}
831
+
832
+ def get_w_and_b(self: Circuit) -> dict:
833
+ """Retrieve model weights and biases. Should be overridden in subclasses
834
+
835
+ Returns:
836
+ dict: Model weights and biases.
837
+ """
838
+ return self.get_weights()
839
+
840
+ def get_inputs_from_file(
841
+ self: Circuit,
842
+ input_file: str,
843
+ *,
844
+ is_scaled: bool = True,
845
+ ) -> dict[str, list[int]]:
846
+ """Load input values from a JSON file, scaling if necessary.
847
+
848
+ Args:
849
+ input_file (str): Path to the input JSON file.
850
+ is_scaled (bool, optional): If False, scale inputs
851
+ according to circuit settings. Defaults to True.
852
+
853
+ Returns:
854
+ dict[str, list[int]]: Mapping from input names to integer lists of inputs.
855
+ """
856
+ if is_scaled:
857
+ return self._read_from_json_safely(input_file)
858
+
859
+ import torch # noqa: PLC0415
860
+
861
+ from python.core.model_processing.onnx_quantizer.layers.base import ( # noqa: PLC0415
862
+ BaseOpQuantizer,
863
+ )
864
+
865
+ out = {}
866
+ read = self._read_from_json_safely(input_file)
867
+
868
+ scaling = BaseOpQuantizer.get_scaling(self.scale_base, self.scale_exponent)
869
+ try:
870
+ for k in read:
871
+
872
+ out[k] = torch.as_tensor(read[k]) * scaling
873
+ out[k] = out[k].tolist()
874
+ except Exception as e:
875
+ msg = f"Failed to scale input data for key '{k}'"
876
+ raise CircuitProcessingError(
877
+ msg,
878
+ operation="scale",
879
+ data_type="tensor",
880
+ details={"key": k},
881
+ ) from e
882
+ return out
883
+
884
+ def scale_inputs_only(self: Circuit, input_file: str) -> str:
885
+ """
886
+ Load input values from a JSON file, scale them according to circuit parameters,
887
+ without reshaping, and save the scaled inputs to a new file.
888
+
889
+ Args:
890
+ input_file (str):
891
+ Path to the input JSON file containing the original input values.
892
+
893
+ Returns:
894
+ str: Path to the new file containing the scaled input values.
895
+
896
+ Raises:
897
+ CircuitFileError: If reading from or writing to JSON files fails.
898
+ """
899
+ inputs = self._read_from_json_safely(input_file)
900
+
901
+ new_inputs = {}
902
+ for key, value in inputs.items():
903
+ new_inputs[key] = self.scale_and_round(
904
+ value,
905
+ self.scale_base,
906
+ self.scale_exponent,
907
+ )
908
+
909
+ # Save scaled inputs
910
+ path = Path(input_file)
911
+ new_input_file = path.stem + "_scaled" + path.suffix
912
+ self._to_json_safely(new_inputs, new_input_file, "scaled input")
913
+ return new_input_file
914
+
915
+ def rename_inputs(self: Circuit, input_file: str) -> str:
916
+ """
917
+ Load input values from a JSON file, rename keys according to circuit logic
918
+ (similar to adjust_inputs but without scaling or reshaping),
919
+ and save the renamed inputs to a new file.
920
+
921
+ Args:
922
+ input_file (str):
923
+ Path to the input JSON file containing the original input values.
924
+
925
+ Returns:
926
+ str: Path to the new file containing the renamed input values.
927
+
928
+ Raises:
929
+ CircuitFileError: If reading from or writing to JSON files fails.
930
+ CircuitInputError: If input validation fails.
931
+ """
932
+ inputs = self._read_from_json_safely(input_file)
933
+
934
+ input_variables = getattr(self, "input_variables", ["input"])
935
+ if input_variables == ["input"]:
936
+ new_inputs = self._rename_single_input(inputs)
937
+ else:
938
+ new_inputs = dict(inputs.items())
939
+
940
+ # Save renamed inputs
941
+ path = Path(input_file)
942
+ new_input_file = path.stem + "_renamed" + path.suffix
943
+ self._to_json_safely(new_inputs, new_input_file, "renamed input")
944
+ return new_input_file
945
+
946
+ def _rename_single_input(self: Circuit, inputs: dict) -> dict:
947
+ """
948
+ Rename inputs when there is a single 'input' variable,
949
+ handling special cases like multiple keys containing 'input'
950
+ or fallback from 'output' to 'input'. No scaling or reshaping.
951
+
952
+ Args:
953
+ inputs (dict): Dictionary of input values loaded from JSON.
954
+
955
+ Returns:
956
+ dict: Renamed inputs with appropriate key mappings.
957
+
958
+ Raises:
959
+ CircuitInputError:
960
+ If multiple keys containing 'input' are found.
961
+ """
962
+ new_inputs: dict[str, Any] = {}
963
+ has_input_been_found = False
964
+
965
+ for key, value in inputs.items():
966
+ if "input" in key:
967
+ if has_input_been_found:
968
+ msg = (
969
+ "Multiple inputs found containing 'input'. "
970
+ "Only one allowed when input_variables = ['input']"
971
+ )
972
+ raise CircuitInputError(
973
+ msg,
974
+ parameter="input",
975
+ expected="single input key",
976
+ details={"input_keys": [k for k in inputs if "input" in k]},
977
+ )
978
+ has_input_been_found = True
979
+ new_inputs["input"] = value
980
+ else:
981
+ new_inputs[key] = value
982
+
983
+ # Special case: fallback mapping output → input
984
+ if "input" not in new_inputs and "output" in new_inputs:
985
+ new_inputs["input"] = inputs["output"]
986
+ del inputs["output"]
987
+
988
+ return new_inputs
989
+
990
+ def format_outputs(self: Circuit, output: list) -> dict:
991
+ """Format raw model outputs into a standard dictionary format.
992
+ Can be overridden in subclasses
993
+
994
+ Args:
995
+ output (list): Raw model output.
996
+
997
+ Returns:
998
+ dict: dictionary containing the formatted output under the key 'output'.
999
+ """
1000
+ return {"output": output}