JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.1.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/METADATA +2 -2
  2. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/RECORD +51 -24
  3. python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +86 -32
  7. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  8. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  9. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  10. python/core/model_processing/onnx_quantizer/layers/base.py +121 -1
  11. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  12. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  13. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  14. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  15. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  16. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -1
  17. python/core/utils/general_layer_functions.py +17 -12
  18. python/core/utils/model_registry.py +6 -3
  19. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  20. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  21. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  22. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  23. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  24. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  25. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  26. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  27. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  28. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  29. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  30. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  31. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  32. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  33. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  34. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  35. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  36. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  37. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  38. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  39. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  40. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +265 -0
  41. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  42. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  43. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  44. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  45. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  46. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  47. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  48. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  49. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/WHEEL +0 -0
  50. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/entry_points.txt +0 -0
  51. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/licenses/LICENSE +0 -0
  52. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,14 @@
1
+ import re
1
2
  import sys
2
3
  from collections.abc import Generator
3
4
  from pathlib import Path
4
5
  from unittest.mock import MagicMock, patch
5
6
 
7
+ import numpy as np
6
8
  import pytest
7
9
 
10
+ from python.core.utils.errors import ShapeMismatchError
11
+
8
12
  sys.modules.pop("python.core.circuits.base", None)
9
13
 
10
14
 
@@ -30,6 +34,7 @@ with (
30
34
  CircuitInputError,
31
35
  CircuitProcessingError,
32
36
  CircuitRunError,
37
+ WitnessMatchError,
33
38
  )
34
39
 
35
40
 
