JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (61) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
  2. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
  3. python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +114 -59
  7. python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
  8. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  9. python/core/model_processing/onnx_custom_ops/mul.py +66 -0
  10. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  11. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  12. python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
  13. python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
  14. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  15. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  16. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  17. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  18. python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
  19. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  20. python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
  21. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
  22. python/core/utils/general_layer_functions.py +17 -12
  23. python/core/utils/model_registry.py +6 -3
  24. python/scripts/gen_and_bench.py +2 -2
  25. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  26. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  27. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  28. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  29. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  30. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  31. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  32. python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
  33. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  34. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  35. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  36. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  37. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  38. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  39. python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
  40. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  41. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  42. python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
  43. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  44. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  45. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  46. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  47. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  48. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  49. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
  50. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  51. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  52. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  53. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  54. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  55. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  56. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  57. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  58. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
  59. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
  60. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
  61. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,9 @@ import logging
4
4
  from pathlib import Path
5
5
  from typing import TYPE_CHECKING, Any
6
6
 
7
+ from numpy import asarray, ndarray
8
+
9
+ from python.core.utils.errors import ShapeMismatchError
7
10
  from python.core.utils.witness_utils import compare_witness_to_io, load_witness
8
11
 
9
12
  if TYPE_CHECKING:
@@ -88,8 +91,7 @@ class Circuit:
88
91
  ValueError: If any parameter value is not an integer or list of integers.
89
92
  """
90
93
  if self.required_keys is None:
91
- msg = "self.required_keys must"
92
- " be specified in the circuit definition."
94
+ msg = "self.required_keys must be specified in the circuit definition."
93
95
  raise CircuitConfigurationError(
94
96
  msg,
95
97
  )
@@ -271,7 +273,8 @@ class Circuit:
271
273
  elif exec_config.run_type == RunType.GEN_VERIFY:
272
274
  witness_file = exec_config.witness_file
273
275
  output_file = exec_config.output_file
274
- processed_input_file = self.rename_inputs(exec_config.input_file)
276
+ processed_input_file = self.prepare_inputs_for_verification(exec_config)
277
+
275
278
  proof_system = exec_config.proof_system
276
279
  if not self.load_and_compare_witness_to_io(
277
280
  witness_path=witness_file,
@@ -317,6 +320,30 @@ class Circuit:
317
320
  operation=exec_config.run_type,
318
321
  ) from e
319
322
 
323
+ def prepare_inputs_for_verification(
324
+ self: Circuit,
325
+ exec_config: CircuitExecutionConfig,
326
+ ) -> str:
327
+ """
328
+ Load inputs, process them for analysis against witness
329
+
330
+ Args:
331
+ exec_config (CircuitExecutionConfig): Execution configuration
332
+
333
+ Returns:
334
+ str: name of file with processed inputs for verification
335
+ """
336
+ # read inputs
337
+ inputs = self._read_from_json_safely(exec_config.input_file)
338
+ # reshape inputs for circuit reading (or for verification check in this case)
339
+ processed_inputs = self.reshape_inputs_for_circuit(inputs)
340
+ # Send back to file
341
+ path = Path(exec_config.input_file)
342
+ processed_input_file = str(path.parent / (path.stem + "_veri" + path.suffix))
343
+ self._to_json_safely(processed_inputs, processed_input_file, "renamed input")
344
+
345
+ return processed_input_file
346
+
320
347
  def load_and_compare_witness_to_io(
321
348
  self: Circuit,
322
349
  witness_path: str,
@@ -378,38 +405,55 @@ class Circuit:
378
405
  return any(self.contains_float(i) for i in obj)
379
406
  return False
380
407
 
381
- def adjust_shape(self: Circuit, shape: list[int] | dict[str, int]) -> list[int]:
382
- """Normalize a shape representation into a valid list of positive integers.
408
+ def adjust_shape(
409
+ self: Circuit,
410
+ shape: list[int] | dict[str, list[int]],
411
+ ) -> list[int] | dict[str, list[int]]:
412
+ """
413
+ Normalize a shape representation into a valid list or dict of positive integers.
383
414
 
384
415
  Args:
385
- shape (list[int] | dict[str, int]):
386
- The shape, which can be a list of ints
387
- or a dict containing one shape list.
416
+ shape (list[int] | dict[str, list[int]]):
417
+ The shape, which can be:
418
+ a. a list of ints, or
419
+ b. a dict mapping strings to lists of ints.
420
+ Each non-positive integer is replaced by 1.
388
421
 
389
422
  Raises:
390
423
  CircuitInputError:
391
- If `shape` is a dict containing more than one shape definition.
424
+ If a dict contains invalid shape definitions.
392
425
 
393
426
  Returns:
394
- list[int]:
395
- The adjusted shape where all non-positive values are replaced with 1.
427
+ list[int] | dict[str, list[int]]:
428
+ The adjusted shape(s) where all non-positive values are replaced with 1.
429
+ For a multi-key dict, returns a dict with normalized lists of ints.
396
430
  """
