JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
- python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
- python/core/circuit_models/generic_onnx.py +43 -9
- python/core/circuits/base.py +231 -71
- python/core/model_processing/converters/onnx_converter.py +114 -59
- python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
- python/core/model_processing/onnx_custom_ops/mul.py +66 -0
- python/core/model_processing/onnx_custom_ops/relu.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
- python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
- python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
- python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
- python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
- python/core/utils/general_layer_functions.py +17 -12
- python/core/utils/model_registry.py +6 -3
- python/scripts/gen_and_bench.py +2 -2
- python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
- python/tests/circuit_parent_classes/test_circuit.py +561 -38
- python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
- python/tests/onnx_quantizer_tests/__init__.py +1 -0
- python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
- python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/base.py +279 -0
- python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
- python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
- python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
- python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
- python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
- python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
- python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
- python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
- python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
- python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
- python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
- python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
python/core/circuits/base.py
CHANGED
|
@@ -4,6 +4,9 @@ import logging
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import TYPE_CHECKING, Any
|
|
6
6
|
|
|
7
|
+
from numpy import asarray, ndarray
|
|
8
|
+
|
|
9
|
+
from python.core.utils.errors import ShapeMismatchError
|
|
7
10
|
from python.core.utils.witness_utils import compare_witness_to_io, load_witness
|
|
8
11
|
|
|
9
12
|
if TYPE_CHECKING:
|
|
@@ -88,8 +91,7 @@ class Circuit:
|
|
|
88
91
|
ValueError: If any parameter value is not an integer or list of integers.
|
|
89
92
|
"""
|
|
90
93
|
if self.required_keys is None:
|
|
91
|
-
msg = "self.required_keys must"
|
|
92
|
-
" be specified in the circuit definition."
|
|
94
|
+
msg = "self.required_keys must be specified in the circuit definition."
|
|
93
95
|
raise CircuitConfigurationError(
|
|
94
96
|
msg,
|
|
95
97
|
)
|
|
@@ -271,7 +273,8 @@ class Circuit:
|
|
|
271
273
|
elif exec_config.run_type == RunType.GEN_VERIFY:
|
|
272
274
|
witness_file = exec_config.witness_file
|
|
273
275
|
output_file = exec_config.output_file
|
|
274
|
-
processed_input_file = self.
|
|
276
|
+
processed_input_file = self.prepare_inputs_for_verification(exec_config)
|
|
277
|
+
|
|
275
278
|
proof_system = exec_config.proof_system
|
|
276
279
|
if not self.load_and_compare_witness_to_io(
|
|
277
280
|
witness_path=witness_file,
|
|
@@ -317,6 +320,30 @@ class Circuit:
|
|
|
317
320
|
operation=exec_config.run_type,
|
|
318
321
|
) from e
|
|
319
322
|
|
|
323
|
+
def prepare_inputs_for_verification(
|
|
324
|
+
self: Circuit,
|
|
325
|
+
exec_config: CircuitExecutionConfig,
|
|
326
|
+
) -> str:
|
|
327
|
+
"""
|
|
328
|
+
Load inputs, process them for analysis against witness
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
exec_config (CircuitExecutionConfig): Execution configuration
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
str: name of file with processed inputs for verification
|
|
335
|
+
"""
|
|
336
|
+
# read inputs
|
|
337
|
+
inputs = self._read_from_json_safely(exec_config.input_file)
|
|
338
|
+
# reshape inputs for circuit reading (or for verification check in this case)
|
|
339
|
+
processed_inputs = self.reshape_inputs_for_circuit(inputs)
|
|
340
|
+
# Send back to file
|
|
341
|
+
path = Path(exec_config.input_file)
|
|
342
|
+
processed_input_file = str(path.parent / (path.stem + "_veri" + path.suffix))
|
|
343
|
+
self._to_json_safely(processed_inputs, processed_input_file, "renamed input")
|
|
344
|
+
|
|
345
|
+
return processed_input_file
|
|
346
|
+
|
|
320
347
|
def load_and_compare_witness_to_io(
|
|
321
348
|
self: Circuit,
|
|
322
349
|
witness_path: str,
|
|
@@ -378,38 +405,55 @@ class Circuit:
|
|
|
378
405
|
return any(self.contains_float(i) for i in obj)
|
|
379
406
|
return False
|
|
380
407
|
|
|
381
|
-
def adjust_shape(
|
|
382
|
-
|
|
408
|
+
def adjust_shape(
|
|
409
|
+
self: Circuit,
|
|
410
|
+
shape: list[int] | dict[str, list[int]],
|
|
411
|
+
) -> list[int] | dict[str, list[int]]:
|
|
412
|
+
"""
|
|
413
|
+
Normalize a shape representation into a valid list or dict of positive integers.
|
|
383
414
|
|
|
384
415
|
Args:
|
|
385
|
-
shape (list[int] | dict[str, int]):
|
|
386
|
-
The shape, which can be
|
|
387
|
-
|
|
416
|
+
shape (list[int] | dict[str, list[int]]):
|
|
417
|
+
The shape, which can be:
|
|
418
|
+
a. a list of ints, or
|
|
419
|
+
b. a dict mapping strings to lists of ints.
|
|
420
|
+
Each non-positive integer is replaced by 1.
|
|
388
421
|
|
|
389
422
|
Raises:
|
|
390
423
|
CircuitInputError:
|
|
391
|
-
If
|
|
424
|
+
If a dict contains invalid shape definitions.
|
|
392
425
|
|
|
393
426
|
Returns:
|
|
394
|
-
list[int]:
|
|
395
|
-
The adjusted shape where all non-positive values are replaced with 1.
|
|
427
|
+
list[int] | dict[str, list[int]]:
|
|
428
|
+
The adjusted shape(s) where all non-positive values are replaced with 1.
|
|
429
|
+
For a multi-key dict, returns a dict with normalized lists of ints.
|
|
396
430
|
"""
|
|
397
431
|
if isinstance(shape, dict):
|
|
398
|
-
#
|
|
399
|
-
# (assuming only one input is relevant here)
|
|
432
|
+
# Handle dict-based shapes
|
|
400
433
|
if len(shape.values()) == 1:
|
|
401
434
|
shape = next(iter(shape.values()))
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
435
|
+
if not isinstance(shape, (list, tuple)):
|
|
436
|
+
msg = f"Expected shape list for input, got {type(shape).__name__}"
|
|
437
|
+
raise CircuitInputError(msg)
|
|
438
|
+
return [s if s > 0 else 1 for s in shape]
|
|
439
|
+
|
|
440
|
+
adjusted_shapes = {}
|
|
441
|
+
for key, subshape in shape.items():
|
|
442
|
+
if not isinstance(subshape, (list, tuple)):
|
|
443
|
+
msg = (
|
|
444
|
+
f"Expected shape list for key '{key}', "
|
|
445
|
+
f"got {type(subshape).__name__}"
|
|
446
|
+
)
|
|
447
|
+
raise CircuitInputError(msg)
|
|
448
|
+
adjusted_shapes[key] = [s if s > 0 else 1 for s in subshape]
|
|
449
|
+
|
|
450
|
+
return adjusted_shapes
|
|
451
|
+
|
|
452
|
+
# Handle list-based shape input (the missing return case)
|
|
453
|
+
if not isinstance(shape, (list, tuple)):
|
|
454
|
+
msg = f"Expected list or dict for 'shape', got {type(shape).__name__}"
|
|
455
|
+
raise CircuitInputError(msg)
|
|
456
|
+
|
|
413
457
|
return [s if s > 0 else 1 for s in shape]
|
|
414
458
|
|
|
415
459
|
def scale_and_round(
|
|
@@ -448,15 +492,20 @@ class Circuit:
|
|
|
448
492
|
)
|
|
449
493
|
return value
|
|
450
494
|
|
|
451
|
-
def adjust_inputs(
|
|
495
|
+
def adjust_inputs(
|
|
496
|
+
self: Circuit,
|
|
497
|
+
inputs: dict[str, np.ndarray],
|
|
498
|
+
input_file: str,
|
|
499
|
+
) -> str:
|
|
452
500
|
"""
|
|
453
501
|
Load input values from a JSON file, adjust them by scaling
|
|
454
502
|
and reshaping according to circuit parameters,
|
|
455
503
|
and save the adjusted inputs to a new file.
|
|
456
504
|
|
|
457
505
|
Args:
|
|
458
|
-
|
|
459
|
-
|
|
506
|
+
inputs (dict[str, np.ndarray]):
|
|
507
|
+
inputs, read from json file
|
|
508
|
+
input_file (str): path to input_file
|
|
460
509
|
|
|
461
510
|
Returns:
|
|
462
511
|
str: Path to the new file containing the adjusted input values.
|
|
@@ -468,7 +517,6 @@ class Circuit:
|
|
|
468
517
|
CircuitConfigurationError: If required shape attributes are missing.
|
|
469
518
|
CircuitProcessingError: If reshaping or scaling operations fail.
|
|
470
519
|
"""
|
|
471
|
-
inputs = self._read_from_json_safely(input_file)
|
|
472
520
|
|
|
473
521
|
input_variables = getattr(self, "input_variables", ["input"])
|
|
474
522
|
if input_variables == ["input"]:
|
|
@@ -503,11 +551,6 @@ class Circuit:
|
|
|
503
551
|
has_input_been_found = False
|
|
504
552
|
|
|
505
553
|
for key, value in inputs.items():
|
|
506
|
-
value_adjusted = self.scale_and_round(
|
|
507
|
-
value,
|
|
508
|
-
self.scale_base,
|
|
509
|
-
self.scale_exponent,
|
|
510
|
-
)
|
|
511
554
|
if "input" in key:
|
|
512
555
|
if has_input_been_found:
|
|
513
556
|
msg = (
|
|
@@ -522,13 +565,13 @@ class Circuit:
|
|
|
522
565
|
)
|
|
523
566
|
has_input_been_found = True
|
|
524
567
|
value_adjusted = self._reshape_input_value(
|
|
525
|
-
|
|
568
|
+
value,
|
|
526
569
|
"input_shape",
|
|
527
570
|
key,
|
|
528
571
|
)
|
|
529
572
|
new_inputs["input"] = value_adjusted
|
|
530
573
|
else:
|
|
531
|
-
new_inputs[key] =
|
|
574
|
+
new_inputs[key] = value
|
|
532
575
|
|
|
533
576
|
# Special case: fallback mapping output → input
|
|
534
577
|
if "input" not in new_inputs and "output" in new_inputs:
|
|
@@ -560,11 +603,7 @@ class Circuit:
|
|
|
560
603
|
"""
|
|
561
604
|
new_inputs: dict[str, Any] = {}
|
|
562
605
|
for key, value in inputs.items():
|
|
563
|
-
value_adjusted =
|
|
564
|
-
value,
|
|
565
|
-
self.scale_base,
|
|
566
|
-
self.scale_exponent,
|
|
567
|
-
)
|
|
606
|
+
value_adjusted = value
|
|
568
607
|
if key in input_variables:
|
|
569
608
|
shape_attr = f"{key}_shape"
|
|
570
609
|
value_adjusted = self._reshape_input_value(
|
|
@@ -601,8 +640,10 @@ class Circuit:
|
|
|
601
640
|
CircuitProcessingError: If the reshaping operation fails.
|
|
602
641
|
"""
|
|
603
642
|
if not hasattr(self, shape_attr):
|
|
604
|
-
msg =
|
|
605
|
-
|
|
643
|
+
msg = (
|
|
644
|
+
f"Required shape attribute '{shape_attr}'"
|
|
645
|
+
f" must be defined to reshape input '{input_key}'."
|
|
646
|
+
)
|
|
606
647
|
raise CircuitConfigurationError(
|
|
607
648
|
msg,
|
|
608
649
|
missing_attributes=[shape_attr],
|
|
@@ -689,6 +730,7 @@ class Circuit:
|
|
|
689
730
|
Returns:
|
|
690
731
|
str: Path to the final processed input file.
|
|
691
732
|
"""
|
|
733
|
+
_ = is_scaled
|
|
692
734
|
# Rescale and reshape
|
|
693
735
|
if quantized_path:
|
|
694
736
|
self.load_quantized_model(quantized_path)
|
|
@@ -707,15 +749,142 @@ class Circuit:
|
|
|
707
749
|
self._to_json_safely(output, output_file, "output")
|
|
708
750
|
|
|
709
751
|
else:
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
752
|
+
# Get new json file name
|
|
753
|
+
path = Path(input_file)
|
|
754
|
+
new_input_file = str(path.with_name(path.stem + "_adjusted" + path.suffix))
|
|
755
|
+
# load inputs
|
|
756
|
+
inputs = self._read_from_json_safely(input_file)
|
|
757
|
+
# scale inputs
|
|
758
|
+
scaled_inputs = self.scale_inputs_only(inputs)
|
|
759
|
+
# reshape/format inputs for inference
|
|
760
|
+
inference_inputs = self.reshape_inputs_for_inference(scaled_inputs)
|
|
761
|
+
|
|
762
|
+
# reshape/format inputs for rust
|
|
763
|
+
circuit_inputs = self.reshape_inputs_for_circuit(scaled_inputs)
|
|
764
|
+
self._to_json_safely(circuit_inputs, new_input_file, "input")
|
|
765
|
+
|
|
766
|
+
# get outputs
|
|
767
|
+
output = self.get_outputs(inference_inputs)
|
|
714
768
|
outputs = self.format_outputs(output)
|
|
715
769
|
|
|
716
770
|
self._to_json_safely(outputs, output_file, "output")
|
|
771
|
+
|
|
772
|
+
input_file = new_input_file
|
|
717
773
|
return input_file
|
|
718
774
|
|
|
775
|
+
def reshape_inputs_for_inference(
|
|
776
|
+
self: Circuit,
|
|
777
|
+
inputs: dict[str],
|
|
778
|
+
) -> ndarray | dict[str, ndarray]:
|
|
779
|
+
"""
|
|
780
|
+
Reshape input tensors to match the model's expected input shape.
|
|
781
|
+
|
|
782
|
+
Parameters
|
|
783
|
+
----------
|
|
784
|
+
inputs : dict[str] or ndarray
|
|
785
|
+
Input tensors or a dictionary of tensors.
|
|
786
|
+
|
|
787
|
+
Returns
|
|
788
|
+
-------
|
|
789
|
+
ndarray or dict[str, ndarray]
|
|
790
|
+
Reshaped input(s) ready for inference.
|
|
791
|
+
"""
|
|
792
|
+
|
|
793
|
+
if not hasattr(self, "input_shape"):
|
|
794
|
+
raise CircuitConfigurationError(missing_attributes=["input_shape"])
|
|
795
|
+
|
|
796
|
+
shape = self.input_shape
|
|
797
|
+
if hasattr(self, "adjust_shape") and callable(self.adjust_shape):
|
|
798
|
+
shape = self.adjust_shape(shape)
|
|
799
|
+
|
|
800
|
+
# --- Case: inputs is a dict ---
|
|
801
|
+
if isinstance(inputs, dict):
|
|
802
|
+
if len(inputs) == 1:
|
|
803
|
+
only_key = next(iter(inputs))
|
|
804
|
+
inputs = asarray(inputs[only_key])
|
|
805
|
+
else:
|
|
806
|
+
return self._reshape_dict_inputs(inputs, shape)
|
|
807
|
+
|
|
808
|
+
# --- Regular reshape ---
|
|
809
|
+
try:
|
|
810
|
+
return asarray(inputs).reshape(shape)
|
|
811
|
+
except Exception as e:
|
|
812
|
+
raise ShapeMismatchError(shape, list(asarray(inputs).shape)) from e
|
|
813
|
+
|
|
814
|
+
def _reshape_dict_inputs(
|
|
815
|
+
self: Circuit,
|
|
816
|
+
inputs: dict[str],
|
|
817
|
+
shape: dict[str, list[int]],
|
|
818
|
+
) -> dict[str]:
|
|
819
|
+
"""Reshape each item in an input dict based on shape dict."""
|
|
820
|
+
if not isinstance(shape, dict):
|
|
821
|
+
msg = (
|
|
822
|
+
"_reshape_dict_inputs requires dict "
|
|
823
|
+
f"shape, got {type(shape).__name__}"
|
|
824
|
+
)
|
|
825
|
+
raise CircuitInputError(msg, parameter="shape", expected="dict")
|
|
826
|
+
for key, value in inputs.items():
|
|
827
|
+
tensor = asarray(value)
|
|
828
|
+
try:
|
|
829
|
+
inputs[key] = tensor.reshape(shape[key])
|
|
830
|
+
except Exception as e:
|
|
831
|
+
raise ShapeMismatchError(shape[key], list(tensor.shape)) from e
|
|
832
|
+
return inputs
|
|
833
|
+
|
|
834
|
+
def reshape_inputs_for_circuit(
|
|
835
|
+
self: Circuit,
|
|
836
|
+
inputs: dict[str],
|
|
837
|
+
) -> dict[str, list[int]]:
|
|
838
|
+
"""
|
|
839
|
+
Flatten model inputs for circuit processing.
|
|
840
|
+
|
|
841
|
+
Parameters
|
|
842
|
+
----------
|
|
843
|
+
inputs : dict[str]
|
|
844
|
+
Mapping of input names to arrays, lists, or tuples.
|
|
845
|
+
|
|
846
|
+
Returns
|
|
847
|
+
-------
|
|
848
|
+
dict[str, list[int]]
|
|
849
|
+
A dictionary with a single flattened input list.
|
|
850
|
+
"""
|
|
851
|
+
if not isinstance(inputs, dict):
|
|
852
|
+
msg = f"Expected a dict, got {type(inputs).__name__}"
|
|
853
|
+
raise CircuitConfigurationError(message=msg)
|
|
854
|
+
|
|
855
|
+
if hasattr(self, "input_shapes") and isinstance(self.input_shapes, dict):
|
|
856
|
+
ordered_keys = list(self.input_shapes.keys())
|
|
857
|
+
else:
|
|
858
|
+
ordered_keys = inputs.keys()
|
|
859
|
+
|
|
860
|
+
all_flattened = []
|
|
861
|
+
|
|
862
|
+
for key in ordered_keys:
|
|
863
|
+
if key not in inputs:
|
|
864
|
+
msg = f"Missing expected input key '{key}'"
|
|
865
|
+
raise CircuitProcessingError(message=msg)
|
|
866
|
+
|
|
867
|
+
value = inputs[key]
|
|
868
|
+
|
|
869
|
+
# --- handle unsupported input types BEFORE entering try ---
|
|
870
|
+
if not isinstance(value, (ndarray, list, tuple)):
|
|
871
|
+
msg = f"Unsupported input type for key '{key}': {type(value).__name__}"
|
|
872
|
+
raise CircuitProcessingError(message=msg)
|
|
873
|
+
|
|
874
|
+
try:
|
|
875
|
+
# Convert to tensor, flatten, and back to list
|
|
876
|
+
if isinstance(value, ndarray):
|
|
877
|
+
flattened = value.flatten().tolist()
|
|
878
|
+
else:
|
|
879
|
+
flattened = asarray(value).flatten().tolist()
|
|
880
|
+
except Exception as e:
|
|
881
|
+
msg = f"Failed to flatten input '{key}' (type {type(value).__name__})"
|
|
882
|
+
raise CircuitProcessingError(message=msg) from e
|
|
883
|
+
|
|
884
|
+
all_flattened.extend(flattened)
|
|
885
|
+
|
|
886
|
+
return {"input": all_flattened}
|
|
887
|
+
|
|
719
888
|
def _compile_preprocessing(
|
|
720
889
|
self: Circuit,
|
|
721
890
|
metadata_path: str,
|
|
@@ -766,8 +935,10 @@ class Circuit:
|
|
|
766
935
|
elif isinstance(w_and_b, (dict, tuple)):
|
|
767
936
|
self._to_json_safely(w_and_b, w_and_b_path, "w_and_b")
|
|
768
937
|
else:
|
|
769
|
-
msg =
|
|
770
|
-
|
|
938
|
+
msg = (
|
|
939
|
+
f"Unsupported w_and_b type: {type(w_and_b)}."
|
|
940
|
+
" Expected list, dict, or tuple."
|
|
941
|
+
)
|
|
771
942
|
raise CircuitConfigurationError(
|
|
772
943
|
msg,
|
|
773
944
|
details={"w_and_b_type": str(type(w_and_b))},
|
|
@@ -881,22 +1052,19 @@ class Circuit:
|
|
|
881
1052
|
) from e
|
|
882
1053
|
return out
|
|
883
1054
|
|
|
884
|
-
def scale_inputs_only(self: Circuit,
|
|
1055
|
+
def scale_inputs_only(self: Circuit, inputs: dict) -> dict:
|
|
885
1056
|
"""
|
|
886
|
-
|
|
887
|
-
without reshaping, and save the scaled inputs to a new file.
|
|
1057
|
+
Scale input values according to circuit parameters without reshaping.
|
|
888
1058
|
|
|
889
1059
|
Args:
|
|
890
|
-
|
|
891
|
-
Path to the input JSON file containing the original input values.
|
|
1060
|
+
inputs (dict): Dictionary of input values to scale.
|
|
892
1061
|
|
|
893
1062
|
Returns:
|
|
894
|
-
|
|
1063
|
+
dict: Dictionary of scaled input values.
|
|
895
1064
|
|
|
896
1065
|
Raises:
|
|
897
1066
|
CircuitFileError: If reading from or writing to JSON files fails.
|
|
898
1067
|
"""
|
|
899
|
-
inputs = self._read_from_json_safely(input_file)
|
|
900
1068
|
|
|
901
1069
|
new_inputs = {}
|
|
902
1070
|
for key, value in inputs.items():
|
|
@@ -905,31 +1073,27 @@ class Circuit:
|
|
|
905
1073
|
self.scale_base,
|
|
906
1074
|
self.scale_exponent,
|
|
907
1075
|
)
|
|
1076
|
+
return new_inputs
|
|
908
1077
|
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
return new_input_file
|
|
914
|
-
|
|
915
|
-
def rename_inputs(self: Circuit, input_file: str) -> str:
|
|
1078
|
+
def rename_inputs(
|
|
1079
|
+
self: Circuit,
|
|
1080
|
+
inputs: dict[str, np.ndarray],
|
|
1081
|
+
) -> dict[str, np.ndarray]:
|
|
916
1082
|
"""
|
|
917
1083
|
Load input values from a JSON file, rename keys according to circuit logic
|
|
918
1084
|
(similar to adjust_inputs but without scaling or reshaping),
|
|
919
1085
|
and save the renamed inputs to a new file.
|
|
920
1086
|
|
|
921
1087
|
Args:
|
|
922
|
-
|
|
923
|
-
Path to the input JSON file containing the original input values.
|
|
1088
|
+
inputs (dict[str, np.ndarray]): Original input values.
|
|
924
1089
|
|
|
925
1090
|
Returns:
|
|
926
|
-
str:
|
|
1091
|
+
dict[str, np.ndarray]: Dictionary of renamed input values.
|
|
927
1092
|
|
|
928
1093
|
Raises:
|
|
929
1094
|
CircuitFileError: If reading from or writing to JSON files fails.
|
|
930
1095
|
CircuitInputError: If input validation fails.
|
|
931
1096
|
"""
|
|
932
|
-
inputs = self._read_from_json_safely(input_file)
|
|
933
1097
|
|
|
934
1098
|
input_variables = getattr(self, "input_variables", ["input"])
|
|
935
1099
|
if input_variables == ["input"]:
|
|
@@ -937,11 +1101,7 @@ class Circuit:
|
|
|
937
1101
|
else:
|
|
938
1102
|
new_inputs = dict(inputs.items())
|
|
939
1103
|
|
|
940
|
-
|
|
941
|
-
path = Path(input_file)
|
|
942
|
-
new_input_file = path.stem + "_renamed" + path.suffix
|
|
943
|
-
self._to_json_safely(new_inputs, new_input_file, "renamed input")
|
|
944
|
-
return new_input_file
|
|
1104
|
+
return new_inputs
|
|
945
1105
|
|
|
946
1106
|
def _rename_single_input(self: Circuit, inputs: dict) -> dict:
|
|
947
1107
|
"""
|