@@ -148,6 +153,7 @@ def test_parse_proof_dispatch_logic(
148
153
  c._gen_witness_preprocessing = MagicMock(return_value="i")
149
154
  c.adjust_inputs = MagicMock(return_value="i")
150
155
  c.rename_inputs = MagicMock(return_value="i")
156
+ c.prepare_inputs_for_verification = MagicMock(return_value="i")
151
157
 
152
158
  c.load_and_compare_witness_to_io = MagicMock(return_value="True")
153
159
 
@@ -436,34 +442,53 @@ def test_gen_witness_preprocessing_write_json_true(mock_to_json: MagicMock) -> N
436
442
 
437
443
 
438
444
  @pytest.mark.unit
439
- @patch("python.core.circuits.base.to_json")
440
- def test_gen_witness_preprocessing_write_json_false(mock_to_json: MagicMock) -> None:
445
+ def test_gen_witness_preprocessing_write_json_false(tmp_path: Path) -> None:
441
446
  c = Circuit()
442
447
  c._file_info = {"quantized_model_path": "quant.pt"}
443
- c.load_quantized_model = MagicMock()
444
- c.get_inputs_from_file = MagicMock(return_value="mock_inputs")
445
- c.reshape_inputs = MagicMock(return_value="in.json")
446
- c.rescale_inputs = MagicMock(return_value="in.json")
447
- c.rename_inputs = MagicMock(return_value="in.json")
448
- c.rescale_and_reshape_inputs = MagicMock(return_value="in.json")
449
- c.adjust_inputs = MagicMock(return_value="in.json")
450
448
 
451
- c.get_outputs = MagicMock(return_value="mock_outputs")
452
- c.format_outputs = MagicMock(return_value={"output": 99})
453
-
454
- c._gen_witness_preprocessing(
455
- "in.json",
456
- "out.json",
449
+ # Mock all method calls used by _gen_witness_preprocessing
450
+ c.load_quantized_model = MagicMock()
451
+ c._read_from_json_safely = MagicMock(return_value={"mock": "inputs"})
452
+ c.scale_inputs_only = MagicMock(return_value={"scaled": "inputs"})
453
+ c.reshape_inputs_for_inference = MagicMock(return_value={"inference": "inputs"})
454
+ c.reshape_inputs_for_circuit = MagicMock(return_value={"input": [1, 2, 3]})
455
+ c._to_json_safely = MagicMock()
456
+ c.get_outputs = MagicMock(return_value={"raw_output": 123})
457
+ c.format_outputs = MagicMock(return_value={"formatted_output": 999})
458
+
459
+ input_path = tmp_path / "in.json"
460
+ output_path = tmp_path / "out.json"
461
+
462
+ result = c._gen_witness_preprocessing(
463
+ str(input_path),
464
+ str(output_path),
457
465
  None,
458
466
  write_json=False,
459
467
  is_scaled=False,
460
468
  )
461
469
 
470
+ # --- Assertions ---
462
471
  c.load_quantized_model.assert_called_once_with("quant.pt")
463
- c.get_inputs_from_file.assert_called_once_with("in.json", is_scaled=False)
464
- c.get_outputs.assert_called_once_with("mock_inputs")
465
- c.format_outputs.assert_called_once_with("mock_outputs")
466
- mock_to_json.assert_called_once_with({"output": 99}, "out.json")
472
+ c._read_from_json_safely.assert_called_once_with(str(input_path))
473
+ c.scale_inputs_only.assert_called_once_with({"mock": "inputs"})
474
+ c.reshape_inputs_for_inference.assert_called_once_with({"scaled": "inputs"})
475
+ c.reshape_inputs_for_circuit.assert_called_once_with({"scaled": "inputs"})
476
+
477
+ # Verify safe JSON writes
478
+ new_input_file = str(input_path.with_name("in_adjusted.json"))
479
+ c._to_json_safely.assert_any_call({"input": [1, 2, 3]}, new_input_file, "input")
480
+ c._to_json_safely.assert_any_call(
481
+ {"formatted_output": 999},
482
+ str(output_path),
483
+ "output",
484
+ )
485
+
486
+ # Verify output generation
487
+ c.get_outputs.assert_called_once_with({"inference": "inputs"})
488
+ c.format_outputs.assert_called_once_with({"raw_output": 123})
489
+
490
+ # Function should return the adjusted input file path
491
+ assert result == new_input_file
467
492
 
468
493
 
469
494
  # ---------- _compile_preprocessing ----------
@@ -896,24 +921,6 @@ def test_save_and_load_model_not_implemented() -> None:
896
921
  assert hasattr(c, "load_quantized_model")
897
922
 
898
923
 
899
- # ---------- New error handling tests ----------
900
- @pytest.mark.unit
901
- def test_adjust_inputs_file_error() -> None:
902
- c = Circuit()
903
- c.input_variables = ["input"]
904
- c.input_shape = [2, 2]
905
- c.scale_base = 2
906
- c.scale_exponent = 1
907
-
908
- with patch(
909
- "python.core.circuits.base.read_from_json",
910
- side_effect=FileNotFoundError("File not found"),
911
- ):
912
- _ = c
913
- with pytest.raises(CircuitFileError, match="Failed to read input file"):
914
- c.adjust_inputs("nonexistent.json")
915
-
916
-
917
924
  @pytest.mark.unit
918
925
  def test_adjust_inputs_processing_error() -> None:
919
926
  c = Circuit()
@@ -934,7 +941,7 @@ def test_adjust_inputs_processing_error() -> None:
934
941
  CircuitProcessingError,
935
942
  match="Failed to reshape input data",
936
943
  ):
937
- c.adjust_inputs("dummy.json")
944
+ c.adjust_inputs({"input": [1, 2, 3, 4]}, "dummy.json")
938
945
 
939
946
 
940
947
  @pytest.mark.unit
@@ -967,3 +974,519 @@ def test_get_inputs_from_file_processing_error() -> None:
967
974
  match="Failed to scale input data",
968
975
  ):
969
976
  c.get_inputs_from_file("dummy.json", is_scaled=False)