397
431
  if isinstance(shape, dict):
398
- # Get the first shape from the dict
399
- # (assuming only one input is relevant here)
432
+ # Handle dict-based shapes
400
433
  if len(shape.values()) == 1:
401
434
  shape = next(iter(shape.values()))
402
- else:
403
- msg = (
404
- "Shape dictionary contains multiple entries;"
405
- " only one input shape is allowed."
406
- )
407
- raise CircuitInputError(
408
- msg,
409
- parameter="shape",
410
- expected="dict with exactly one key-value pair",
411
- details={"shape_keys": list(shape.keys())},
412
- )
435
+ if not isinstance(shape, (list, tuple)):
436
+ msg = f"Expected shape list for input, got {type(shape).__name__}"
437
+ raise CircuitInputError(msg)
438
+ return [s if s > 0 else 1 for s in shape]
439
+
440
+ adjusted_shapes = {}
441
+ for key, subshape in shape.items():
442
+ if not isinstance(subshape, (list, tuple)):
443
+ msg = (
444
+ f"Expected shape list for key '{key}', "
445
+ f"got {type(subshape).__name__}"
446
+ )
447
+ raise CircuitInputError(msg)
448
+ adjusted_shapes[key] = [s if s > 0 else 1 for s in subshape]
449
+
450
+ return adjusted_shapes
451
+
452
+ # Handle list-based shape input (the missing return case)
453
+ if not isinstance(shape, (list, tuple)):
454
+ msg = f"Expected list or dict for 'shape', got {type(shape).__name__}"
455
+ raise CircuitInputError(msg)
456
+
413
457
  return [s if s > 0 else 1 for s in shape]
414
458
 
415
459
  def scale_and_round(
@@ -448,15 +492,20 @@ class Circuit:
448
492
  )
449
493
  return value
450
494
 
451
- def adjust_inputs(self: Circuit, input_file: str) -> str:
495
+ def adjust_inputs(
496
+ self: Circuit,
497
+ inputs: dict[str, np.ndarray],
498
+ input_file: str,
499
+ ) -> str:
452
500
  """
453
501
  Load input values from a JSON file, adjust them by scaling
454
502
  and reshaping according to circuit parameters,
455
503
  and save the adjusted inputs to a new file.
456
504
 
457
505
  Args:
458
- input_file (str):
459
- Path to the input JSON file containing the original input values.
506
+ inputs (dict[str, np.ndarray]):
507
+ inputs, read from json file
508
+ input_file (str): path to input_file
460
509
 
461
510
  Returns:
462
511
  str: Path to the new file containing the adjusted input values.
@@ -468,7 +517,6 @@ class Circuit:
468
517
  CircuitConfigurationError: If required shape attributes are missing.
469
518
  CircuitProcessingError: If reshaping or scaling operations fail.
470
519
  """
471
- inputs = self._read_from_json_safely(input_file)
472
520
 
473
521
  input_variables = getattr(self, "input_variables", ["input"])
474
522
  if input_variables == ["input"]:
@@ -503,11 +551,6 @@ class Circuit:
503
551
  has_input_been_found = False