977
+
978
+
979
+ # ---------- Test _raise_unknown_run_type ----------
980
+ @pytest.mark.unit
981
+ def test_raise_unknown_run_type() -> None:
982
+ c = Circuit()
983
+
984
+ with pytest.raises(CircuitRunError, match="Unsupported run type: INVALID_TYPE"):
985
+ c._raise_unknown_run_type("INVALID_TYPE")
986
+
987
+
988
+ # ---------- Test contains_float ----------
989
+ @pytest.mark.unit
990
+ def test_contains_float_with_float() -> None:
991
+ c = Circuit()
992
+ assert c.contains_float(3.14) is True
993
+ assert c.contains_float(2.0) is True
994
+ assert c.contains_float(1.5) is True
995
+
996
+
997
+ @pytest.mark.unit
998
+ def test_contains_float_with_int() -> None:
999
+ c = Circuit()
1000
+ assert c.contains_float(1) is False
1001
+ assert c.contains_float(0) is False
1002
+ assert c.contains_float(-5) is False
1003
+
1004
+
1005
+ @pytest.mark.unit
1006
+ def test_contains_float_with_list() -> None:
1007
+ c = Circuit()
1008
+ assert c.contains_float([1, 2, 3]) is False
1009
+ assert c.contains_float([1.0, 2, 3]) is True
1010
+ assert c.contains_float([1, 2.5, 3]) is True
1011
+ assert c.contains_float([]) is False
1012
+
1013
+
1014
+ @pytest.mark.unit
1015
+ def test_contains_float_with_dict() -> None:
1016
+ c = Circuit()
1017
+ assert c.contains_float({"a": 1, "b": 2}) is False
1018
+ assert c.contains_float({"a": 1.0, "b": 2}) is True
1019
+ assert c.contains_float({"a": 1, "b": 2.5}) is True
1020
+ assert c.contains_float({}) is False
1021
+
1022
+
1023
+ @pytest.mark.unit
1024
+ def test_contains_float_nested_structures() -> None:
1025
+ c = Circuit()
1026
+ nested_with_float = {"a": [1, 2.0, 3], "b": {"c": 4.5}}
1027
+ nested_without_float = {"a": [1, 2, 3], "b": {"c": 4}}
1028
+
1029
+ assert c.contains_float(nested_with_float) is True
1030
+ assert c.contains_float(nested_without_float) is False
1031
+
1032
+
1033
+ # ---------- Test adjust_shape ----------
1034
+ @pytest.mark.unit
1035
+ def test_adjust_shape_list() -> None:
1036
+ c = Circuit()
1037
+ assert c.adjust_shape([1, 2, 3]) == [1, 2, 3]
1038
+ assert c.adjust_shape([0, -1, 5]) == [1, 1, 5]
1039
+ assert c.adjust_shape([-5, 0, 3]) == [1, 1, 3]
1040
+
1041
+
1042
+ @pytest.mark.unit
1043
+ def test_adjust_shape_dict_single_value() -> None:
1044
+ c = Circuit()
1045
+ result = c.adjust_shape({"key": [2, 3, 4]})
1046
+ assert result == [2, 3, 4]
1047
+ assert result == [2, 3, 4]
1048
+
1049
+
1050
+ @pytest.mark.unit
1051
+ def test_adjust_shape_dict_multiple_values() -> None:
1052
+ c = Circuit()
1053
+ input_dict = {"input": [2, 3, 4], "weight": [1, -1, 5], "bias": [0, 0, 3]}
1054
+ expected = {"input": [2, 3, 4], "weight": [1, 1, 5], "bias": [1, 1, 3]}
1055
+ assert c.adjust_shape(input_dict) == expected
1056
+
1057
+
1058
+ @pytest.mark.unit
1059
+ def test_adjust_shape_invalid_type() -> None:
1060
+ c = Circuit()
1061
+ with pytest.raises(CircuitInputError, match="Expected list or dict for 'shape'"):
1062
+ c.adjust_shape("invalid")
1063
+
1064
+
1065
+ @pytest.mark.unit
1066
+ def test_adjust_shape_dict_invalid_value() -> None:
1067
+ c = Circuit()
1068
+ with pytest.raises(
1069
+ CircuitInputError,
1070
+ match="Expected shape list for input, got str",
1071
+ ):
1072
+ c.adjust_shape({"bad": "not_a_list"})
1073
+
1074
+
1075
+ # ---------- Test scale_and_round ----------
1076
+ @pytest.mark.unit
1077
+ def test_scale_and_round_with_floats() -> None:
1078
+ c = Circuit()
1079
+ c.scale_base = 2
1080
+ c.scale_exponent = 2
1081
+
1082
+ with patch(
1083
+ "python.core.model_processing.onnx_quantizer.layers.base.BaseOpQuantizer.get_scaling",
1084
+ return_value=4.0,
1085
+ ):
1086
+ result = c.scale_and_round([1.5, 2.5], 2, 2)
1087
+ assert result == [6, 10] # rounded(1.5 * 4) = 6, rounded(2.5 * 4) = 10
1088
+
1089
+
1090
+ @pytest.mark.unit
1091
+ def test_scale_and_round_with_ints() -> None:
1092
+ c = Circuit()
1093
+ c.scale_base = 2
1094
+ c.scale_exponent = 2
1095
+
1096
+ with patch(
1097
+ "python.core.model_processing.onnx_quantizer.layers.base.BaseOpQuantizer.get_scaling",
1098
+ return_value=4.0,
1099
+ ):
1100
+ result = c.scale_and_round([1, 2, 3], 2, 2)
1101
+ assert result == [1, 2, 3] # No change for integers
1102
+
1103
+
1104
+ @pytest.mark.unit
1105
+ def test_scale_and_round_with_tensors() -> None:
1106
+ c = Circuit()
1107
+ c.scale_base = 2
1108
+ c.scale_exponent = 2
1109
+
1110
+ with patch(
1111
+ "python.core.model_processing.onnx_quantizer.layers.base.BaseOpQuantizer.get_scaling",
1112
+ return_value=4.0,
1113
+ ):
1114
+
1115
+ tensor_input = [1.5, 2.5]
1116
+ result = c.scale_and_round(tensor_input, 2, 2)
1117
+ assert result == [6, 10]
1118
+
1119
+
1120
+ # ---------- Test _to_json_safely and _read_from_json_safely ----------
1121
+ @pytest.mark.unit
1122
+ @patch("python.core.circuits.base.to_json")
1123
+ def test_to_json_safely_success(mock_to_json: MagicMock) -> None:
1124
+ c = Circuit()
1125
+ c._to_json_safely({"key": "value"}, "file.json", "test var")
1126
+ mock_to_json.assert_called_once_with({"key": "value"}, "file.json")
1127
+
1128
+
1129
+ @pytest.mark.unit
1130
+ @patch("python.core.circuits.base.to_json", side_effect=Exception("Write failed"))
1131
+ def test_to_json_safely_failure(mock_to_json: MagicMock) -> None:
1132
+ c = Circuit()
1133
+ with pytest.raises(
1134
+ CircuitFileError,
1135
+ match=re.escape("Failed to write test var file: file.json"),
1136
+ ):
1137
+ c._to_json_safely({"key": "value"}, "file.json", "test var")
1138
+
1139
+
1140
+ @pytest.mark.unit
1141
+ @patch("python.core.circuits.base.read_from_json", return_value={"key": "value"})
1142
+ def test_read_from_json_safely_success(mock_read: MagicMock) -> None:
1143
+ c = Circuit()
1144
+ result = c._read_from_json_safely("file.json")
1145
+ mock_read.assert_called_once_with("file.json")
1146
+ assert result == {"key": "value"}
1147
+
1148
+
1149
+ @pytest.mark.unit
1150
+ @patch("python.core.circuits.base.read_from_json", side_effect=Exception("Read failed"))
1151
+ def test_read_from_json_safely_failure(mock_read: MagicMock) -> None:
1152
+ c = Circuit()
1153
+ with pytest.raises(
1154
+ CircuitFileError,
1155
+ match=re.escape("Failed to read input file: file.json"),
1156
+ ):
1157
+ c._read_from_json_safely("file.json")
1158
+
1159
+
1160
+ # ---------- Test _adjust_single_input ----------
1161
+ @pytest.mark.unit
1162
+ def test_adjust_single_input_success() -> None:
1163
+ c = Circuit()
1164
+ c.input_shape = [2, 2]
1165
+ c.scale_base = 2
1166
+ c.scale_exponent = 1
1167
+ five = 5
1168
+
1169
+ inputs = {"input": [1, 2, 3, 4], "extra": 5}
1170
+ result = c._adjust_single_input(inputs)
1171
+
1172
+ assert "input" in result
1173
+ assert "extra" in result
1174
+ assert result["extra"] == five
1175
+
1176
+
1177
+ # ---------- Test _adjust_multiple_inputs ----------
1178
+ @pytest.mark.unit
1179
+ def test_adjust_multiple_inputs_success() -> None:
1180
+ c = Circuit()
1181
+ c.x_shape = [2]
1182
+ c.y_shape = [2]
1183
+ c.scale_base = 2
1184
+ c.scale_exponent = 1
1185
+ five = 5
1186
+
1187
+ inputs = {"x": [1, 2], "y": [3, 4], "z": 5}
1188
+ input_variables = ["x", "y"]
1189
+ result = c._adjust_multiple_inputs(inputs, input_variables)
1190
+
1191
+ assert "x" in result
1192
+ assert "y" in result
1193
+ assert "z" in result
1194
+ assert result["z"] == five
1195
+
1196
+
1197
+ # ---------- Test _reshape_input_value ----------
1198
+ @pytest.mark.unit
1199
+ def test_reshape_input_value_success() -> None:
1200
+ c = Circuit()
1201
+ c.input_shape = [2, 2]
1202
+
1203
+ result = c._reshape_input_value([1, 2, 3, 4], "input_shape", "input")
1204
+ assert result == [[1, 2], [3, 4]]
1205
+
1206
+
1207
+ @pytest.mark.unit
1208
+ def test_reshape_input_value_missing_shape_attr() -> None:
1209
+ c = Circuit()
1210
+
1211
+ with pytest.raises(
1212
+ CircuitConfigurationError,
1213
+ match="Required shape attribute 'missing_shape'",
1214
+ ):
1215
+ c._reshape_input_value([1, 2, 3, 4], "missing_shape", "input")
1216
+
1217
+
1218
+ @pytest.mark.unit
1219
+ def test_reshape_input_value_invalid_shape() -> None:
1220
+ c = Circuit()
1221
+ c.input_shape = [2, 3] # 6 elements needed
1222
+
1223
+ with pytest.raises(CircuitProcessingError, match="Failed to reshape input data"):
1224
+ c._reshape_input_value([1, 2, 3, 4], "input_shape", "input")
1225
+
1226
+
1227
+ # ---------- Test scale_inputs_only ----------
1228
+ @pytest.mark.unit
1229
+ def test_scale_inputs_only_success() -> None:
1230
+ c = Circuit()
1231
+ c.scale_base = 2
1232
+ c.scale_exponent = 1
1233
+
1234
+ inputs = {"x": [1, 2], "y": [3, 4]}
1235
+ with patch.object(
1236
+ c,
1237
+ "scale_and_round",
1238
+ side_effect=lambda v, _sb, _se: [v[0] * 2, v[1] * 2],
1239
+ ):
1240
+ result = c.scale_inputs_only(inputs)
1241
+ assert result == {"x": [2, 4], "y": [6, 8]}
1242
+
1243
+
1244
+ # ---------- Test rename_inputs ----------
1245
+ @pytest.mark.unit
1246
+ def test_rename_inputs_single_input() -> None:
1247
+ c = Circuit()
1248
+ c.input_variables = ["input"]
1249
+
1250
+ inputs = {"input_data": [1, 2, 3], "extra": 4}
1251
+ result = c.rename_inputs(inputs)
1252
+
1253
+ assert result == {"input": [1, 2, 3], "extra": 4}
1254
+
1255
+
1256
+ @pytest.mark.unit
1257
+ def test_rename_inputs_multiple_inputs() -> None:
1258
+ c = Circuit()
1259
+ c.input_variables = ["x", "y"]
1260
+
1261
+ inputs = {"x": [1, 2], "y": [3, 4], "z": 5}
1262
+ result = c.rename_inputs(inputs)
1263
+
1264
+ assert result == inputs # Should remain unchanged
1265
+
1266
+
1267
+ # ---------- Test _rename_single_input ----------
1268
+ @pytest.mark.unit
1269
+ def test_rename_single_input_success() -> None:
1270
+ c = Circuit()
1271
+ inputs = {"input_vec": [1, 2, 3], "extra": 4}
1272
+ result = c._rename_single_input(inputs)
1273
+
1274
+ assert result == {"input": [1, 2, 3], "extra": 4}
1275
+
1276
+
1277
+ @pytest.mark.unit
1278
+ def test_rename_single_input_multiple_keys_error() -> None:
1279
+ c = Circuit()
1280
+ inputs = {"input1": [1, 2], "input2": [3, 4]}
1281
+
1282
+ with pytest.raises(
1283
+ CircuitInputError,
1284
+ match="Multiple inputs found containing 'input'",
1285
+ ):
1286
+ c._rename_single_input(inputs)
1287
+
1288
+
1289
+ # ---------- Test reshape_inputs_for_inference ----------
1290
+ @pytest.mark.unit
1291
+ def test_reshape_inputs_for_inference_single_input() -> None:
1292
+ c = Circuit()
1293
+ c.input_shape = [2, 2]
1294
+
1295
+ inputs = {"data": [1, 2, 3, 4]}
1296
+ result = c.reshape_inputs_for_inference(inputs)
1297
+
1298
+ expected = np.array([[1, 2], [3, 4]])
1299
+ np.testing.assert_array_equal(result, expected)
1300
+
1301
+
1302
+ @pytest.mark.unit
1303
+ def test_reshape_inputs_for_inference_dict_shapes() -> None:
1304
+
1305
+ c = Circuit()
1306
+ c.input_shape = {"x": [2], "y": [2]}
1307
+
1308
+ inputs = {"x": [1, 2], "y": [3, 4]}
1309
+ result = c.reshape_inputs_for_inference(inputs)
1310
+
1311
+ _ = {"x": np.array([1, 2]), "y": np.array([3, 4])}
1312
+ assert list(result.keys()) == ["x", "y"]
1313
+
1314
+
1315
+ @pytest.mark.unit
1316
+ def test_reshape_inputs_for_inference_missing_shape() -> None:
1317
+ c = Circuit()
1318
+ inputs = {"data": [1, 2, 3, 4]}
1319
+
1320
+ with pytest.raises(CircuitConfigurationError, match="input_shape"):
1321
+ c.reshape_inputs_for_inference(inputs)
1322
+
1323
+
1324
+ @pytest.mark.unit
1325
+ def test_reshape_inputs_for_inference_shape_mismatch() -> None:
1326
+ c = Circuit()
1327
+ c.input_shape = [2, 3] # Needs 6 elements
1328
+
1329
+ inputs = {"data": [1, 2, 3, 4]} # Only 4 elements
1330
+
1331
+ with pytest.raises(ShapeMismatchError):
1332
+ c.reshape_inputs_for_inference(inputs)
1333
+
1334
+
1335
+ # ---------- Test _reshape_dict_inputs ----------
1336
+ @pytest.mark.unit
1337
+ def test_reshape_dict_inputs_success() -> None:
1338
+ c = Circuit()
1339
+ shape_dict = {"x": [2], "y": [2, 1]}
1340
+
1341
+ inputs = {"x": [1, 2], "y": [3, 4]}
1342
+ result = c._reshape_dict_inputs(inputs, shape_dict)
1343
+
1344
+ np.testing.assert_array_equal(result["x"], np.array([1, 2]))
1345
+ np.testing.assert_array_equal(result["y"], np.array([[3], [4]]))
1346
+
1347
+
1348
+ @pytest.mark.unit
1349
+ def test_reshape_dict_inputs_non_dict_shape() -> None:
1350
+ c = Circuit()
1351
+ shape_list = [2, 2]
1352
+
1353
+ with pytest.raises(
1354
+ CircuitInputError,
1355
+ match="_reshape_dict_inputs requires dict shape",
1356
+ ):
1357
+ c._reshape_dict_inputs({"x": [1, 2]}, shape_list)
1358
+
1359
+
1360
+ @pytest.mark.unit
1361
+ def test_reshape_dict_inputs_shape_mismatch() -> None:
1362
+ c = Circuit()
1363
+ shape_dict = {"x": [2, 2]} # Needs 4 elements
1364
+
1365
+ inputs = {"x": [1, 2]} # Only 2 elements
1366
+
1367
+ with pytest.raises(ShapeMismatchError):
1368
+ c._reshape_dict_inputs(inputs, shape_dict)
1369
+
1370
+
1371
+ # ---------- Test reshape_inputs_for_circuit ----------
1372
+ @pytest.mark.unit
1373
+ def test_reshape_inputs_for_circuit_success() -> None:
1374
+ c = Circuit()
1375
+ inputs = {"x": [1, 2], "y": [3, 4]}
1376
+
1377
+ result = c.reshape_inputs_for_circuit(inputs)
1378
+
1379
+ assert result == {"input": [1, 2, 3, 4]}
1380
+
1381
+
1382
+ @pytest.mark.unit
1383
+ def test_reshape_inputs_for_circuit_with_input_shapes() -> None:
1384
+ c = Circuit()
1385
+ c.input_shapes = {"y": [2], "x": [2]} # Ordered differently
1386
+
1387
+ inputs = {"x": [1, 2], "y": [3, 4]}
1388
+
1389
+ result = c.reshape_inputs_for_circuit(inputs)
1390
+
1391
+ assert result == {"input": [3, 4, 1, 2]} # Respects order from input_shapes
1392
+
1393
+
1394
+ @pytest.mark.unit
1395
+ def test_reshape_inputs_for_circuit_invalid_type() -> None:
1396
+ c = Circuit()
1397
+
1398
+ with pytest.raises(CircuitConfigurationError, match="Expected a dict, got list"):
1399
+ c.reshape_inputs_for_circuit([1, 2, 3, 4])
1400
+
1401
+
1402
+ @pytest.mark.unit
1403
+ def test_reshape_inputs_for_circuit_missing_key() -> None:
1404
+ c = Circuit()
1405
+ c.input_shapes = {"x": [2], "y": [2]}
1406
+
1407
+ inputs = {"x": [1, 2]} # Missing "y"
1408
+
1409
+ with pytest.raises(CircuitProcessingError, match="Missing expected input key 'y'"):
1410
+ c.reshape_inputs_for_circuit(inputs)
1411
+
1412
+
1413
+ @pytest.mark.unit
1414
+ def test_reshape_inputs_for_circuit_unsupported_type() -> None:
1415
+ c = Circuit()
1416
+
1417
+ inputs = {"x": "invalid_type"}
1418
+
1419
+ with pytest.raises(
1420
+ CircuitProcessingError,
1421
+ match="Unsupported input type for key 'x'",
1422
+ ):
1423
+ c.reshape_inputs_for_circuit(inputs)
1424
+
1425
+
1426
+ # ---------- Test load_and_compare_witness_to_io ----------
1427
+ @pytest.mark.unit
1428
+ @patch("python.core.circuits.base.load_witness")
1429
+ @patch("python.core.circuits.base.compare_witness_to_io")
1430
+ def test_load_and_compare_witness_to_io_success(
1431
+ mock_compare: MagicMock,
1432
+ mock_load: MagicMock,
1433
+ ) -> None:
1434
+ c = Circuit()
1435
+ c._read_from_json_safely = MagicMock
1436
+ mock_load.return_value = {"modulus": 10, "public_inputs": [1, 2, 3]}
1437
+ mock_compare.return_value = True
1438
+
1439
+ _ = c.load_and_compare_witness_to_io(
1440
+ "witness.bin",
1441
+ "inputs.json",
1442
+ "outputs.json",
1443
+ ZKProofSystems.Expander,
1444
+ )
1445
+
1446
+ mock_load.assert_called_once_with("witness.bin", ZKProofSystems.Expander)
1447
+ mock_compare.assert_called_once()
1448
+
1449
+
1450
+ @pytest.mark.unit
1451
+ @patch("python.core.circuits.base.load_witness")
1452
+ def test_load_and_compare_witness_to_io_missing_modulus(mock_load: MagicMock) -> None:
1453
+ c = Circuit()
1454
+ c._read_from_json_safely = MagicMock
1455
+ mock_load.return_value = {"public_inputs": [1, 2, 3]} # No modulus
1456
+
1457
+ with pytest.raises(
1458
+ WitnessMatchError,
1459
+ match=r"Witness not correctly formed\. Missing modulus\.",
1460
+ ):
1461
+ c.load_and_compare_witness_to_io(
1462
+ "witness.bin",
1463
+ "inputs.json",
1464
+ "outputs.json",
1465
+ ZKProofSystems.Expander,
1466
+ )
1467
+
1468
+
1469
+ # ---------- Test prepare_inputs_for_verification ----------
1470
+ @pytest.mark.unit
1471
+ def test_prepare_inputs_for_verification_success(tmp_path: Path) -> None:
1472
+ c = Circuit()
1473
+ c._read_from_json_safely = MagicMock(return_value={"input": [1, 2, 3, 4]})
1474
+ c.reshape_inputs_for_circuit = MagicMock(return_value={"input": [1, 2, 3, 4]})
1475
+ c._to_json_safely = MagicMock()
1476
+
1477
+ input_file = tmp_path / "input.json"
1478
+ exec_config = MagicMock()
1479
+ exec_config.input_file = str(input_file)
1480
+
1481
+ result = c.prepare_inputs_for_verification(exec_config)
1482
+
1483
+ expected_file = str(tmp_path / "input_veri.json")
1484
+ assert result == expected_file
1485
+
1486
+ c._read_from_json_safely.assert_called_once_with(str(input_file))
1487
+ c.reshape_inputs_for_circuit.assert_called_once_with({"input": [1, 2, 3, 4]})
1488
+ c._to_json_safely.assert_called_once_with(
1489
+ {"input": [1, 2, 3, 4]},
1490
+ expected_file,
1491
+ "renamed input",
1492
+ )