504
552
 
505
553
  for key, value in inputs.items():
506
- value_adjusted = self.scale_and_round(
507
- value,
508
- self.scale_base,
509
- self.scale_exponent,
510
- )
511
554
  if "input" in key:
512
555
  if has_input_been_found:
513
556
  msg = (
@@ -522,13 +565,13 @@ class Circuit:
522
565
  )
523
566
  has_input_been_found = True
524
567
  value_adjusted = self._reshape_input_value(
525
- value_adjusted,
568
+ value,
526
569
  "input_shape",
527
570
  key,
528
571
  )
529
572
  new_inputs["input"] = value_adjusted
530
573
  else:
531
- new_inputs[key] = value_adjusted
574
+ new_inputs[key] = value
532
575
 
533
576
  # Special case: fallback mapping output → input
534
577
  if "input" not in new_inputs and "output" in new_inputs:
@@ -560,11 +603,7 @@ class Circuit:
560
603
  """
561
604
  new_inputs: dict[str, Any] = {}
562
605
  for key, value in inputs.items():
563
- value_adjusted = self.scale_and_round(
564
- value,
565
- self.scale_base,
566
- self.scale_exponent,
567
- )
606
+ value_adjusted = value
568
607
  if key in input_variables:
569
608
  shape_attr = f"{key}_shape"
570
609
  value_adjusted = self._reshape_input_value(
@@ -601,8 +640,10 @@ class Circuit:
601
640
  CircuitProcessingError: If the reshaping operation fails.
602
641
  """
603
642
  if not hasattr(self, shape_attr):
604
- msg = f"Required shape attribute '{shape_attr}'"
605
- f" must be defined to reshape input '{input_key}'."
643
+ msg = (
644
+ f"Required shape attribute '{shape_attr}'"
645
+ f" must be defined to reshape input '{input_key}'."
646
+ )
606
647
  raise CircuitConfigurationError(
607
648
  msg,
608
649
  missing_attributes=[shape_attr],
@@ -689,6 +730,7 @@ class Circuit:
689
730
  Returns:
690
731
  str: Path to the final processed input file.
691
732
  """
733
+ _ = is_scaled
692
734
  # Rescale and reshape
693
735
  if quantized_path:
694
736
  self.load_quantized_model(quantized_path)
@@ -707,15 +749,142 @@ class Circuit:
707
749
  self._to_json_safely(output, output_file, "output")
708
750
 
709
751
  else:
710
- input_file = self.adjust_inputs(input_file)
711
- inputs = self.get_inputs_from_file(input_file, is_scaled=is_scaled)
712
- # Compute output (with caching via decorator)
713
- output = self.get_outputs(inputs)
752
+ # Get new json file name
753
+ path = Path(input_file)
754
+ new_input_file = str(path.with_name(path.stem + "_adjusted" + path.suffix))
755
+ # load inputs
756
+ inputs = self._read_from_json_safely(input_file)
757
+ # scale inputs
758
+ scaled_inputs = self.scale_inputs_only(inputs)
759
+ # reshape/format inputs for inference
760
+ inference_inputs = self.reshape_inputs_for_inference(scaled_inputs)
761
+
762
+ # reshape/format inputs for rust
763
+ circuit_inputs = self.reshape_inputs_for_circuit(scaled_inputs)
764
+ self._to_json_safely(circuit_inputs, new_input_file, "input")
765
+
766
+ # get outputs
767
+ output = self.get_outputs(inference_inputs)
714
768
  outputs = self.format_outputs(output)
715
769
 
716
770
  self._to_json_safely(outputs, output_file, "output")
771
+
772
+ input_file = new_input_file
717
773
  return input_file
718
774
 
775
+ def reshape_inputs_for_inference(
776
+ self: Circuit,
777
+ inputs: dict[str],
778
+ ) -> ndarray | dict[str, ndarray]:
779
+ """
780
+ Reshape input tensors to match the model's expected input shape.
781
+
782
+ Parameters
783
+ ----------
784
+ inputs : dict[str] or ndarray
785
+ Input tensors or a dictionary of tensors.
786
+
787
+ Returns
788
+ -------
789
+ ndarray or dict[str, ndarray]
790
+ Reshaped input(s) ready for inference.
791
+ """
792
+
793
+ if not hasattr(self, "input_shape"):
794
+ raise CircuitConfigurationError(missing_attributes=["input_shape"])
795
+
796
+ shape = self.input_shape
797
+ if hasattr(self, "adjust_shape") and callable(self.adjust_shape):
798
+ shape = self.adjust_shape(shape)
799
+
800
+ # --- Case: inputs is a dict ---
801
+ if isinstance(inputs, dict):
802
+ if len(inputs) == 1:
803
+ only_key = next(iter(inputs))
804
+ inputs = asarray(inputs[only_key])
805
+ else:
806
+ return self._reshape_dict_inputs(inputs, shape)
807
+
808
+ # --- Regular reshape ---
809
+ try:
810
+ return asarray(inputs).reshape(shape)
811
+ except Exception as e:
812
+ raise ShapeMismatchError(shape, list(asarray(inputs).shape)) from e
813
+
814
+ def _reshape_dict_inputs(
815
+ self: Circuit,
816
+ inputs: dict[str],
817
+ shape: dict[str, list[int]],
818
+ ) -> dict[str]:
819
+ """Reshape each item in an input dict based on shape dict."""
820
+ if not isinstance(shape, dict):
821
+ msg = (
822
+ "_reshape_dict_inputs requires dict "
823
+ f"shape, got {type(shape).__name__}"
824
+ )
825
+ raise CircuitInputError(msg, parameter="shape", expected="dict")
826
+ for key, value in inputs.items():
827
+ tensor = asarray(value)
828
+ try:
829
+ inputs[key] = tensor.reshape(shape[key])
830
+ except Exception as e:
831
+ raise ShapeMismatchError(shape[key], list(tensor.shape)) from e
832
+ return inputs
833
+
834
+ def reshape_inputs_for_circuit(
835
+ self: Circuit,
836
+ inputs: dict[str],
837
+ ) -> dict[str, list[int]]:
838
+ """
839
+ Flatten model inputs for circuit processing.
840
+
841
+ Parameters
842
+ ----------
843
+ inputs : dict[str]
844
+ Mapping of input names to arrays, lists, or tuples.
845
+
846
+ Returns
847
+ -------
848
+ dict[str, list[int]]
849
+ A dictionary with a single flattened input list.
850
+ """
851
+ if not isinstance(inputs, dict):
852
+ msg = f"Expected a dict, got {type(inputs).__name__}"
853
+ raise CircuitConfigurationError(message=msg)
854
+
855
+ if hasattr(self, "input_shapes") and isinstance(self.input_shapes, dict):
856
+ ordered_keys = list(self.input_shapes.keys())
857
+ else:
858
+ ordered_keys = inputs.keys()
859
+
860
+ all_flattened = []
861
+
862
+ for key in ordered_keys:
863
+ if key not in inputs:
864
+ msg = f"Missing expected input key '{key}'"
865
+ raise CircuitProcessingError(message=msg)
866
+
867
+ value = inputs[key]
868
+
869
+ # --- handle unsupported input types BEFORE entering try ---
870
+ if not isinstance(value, (ndarray, list, tuple)):
871
+ msg = f"Unsupported input type for key '{key}': {type(value).__name__}"
872
+ raise CircuitProcessingError(message=msg)
873
+
874
+ try:
875
+ # Convert to tensor, flatten, and back to list
876
+ if isinstance(value, ndarray):
877
+ flattened = value.flatten().tolist()
878
+ else:
879
+ flattened = asarray(value).flatten().tolist()
880
+ except Exception as e:
881
+ msg = f"Failed to flatten input '{key}' (type {type(value).__name__})"
882
+ raise CircuitProcessingError(message=msg) from e
883
+
884
+ all_flattened.extend(flattened)
885
+
886
+ return {"input": all_flattened}
887
+
719
888
  def _compile_preprocessing(
720
889
  self: Circuit,
721
890
  metadata_path: str,
@@ -766,8 +935,10 @@ class Circuit:
766
935
  elif isinstance(w_and_b, (dict, tuple)):
767
936
  self._to_json_safely(w_and_b, w_and_b_path, "w_and_b")
768
937
  else:
769
- msg = f"Unsupported w_and_b type: {type(w_and_b)}."
770
- " Expected list, dict, or tuple."
938
+ msg = (
939
+ f"Unsupported w_and_b type: {type(w_and_b)}."
940
+ " Expected list, dict, or tuple."
941
+ )
771
942
  raise CircuitConfigurationError(
772
943
  msg,
773
944
  details={"w_and_b_type": str(type(w_and_b))},
@@ -881,22 +1052,19 @@ class Circuit:
881
1052
  ) from e
882
1053
  return out
883
1054
 
884
- def scale_inputs_only(self: Circuit, input_file: str) -> str:
1055
+ def scale_inputs_only(self: Circuit, inputs: dict) -> dict:
885
1056
  """
886
- Load input values from a JSON file, scale them according to circuit parameters,
887
- without reshaping, and save the scaled inputs to a new file.
1057
+ Scale input values according to circuit parameters without reshaping.
888
1058
 
889
1059
  Args:
890
- input_file (str):
891
- Path to the input JSON file containing the original input values.
1060
+ inputs (dict): Dictionary of input values to scale.
892
1061
 
893
1062
  Returns:
894
- str: Path to the new file containing the scaled input values.
1063
+ dict: Dictionary of scaled input values.
895
1064
 
896
1065
  Raises:
897
1066
  CircuitFileError: If reading from or writing to JSON files fails.
898
1067
  """
899
- inputs = self._read_from_json_safely(input_file)
900
1068
 
901
1069
  new_inputs = {}
902
1070
  for key, value in inputs.items():
@@ -905,31 +1073,27 @@ class Circuit:
905
1073
  self.scale_base,
906
1074
  self.scale_exponent,
907
1075
  )
1076
+ return new_inputs
908
1077
 
909
- # Save scaled inputs
910
- path = Path(input_file)
911
- new_input_file = path.stem + "_scaled" + path.suffix
912
- self._to_json_safely(new_inputs, new_input_file, "scaled input")
913
- return new_input_file
914
-
915
- def rename_inputs(self: Circuit, input_file: str) -> str:
1078
+ def rename_inputs(
1079
+ self: Circuit,
1080
+ inputs: dict[str, np.ndarray],
1081
+ ) -> dict[str, np.ndarray]:
916
1082
  """
917
1083
  Load input values from a JSON file, rename keys according to circuit logic
918
1084
  (similar to adjust_inputs but without scaling or reshaping),
919
1085
  and save the renamed inputs to a new file.
920
1086
 
921
1087
  Args:
922
- input_file (str):
923
- Path to the input JSON file containing the original input values.
1088
+ inputs (dict[str, np.ndarray]): Original input values.
924
1089
 
925
1090
  Returns:
926
- str: Path to the new file containing the renamed input values.
1091
+ dict[str, np.ndarray]: Dictionary of renamed input values.
927
1092
 
928
1093
  Raises:
929
1094
  CircuitFileError: If reading from or writing to JSON files fails.
930
1095
  CircuitInputError: If input validation fails.
931
1096
  """
932
- inputs = self._read_from_json_safely(input_file)
933
1097
 
934
1098
  input_variables = getattr(self, "input_variables", ["input"])
935
1099
  if input_variables == ["input"]:
@@ -937,11 +1101,7 @@ class Circuit:
937
1101
  else:
938
1102
  new_inputs = dict(inputs.items())
939
1103
 
940
- # Save renamed inputs
941
- path = Path(input_file)
942
- new_input_file = path.stem + "_renamed" + path.suffix
943
- self._to_json_safely(new_inputs, new_input_file, "renamed input")
944
- return new_input_file
1104
+ return new_inputs
945
1105
 
946
1106
  def _rename_single_input(self: Circuit, inputs: dict) -> dict:
947
1107
  """