JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.1.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/METADATA +2 -2
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/RECORD +51 -24
- python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
- python/core/circuit_models/generic_onnx.py +43 -9
- python/core/circuits/base.py +231 -71
- python/core/model_processing/converters/onnx_converter.py +86 -32
- python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
- python/core/model_processing/onnx_custom_ops/relu.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +121 -1
- python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
- python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
- python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -1
- python/core/utils/general_layer_functions.py +17 -12
- python/core/utils/model_registry.py +6 -3
- python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
- python/tests/circuit_parent_classes/test_circuit.py +561 -38
- python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
- python/tests/onnx_quantizer_tests/__init__.py +1 -0
- python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
- python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/base.py +279 -0
- python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
- python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
- python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
- python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
- python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
- python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
- python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +265 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
- python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
- python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
- python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
- python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/WHEEL +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,14 @@
|
|
|
1
|
+
import re
|
|
1
2
|
import sys
|
|
2
3
|
from collections.abc import Generator
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from unittest.mock import MagicMock, patch
|
|
5
6
|
|
|
7
|
+
import numpy as np
|
|
6
8
|
import pytest
|
|
7
9
|
|
|
10
|
+
from python.core.utils.errors import ShapeMismatchError
|
|
11
|
+
|
|
8
12
|
sys.modules.pop("python.core.circuits.base", None)
|
|
9
13
|
|
|
10
14
|
|
|
@@ -30,6 +34,7 @@ with (
|
|
|
30
34
|
CircuitInputError,
|
|
31
35
|
CircuitProcessingError,
|
|
32
36
|
CircuitRunError,
|
|
37
|
+
WitnessMatchError,
|
|
33
38
|
)
|
|
34
39
|
|
|
35
40
|
|
|
@@ -148,6 +153,7 @@ def test_parse_proof_dispatch_logic(
|
|
|
148
153
|
c._gen_witness_preprocessing = MagicMock(return_value="i")
|
|
149
154
|
c.adjust_inputs = MagicMock(return_value="i")
|
|
150
155
|
c.rename_inputs = MagicMock(return_value="i")
|
|
156
|
+
c.prepare_inputs_for_verification = MagicMock(return_value="i")
|
|
151
157
|
|
|
152
158
|
c.load_and_compare_witness_to_io = MagicMock(return_value="True")
|
|
153
159
|
|
|
@@ -436,34 +442,53 @@ def test_gen_witness_preprocessing_write_json_true(mock_to_json: MagicMock) -> N
|
|
|
436
442
|
|
|
437
443
|
|
|
438
444
|
@pytest.mark.unit
|
|
439
|
-
|
|
440
|
-
def test_gen_witness_preprocessing_write_json_false(mock_to_json: MagicMock) -> None:
|
|
445
|
+
def test_gen_witness_preprocessing_write_json_false(tmp_path: Path) -> None:
|
|
441
446
|
c = Circuit()
|
|
442
447
|
c._file_info = {"quantized_model_path": "quant.pt"}
|
|
443
|
-
c.load_quantized_model = MagicMock()
|
|
444
|
-
c.get_inputs_from_file = MagicMock(return_value="mock_inputs")
|
|
445
|
-
c.reshape_inputs = MagicMock(return_value="in.json")
|
|
446
|
-
c.rescale_inputs = MagicMock(return_value="in.json")
|
|
447
|
-
c.rename_inputs = MagicMock(return_value="in.json")
|
|
448
|
-
c.rescale_and_reshape_inputs = MagicMock(return_value="in.json")
|
|
449
|
-
c.adjust_inputs = MagicMock(return_value="in.json")
|
|
450
448
|
|
|
451
|
-
|
|
452
|
-
c.
|
|
453
|
-
|
|
454
|
-
c.
|
|
455
|
-
|
|
456
|
-
|
|
449
|
+
# Mock all method calls used by _gen_witness_preprocessing
|
|
450
|
+
c.load_quantized_model = MagicMock()
|
|
451
|
+
c._read_from_json_safely = MagicMock(return_value={"mock": "inputs"})
|
|
452
|
+
c.scale_inputs_only = MagicMock(return_value={"scaled": "inputs"})
|
|
453
|
+
c.reshape_inputs_for_inference = MagicMock(return_value={"inference": "inputs"})
|
|
454
|
+
c.reshape_inputs_for_circuit = MagicMock(return_value={"input": [1, 2, 3]})
|
|
455
|
+
c._to_json_safely = MagicMock()
|
|
456
|
+
c.get_outputs = MagicMock(return_value={"raw_output": 123})
|
|
457
|
+
c.format_outputs = MagicMock(return_value={"formatted_output": 999})
|
|
458
|
+
|
|
459
|
+
input_path = tmp_path / "in.json"
|
|
460
|
+
output_path = tmp_path / "out.json"
|
|
461
|
+
|
|
462
|
+
result = c._gen_witness_preprocessing(
|
|
463
|
+
str(input_path),
|
|
464
|
+
str(output_path),
|
|
457
465
|
None,
|
|
458
466
|
write_json=False,
|
|
459
467
|
is_scaled=False,
|
|
460
468
|
)
|
|
461
469
|
|
|
470
|
+
# --- Assertions ---
|
|
462
471
|
c.load_quantized_model.assert_called_once_with("quant.pt")
|
|
463
|
-
c.
|
|
464
|
-
c.
|
|
465
|
-
c.
|
|
466
|
-
|
|
472
|
+
c._read_from_json_safely.assert_called_once_with(str(input_path))
|
|
473
|
+
c.scale_inputs_only.assert_called_once_with({"mock": "inputs"})
|
|
474
|
+
c.reshape_inputs_for_inference.assert_called_once_with({"scaled": "inputs"})
|
|
475
|
+
c.reshape_inputs_for_circuit.assert_called_once_with({"scaled": "inputs"})
|
|
476
|
+
|
|
477
|
+
# Verify safe JSON writes
|
|
478
|
+
new_input_file = str(input_path.with_name("in_adjusted.json"))
|
|
479
|
+
c._to_json_safely.assert_any_call({"input": [1, 2, 3]}, new_input_file, "input")
|
|
480
|
+
c._to_json_safely.assert_any_call(
|
|
481
|
+
{"formatted_output": 999},
|
|
482
|
+
str(output_path),
|
|
483
|
+
"output",
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Verify output generation
|
|
487
|
+
c.get_outputs.assert_called_once_with({"inference": "inputs"})
|
|
488
|
+
c.format_outputs.assert_called_once_with({"raw_output": 123})
|
|
489
|
+
|
|
490
|
+
# Function should return the adjusted input file path
|
|
491
|
+
assert result == new_input_file
|
|
467
492
|
|
|
468
493
|
|
|
469
494
|
# ---------- _compile_preprocessing ----------
|
|
@@ -896,24 +921,6 @@ def test_save_and_load_model_not_implemented() -> None:
|
|
|
896
921
|
assert hasattr(c, "load_quantized_model")
|
|
897
922
|
|
|
898
923
|
|
|
899
|
-
# ---------- New error handling tests ----------
|
|
900
|
-
@pytest.mark.unit
|
|
901
|
-
def test_adjust_inputs_file_error() -> None:
|
|
902
|
-
c = Circuit()
|
|
903
|
-
c.input_variables = ["input"]
|
|
904
|
-
c.input_shape = [2, 2]
|
|
905
|
-
c.scale_base = 2
|
|
906
|
-
c.scale_exponent = 1
|
|
907
|
-
|
|
908
|
-
with patch(
|
|
909
|
-
"python.core.circuits.base.read_from_json",
|
|
910
|
-
side_effect=FileNotFoundError("File not found"),
|
|
911
|
-
):
|
|
912
|
-
_ = c
|
|
913
|
-
with pytest.raises(CircuitFileError, match="Failed to read input file"):
|
|
914
|
-
c.adjust_inputs("nonexistent.json")
|
|
915
|
-
|
|
916
|
-
|
|
917
924
|
@pytest.mark.unit
|
|
918
925
|
def test_adjust_inputs_processing_error() -> None:
|
|
919
926
|
c = Circuit()
|
|
@@ -934,7 +941,7 @@ def test_adjust_inputs_processing_error() -> None:
|
|
|
934
941
|
CircuitProcessingError,
|
|
935
942
|
match="Failed to reshape input data",
|
|
936
943
|
):
|
|
937
|
-
c.adjust_inputs("dummy.json")
|
|
944
|
+
c.adjust_inputs({"input": [1, 2, 3, 4]}, "dummy.json")
|
|
938
945
|
|
|
939
946
|
|
|
940
947
|
@pytest.mark.unit
|
|
@@ -967,3 +974,519 @@ def test_get_inputs_from_file_processing_error() -> None:
|
|
|
967
974
|
match="Failed to scale input data",
|
|
968
975
|
):
|
|
969
976
|
c.get_inputs_from_file("dummy.json", is_scaled=False)
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
# ---------- Test _raise_unknown_run_type ----------
|
|
980
|
+
@pytest.mark.unit
|
|
981
|
+
def test_raise_unknown_run_type() -> None:
|
|
982
|
+
c = Circuit()
|
|
983
|
+
|
|
984
|
+
with pytest.raises(CircuitRunError, match="Unsupported run type: INVALID_TYPE"):
|
|
985
|
+
c._raise_unknown_run_type("INVALID_TYPE")
|
|
986
|
+
|
|
987
|
+
|
|
988
|
+
# ---------- Test contains_float ----------
|
|
989
|
+
@pytest.mark.unit
|
|
990
|
+
def test_contains_float_with_float() -> None:
|
|
991
|
+
c = Circuit()
|
|
992
|
+
assert c.contains_float(3.14) is True
|
|
993
|
+
assert c.contains_float(2.0) is True
|
|
994
|
+
assert c.contains_float(1.5) is True
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
@pytest.mark.unit
|
|
998
|
+
def test_contains_float_with_int() -> None:
|
|
999
|
+
c = Circuit()
|
|
1000
|
+
assert c.contains_float(1) is False
|
|
1001
|
+
assert c.contains_float(0) is False
|
|
1002
|
+
assert c.contains_float(-5) is False
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
@pytest.mark.unit
|
|
1006
|
+
def test_contains_float_with_list() -> None:
|
|
1007
|
+
c = Circuit()
|
|
1008
|
+
assert c.contains_float([1, 2, 3]) is False
|
|
1009
|
+
assert c.contains_float([1.0, 2, 3]) is True
|
|
1010
|
+
assert c.contains_float([1, 2.5, 3]) is True
|
|
1011
|
+
assert c.contains_float([]) is False
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
@pytest.mark.unit
|
|
1015
|
+
def test_contains_float_with_dict() -> None:
|
|
1016
|
+
c = Circuit()
|
|
1017
|
+
assert c.contains_float({"a": 1, "b": 2}) is False
|
|
1018
|
+
assert c.contains_float({"a": 1.0, "b": 2}) is True
|
|
1019
|
+
assert c.contains_float({"a": 1, "b": 2.5}) is True
|
|
1020
|
+
assert c.contains_float({}) is False
|
|
1021
|
+
|
|
1022
|
+
|
|
1023
|
+
@pytest.mark.unit
|
|
1024
|
+
def test_contains_float_nested_structures() -> None:
|
|
1025
|
+
c = Circuit()
|
|
1026
|
+
nested_with_float = {"a": [1, 2.0, 3], "b": {"c": 4.5}}
|
|
1027
|
+
nested_without_float = {"a": [1, 2, 3], "b": {"c": 4}}
|
|
1028
|
+
|
|
1029
|
+
assert c.contains_float(nested_with_float) is True
|
|
1030
|
+
assert c.contains_float(nested_without_float) is False
|
|
1031
|
+
|
|
1032
|
+
|
|
1033
|
+
# ---------- Test adjust_shape ----------
|
|
1034
|
+
@pytest.mark.unit
|
|
1035
|
+
def test_adjust_shape_list() -> None:
|
|
1036
|
+
c = Circuit()
|
|
1037
|
+
assert c.adjust_shape([1, 2, 3]) == [1, 2, 3]
|
|
1038
|
+
assert c.adjust_shape([0, -1, 5]) == [1, 1, 5]
|
|
1039
|
+
assert c.adjust_shape([-5, 0, 3]) == [1, 1, 3]
|
|
1040
|
+
|
|
1041
|
+
|
|
1042
|
+
@pytest.mark.unit
|
|
1043
|
+
def test_adjust_shape_dict_single_value() -> None:
|
|
1044
|
+
c = Circuit()
|
|
1045
|
+
result = c.adjust_shape({"key": [2, 3, 4]})
|
|
1046
|
+
assert result == [2, 3, 4]
|
|
1047
|
+
assert result == [2, 3, 4]
|
|
1048
|
+
|
|
1049
|
+
|
|
1050
|
+
@pytest.mark.unit
|
|
1051
|
+
def test_adjust_shape_dict_multiple_values() -> None:
|
|
1052
|
+
c = Circuit()
|
|
1053
|
+
input_dict = {"input": [2, 3, 4], "weight": [1, -1, 5], "bias": [0, 0, 3]}
|
|
1054
|
+
expected = {"input": [2, 3, 4], "weight": [1, 1, 5], "bias": [1, 1, 3]}
|
|
1055
|
+
assert c.adjust_shape(input_dict) == expected
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
@pytest.mark.unit
|
|
1059
|
+
def test_adjust_shape_invalid_type() -> None:
|
|
1060
|
+
c = Circuit()
|
|
1061
|
+
with pytest.raises(CircuitInputError, match="Expected list or dict for 'shape'"):
|
|
1062
|
+
c.adjust_shape("invalid")
|
|
1063
|
+
|
|
1064
|
+
|
|
1065
|
+
@pytest.mark.unit
|
|
1066
|
+
def test_adjust_shape_dict_invalid_value() -> None:
|
|
1067
|
+
c = Circuit()
|
|
1068
|
+
with pytest.raises(
|
|
1069
|
+
CircuitInputError,
|
|
1070
|
+
match="Expected shape list for input, got str",
|
|
1071
|
+
):
|
|
1072
|
+
c.adjust_shape({"bad": "not_a_list"})
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
# ---------- Test scale_and_round ----------
|
|
1076
|
+
@pytest.mark.unit
|
|
1077
|
+
def test_scale_and_round_with_floats() -> None:
|
|
1078
|
+
c = Circuit()
|
|
1079
|
+
c.scale_base = 2
|
|
1080
|
+
c.scale_exponent = 2
|
|
1081
|
+
|
|
1082
|
+
with patch(
|
|
1083
|
+
"python.core.model_processing.onnx_quantizer.layers.base.BaseOpQuantizer.get_scaling",
|
|
1084
|
+
return_value=4.0,
|
|
1085
|
+
):
|
|
1086
|
+
result = c.scale_and_round([1.5, 2.5], 2, 2)
|
|
1087
|
+
assert result == [6, 10] # rounded(1.5 * 4) = 6, rounded(2.5 * 4) = 10
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
@pytest.mark.unit
|
|
1091
|
+
def test_scale_and_round_with_ints() -> None:
|
|
1092
|
+
c = Circuit()
|
|
1093
|
+
c.scale_base = 2
|
|
1094
|
+
c.scale_exponent = 2
|
|
1095
|
+
|
|
1096
|
+
with patch(
|
|
1097
|
+
"python.core.model_processing.onnx_quantizer.layers.base.BaseOpQuantizer.get_scaling",
|
|
1098
|
+
return_value=4.0,
|
|
1099
|
+
):
|
|
1100
|
+
result = c.scale_and_round([1, 2, 3], 2, 2)
|
|
1101
|
+
assert result == [1, 2, 3] # No change for integers
|
|
1102
|
+
|
|
1103
|
+
|
|
1104
|
+
@pytest.mark.unit
|
|
1105
|
+
def test_scale_and_round_with_tensors() -> None:
|
|
1106
|
+
c = Circuit()
|
|
1107
|
+
c.scale_base = 2
|
|
1108
|
+
c.scale_exponent = 2
|
|
1109
|
+
|
|
1110
|
+
with patch(
|
|
1111
|
+
"python.core.model_processing.onnx_quantizer.layers.base.BaseOpQuantizer.get_scaling",
|
|
1112
|
+
return_value=4.0,
|
|
1113
|
+
):
|
|
1114
|
+
|
|
1115
|
+
tensor_input = [1.5, 2.5]
|
|
1116
|
+
result = c.scale_and_round(tensor_input, 2, 2)
|
|
1117
|
+
assert result == [6, 10]
|
|
1118
|
+
|
|
1119
|
+
|
|
1120
|
+
# ---------- Test _to_json_safely and _read_from_json_safely ----------
|
|
1121
|
+
@pytest.mark.unit
|
|
1122
|
+
@patch("python.core.circuits.base.to_json")
|
|
1123
|
+
def test_to_json_safely_success(mock_to_json: MagicMock) -> None:
|
|
1124
|
+
c = Circuit()
|
|
1125
|
+
c._to_json_safely({"key": "value"}, "file.json", "test var")
|
|
1126
|
+
mock_to_json.assert_called_once_with({"key": "value"}, "file.json")
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
@pytest.mark.unit
|
|
1130
|
+
@patch("python.core.circuits.base.to_json", side_effect=Exception("Write failed"))
|
|
1131
|
+
def test_to_json_safely_failure(mock_to_json: MagicMock) -> None:
|
|
1132
|
+
c = Circuit()
|
|
1133
|
+
with pytest.raises(
|
|
1134
|
+
CircuitFileError,
|
|
1135
|
+
match=re.escape("Failed to write test var file: file.json"),
|
|
1136
|
+
):
|
|
1137
|
+
c._to_json_safely({"key": "value"}, "file.json", "test var")
|
|
1138
|
+
|
|
1139
|
+
|
|
1140
|
+
@pytest.mark.unit
|
|
1141
|
+
@patch("python.core.circuits.base.read_from_json", return_value={"key": "value"})
|
|
1142
|
+
def test_read_from_json_safely_success(mock_read: MagicMock) -> None:
|
|
1143
|
+
c = Circuit()
|
|
1144
|
+
result = c._read_from_json_safely("file.json")
|
|
1145
|
+
mock_read.assert_called_once_with("file.json")
|
|
1146
|
+
assert result == {"key": "value"}
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
@pytest.mark.unit
|
|
1150
|
+
@patch("python.core.circuits.base.read_from_json", side_effect=Exception("Read failed"))
|
|
1151
|
+
def test_read_from_json_safely_failure(mock_read: MagicMock) -> None:
|
|
1152
|
+
c = Circuit()
|
|
1153
|
+
with pytest.raises(
|
|
1154
|
+
CircuitFileError,
|
|
1155
|
+
match=re.escape("Failed to read input file: file.json"),
|
|
1156
|
+
):
|
|
1157
|
+
c._read_from_json_safely("file.json")
|
|
1158
|
+
|
|
1159
|
+
|
|
1160
|
+
# ---------- Test _adjust_single_input ----------
|
|
1161
|
+
@pytest.mark.unit
|
|
1162
|
+
def test_adjust_single_input_success() -> None:
|
|
1163
|
+
c = Circuit()
|
|
1164
|
+
c.input_shape = [2, 2]
|
|
1165
|
+
c.scale_base = 2
|
|
1166
|
+
c.scale_exponent = 1
|
|
1167
|
+
five = 5
|
|
1168
|
+
|
|
1169
|
+
inputs = {"input": [1, 2, 3, 4], "extra": 5}
|
|
1170
|
+
result = c._adjust_single_input(inputs)
|
|
1171
|
+
|
|
1172
|
+
assert "input" in result
|
|
1173
|
+
assert "extra" in result
|
|
1174
|
+
assert result["extra"] == five
|
|
1175
|
+
|
|
1176
|
+
|
|
1177
|
+
# ---------- Test _adjust_multiple_inputs ----------
|
|
1178
|
+
@pytest.mark.unit
|
|
1179
|
+
def test_adjust_multiple_inputs_success() -> None:
|
|
1180
|
+
c = Circuit()
|
|
1181
|
+
c.x_shape = [2]
|
|
1182
|
+
c.y_shape = [2]
|
|
1183
|
+
c.scale_base = 2
|
|
1184
|
+
c.scale_exponent = 1
|
|
1185
|
+
five = 5
|
|
1186
|
+
|
|
1187
|
+
inputs = {"x": [1, 2], "y": [3, 4], "z": 5}
|
|
1188
|
+
input_variables = ["x", "y"]
|
|
1189
|
+
result = c._adjust_multiple_inputs(inputs, input_variables)
|
|
1190
|
+
|
|
1191
|
+
assert "x" in result
|
|
1192
|
+
assert "y" in result
|
|
1193
|
+
assert "z" in result
|
|
1194
|
+
assert result["z"] == five
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
# ---------- Test _reshape_input_value ----------
|
|
1198
|
+
@pytest.mark.unit
|
|
1199
|
+
def test_reshape_input_value_success() -> None:
|
|
1200
|
+
c = Circuit()
|
|
1201
|
+
c.input_shape = [2, 2]
|
|
1202
|
+
|
|
1203
|
+
result = c._reshape_input_value([1, 2, 3, 4], "input_shape", "input")
|
|
1204
|
+
assert result == [[1, 2], [3, 4]]
|
|
1205
|
+
|
|
1206
|
+
|
|
1207
|
+
@pytest.mark.unit
|
|
1208
|
+
def test_reshape_input_value_missing_shape_attr() -> None:
|
|
1209
|
+
c = Circuit()
|
|
1210
|
+
|
|
1211
|
+
with pytest.raises(
|
|
1212
|
+
CircuitConfigurationError,
|
|
1213
|
+
match="Required shape attribute 'missing_shape'",
|
|
1214
|
+
):
|
|
1215
|
+
c._reshape_input_value([1, 2, 3, 4], "missing_shape", "input")
|
|
1216
|
+
|
|
1217
|
+
|
|
1218
|
+
@pytest.mark.unit
|
|
1219
|
+
def test_reshape_input_value_invalid_shape() -> None:
|
|
1220
|
+
c = Circuit()
|
|
1221
|
+
c.input_shape = [2, 3] # 6 elements needed
|
|
1222
|
+
|
|
1223
|
+
with pytest.raises(CircuitProcessingError, match="Failed to reshape input data"):
|
|
1224
|
+
c._reshape_input_value([1, 2, 3, 4], "input_shape", "input")
|
|
1225
|
+
|
|
1226
|
+
|
|
1227
|
+
# ---------- Test scale_inputs_only ----------
|
|
1228
|
+
@pytest.mark.unit
|
|
1229
|
+
def test_scale_inputs_only_success() -> None:
|
|
1230
|
+
c = Circuit()
|
|
1231
|
+
c.scale_base = 2
|
|
1232
|
+
c.scale_exponent = 1
|
|
1233
|
+
|
|
1234
|
+
inputs = {"x": [1, 2], "y": [3, 4]}
|
|
1235
|
+
with patch.object(
|
|
1236
|
+
c,
|
|
1237
|
+
"scale_and_round",
|
|
1238
|
+
side_effect=lambda v, _sb, _se: [v[0] * 2, v[1] * 2],
|
|
1239
|
+
):
|
|
1240
|
+
result = c.scale_inputs_only(inputs)
|
|
1241
|
+
assert result == {"x": [2, 4], "y": [6, 8]}
|
|
1242
|
+
|
|
1243
|
+
|
|
1244
|
+
# ---------- Test rename_inputs ----------
|
|
1245
|
+
@pytest.mark.unit
|
|
1246
|
+
def test_rename_inputs_single_input() -> None:
|
|
1247
|
+
c = Circuit()
|
|
1248
|
+
c.input_variables = ["input"]
|
|
1249
|
+
|
|
1250
|
+
inputs = {"input_data": [1, 2, 3], "extra": 4}
|
|
1251
|
+
result = c.rename_inputs(inputs)
|
|
1252
|
+
|
|
1253
|
+
assert result == {"input": [1, 2, 3], "extra": 4}
|
|
1254
|
+
|
|
1255
|
+
|
|
1256
|
+
@pytest.mark.unit
|
|
1257
|
+
def test_rename_inputs_multiple_inputs() -> None:
|
|
1258
|
+
c = Circuit()
|
|
1259
|
+
c.input_variables = ["x", "y"]
|
|
1260
|
+
|
|
1261
|
+
inputs = {"x": [1, 2], "y": [3, 4], "z": 5}
|
|
1262
|
+
result = c.rename_inputs(inputs)
|
|
1263
|
+
|
|
1264
|
+
assert result == inputs # Should remain unchanged
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
# ---------- Test _rename_single_input ----------
|
|
1268
|
+
@pytest.mark.unit
|
|
1269
|
+
def test_rename_single_input_success() -> None:
|
|
1270
|
+
c = Circuit()
|
|
1271
|
+
inputs = {"input_vec": [1, 2, 3], "extra": 4}
|
|
1272
|
+
result = c._rename_single_input(inputs)
|
|
1273
|
+
|
|
1274
|
+
assert result == {"input": [1, 2, 3], "extra": 4}
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
@pytest.mark.unit
|
|
1278
|
+
def test_rename_single_input_multiple_keys_error() -> None:
|
|
1279
|
+
c = Circuit()
|
|
1280
|
+
inputs = {"input1": [1, 2], "input2": [3, 4]}
|
|
1281
|
+
|
|
1282
|
+
with pytest.raises(
|
|
1283
|
+
CircuitInputError,
|
|
1284
|
+
match="Multiple inputs found containing 'input'",
|
|
1285
|
+
):
|
|
1286
|
+
c._rename_single_input(inputs)
|
|
1287
|
+
|
|
1288
|
+
|
|
1289
|
+
# ---------- Test reshape_inputs_for_inference ----------
|
|
1290
|
+
@pytest.mark.unit
|
|
1291
|
+
def test_reshape_inputs_for_inference_single_input() -> None:
|
|
1292
|
+
c = Circuit()
|
|
1293
|
+
c.input_shape = [2, 2]
|
|
1294
|
+
|
|
1295
|
+
inputs = {"data": [1, 2, 3, 4]}
|
|
1296
|
+
result = c.reshape_inputs_for_inference(inputs)
|
|
1297
|
+
|
|
1298
|
+
expected = np.array([[1, 2], [3, 4]])
|
|
1299
|
+
np.testing.assert_array_equal(result, expected)
|
|
1300
|
+
|
|
1301
|
+
|
|
1302
|
+
@pytest.mark.unit
|
|
1303
|
+
def test_reshape_inputs_for_inference_dict_shapes() -> None:
|
|
1304
|
+
|
|
1305
|
+
c = Circuit()
|
|
1306
|
+
c.input_shape = {"x": [2], "y": [2]}
|
|
1307
|
+
|
|
1308
|
+
inputs = {"x": [1, 2], "y": [3, 4]}
|
|
1309
|
+
result = c.reshape_inputs_for_inference(inputs)
|
|
1310
|
+
|
|
1311
|
+
_ = {"x": np.array([1, 2]), "y": np.array([3, 4])}
|
|
1312
|
+
assert list(result.keys()) == ["x", "y"]
|
|
1313
|
+
|
|
1314
|
+
|
|
1315
|
+
@pytest.mark.unit
|
|
1316
|
+
def test_reshape_inputs_for_inference_missing_shape() -> None:
|
|
1317
|
+
c = Circuit()
|
|
1318
|
+
inputs = {"data": [1, 2, 3, 4]}
|
|
1319
|
+
|
|
1320
|
+
with pytest.raises(CircuitConfigurationError, match="input_shape"):
|
|
1321
|
+
c.reshape_inputs_for_inference(inputs)
|
|
1322
|
+
|
|
1323
|
+
|
|
1324
|
+
@pytest.mark.unit
|
|
1325
|
+
def test_reshape_inputs_for_inference_shape_mismatch() -> None:
|
|
1326
|
+
c = Circuit()
|
|
1327
|
+
c.input_shape = [2, 3] # Needs 6 elements
|
|
1328
|
+
|
|
1329
|
+
inputs = {"data": [1, 2, 3, 4]} # Only 4 elements
|
|
1330
|
+
|
|
1331
|
+
with pytest.raises(ShapeMismatchError):
|
|
1332
|
+
c.reshape_inputs_for_inference(inputs)
|
|
1333
|
+
|
|
1334
|
+
|
|
1335
|
+
# ---------- Test _reshape_dict_inputs ----------
|
|
1336
|
+
@pytest.mark.unit
|
|
1337
|
+
def test_reshape_dict_inputs_success() -> None:
|
|
1338
|
+
c = Circuit()
|
|
1339
|
+
shape_dict = {"x": [2], "y": [2, 1]}
|
|
1340
|
+
|
|
1341
|
+
inputs = {"x": [1, 2], "y": [3, 4]}
|
|
1342
|
+
result = c._reshape_dict_inputs(inputs, shape_dict)
|
|
1343
|
+
|
|
1344
|
+
np.testing.assert_array_equal(result["x"], np.array([1, 2]))
|
|
1345
|
+
np.testing.assert_array_equal(result["y"], np.array([[3], [4]]))
|
|
1346
|
+
|
|
1347
|
+
|
|
1348
|
+
@pytest.mark.unit
|
|
1349
|
+
def test_reshape_dict_inputs_non_dict_shape() -> None:
|
|
1350
|
+
c = Circuit()
|
|
1351
|
+
shape_list = [2, 2]
|
|
1352
|
+
|
|
1353
|
+
with pytest.raises(
|
|
1354
|
+
CircuitInputError,
|
|
1355
|
+
match="_reshape_dict_inputs requires dict shape",
|
|
1356
|
+
):
|
|
1357
|
+
c._reshape_dict_inputs({"x": [1, 2]}, shape_list)
|
|
1358
|
+
|
|
1359
|
+
|
|
1360
|
+
@pytest.mark.unit
|
|
1361
|
+
def test_reshape_dict_inputs_shape_mismatch() -> None:
|
|
1362
|
+
c = Circuit()
|
|
1363
|
+
shape_dict = {"x": [2, 2]} # Needs 4 elements
|
|
1364
|
+
|
|
1365
|
+
inputs = {"x": [1, 2]} # Only 2 elements
|
|
1366
|
+
|
|
1367
|
+
with pytest.raises(ShapeMismatchError):
|
|
1368
|
+
c._reshape_dict_inputs(inputs, shape_dict)
|
|
1369
|
+
|
|
1370
|
+
|
|
1371
|
+
# ---------- Test reshape_inputs_for_circuit ----------
|
|
1372
|
+
@pytest.mark.unit
|
|
1373
|
+
def test_reshape_inputs_for_circuit_success() -> None:
|
|
1374
|
+
c = Circuit()
|
|
1375
|
+
inputs = {"x": [1, 2], "y": [3, 4]}
|
|
1376
|
+
|
|
1377
|
+
result = c.reshape_inputs_for_circuit(inputs)
|
|
1378
|
+
|
|
1379
|
+
assert result == {"input": [1, 2, 3, 4]}
|
|
1380
|
+
|
|
1381
|
+
|
|
1382
|
+
@pytest.mark.unit
|
|
1383
|
+
def test_reshape_inputs_for_circuit_with_input_shapes() -> None:
|
|
1384
|
+
c = Circuit()
|
|
1385
|
+
c.input_shapes = {"y": [2], "x": [2]} # Ordered differently
|
|
1386
|
+
|
|
1387
|
+
inputs = {"x": [1, 2], "y": [3, 4]}
|
|
1388
|
+
|
|
1389
|
+
result = c.reshape_inputs_for_circuit(inputs)
|
|
1390
|
+
|
|
1391
|
+
assert result == {"input": [3, 4, 1, 2]} # Respects order from input_shapes
|
|
1392
|
+
|
|
1393
|
+
|
|
1394
|
+
@pytest.mark.unit
|
|
1395
|
+
def test_reshape_inputs_for_circuit_invalid_type() -> None:
|
|
1396
|
+
c = Circuit()
|
|
1397
|
+
|
|
1398
|
+
with pytest.raises(CircuitConfigurationError, match="Expected a dict, got list"):
|
|
1399
|
+
c.reshape_inputs_for_circuit([1, 2, 3, 4])
|
|
1400
|
+
|
|
1401
|
+
|
|
1402
|
+
@pytest.mark.unit
|
|
1403
|
+
def test_reshape_inputs_for_circuit_missing_key() -> None:
|
|
1404
|
+
c = Circuit()
|
|
1405
|
+
c.input_shapes = {"x": [2], "y": [2]}
|
|
1406
|
+
|
|
1407
|
+
inputs = {"x": [1, 2]} # Missing "y"
|
|
1408
|
+
|
|
1409
|
+
with pytest.raises(CircuitProcessingError, match="Missing expected input key 'y'"):
|
|
1410
|
+
c.reshape_inputs_for_circuit(inputs)
|
|
1411
|
+
|
|
1412
|
+
|
|
1413
|
+
@pytest.mark.unit
|
|
1414
|
+
def test_reshape_inputs_for_circuit_unsupported_type() -> None:
|
|
1415
|
+
c = Circuit()
|
|
1416
|
+
|
|
1417
|
+
inputs = {"x": "invalid_type"}
|
|
1418
|
+
|
|
1419
|
+
with pytest.raises(
|
|
1420
|
+
CircuitProcessingError,
|
|
1421
|
+
match="Unsupported input type for key 'x'",
|
|
1422
|
+
):
|
|
1423
|
+
c.reshape_inputs_for_circuit(inputs)
|
|
1424
|
+
|
|
1425
|
+
|
|
1426
|
+
# ---------- Test load_and_compare_witness_to_io ----------
|
|
1427
|
+
@pytest.mark.unit
|
|
1428
|
+
@patch("python.core.circuits.base.load_witness")
|
|
1429
|
+
@patch("python.core.circuits.base.compare_witness_to_io")
|
|
1430
|
+
def test_load_and_compare_witness_to_io_success(
|
|
1431
|
+
mock_compare: MagicMock,
|
|
1432
|
+
mock_load: MagicMock,
|
|
1433
|
+
) -> None:
|
|
1434
|
+
c = Circuit()
|
|
1435
|
+
c._read_from_json_safely = MagicMock
|
|
1436
|
+
mock_load.return_value = {"modulus": 10, "public_inputs": [1, 2, 3]}
|
|
1437
|
+
mock_compare.return_value = True
|
|
1438
|
+
|
|
1439
|
+
_ = c.load_and_compare_witness_to_io(
|
|
1440
|
+
"witness.bin",
|
|
1441
|
+
"inputs.json",
|
|
1442
|
+
"outputs.json",
|
|
1443
|
+
ZKProofSystems.Expander,
|
|
1444
|
+
)
|
|
1445
|
+
|
|
1446
|
+
mock_load.assert_called_once_with("witness.bin", ZKProofSystems.Expander)
|
|
1447
|
+
mock_compare.assert_called_once()
|
|
1448
|
+
|
|
1449
|
+
|
|
1450
|
+
@pytest.mark.unit
|
|
1451
|
+
@patch("python.core.circuits.base.load_witness")
|
|
1452
|
+
def test_load_and_compare_witness_to_io_missing_modulus(mock_load: MagicMock) -> None:
|
|
1453
|
+
c = Circuit()
|
|
1454
|
+
c._read_from_json_safely = MagicMock
|
|
1455
|
+
mock_load.return_value = {"public_inputs": [1, 2, 3]} # No modulus
|
|
1456
|
+
|
|
1457
|
+
with pytest.raises(
|
|
1458
|
+
WitnessMatchError,
|
|
1459
|
+
match=r"Witness not correctly formed\. Missing modulus\.",
|
|
1460
|
+
):
|
|
1461
|
+
c.load_and_compare_witness_to_io(
|
|
1462
|
+
"witness.bin",
|
|
1463
|
+
"inputs.json",
|
|
1464
|
+
"outputs.json",
|
|
1465
|
+
ZKProofSystems.Expander,
|
|
1466
|
+
)
|
|
1467
|
+
|
|
1468
|
+
|
|
1469
|
+
# ---------- Test prepare_inputs_for_verification ----------
|
|
1470
|
+
@pytest.mark.unit
|
|
1471
|
+
def test_prepare_inputs_for_verification_success(tmp_path: Path) -> None:
|
|
1472
|
+
c = Circuit()
|
|
1473
|
+
c._read_from_json_safely = MagicMock(return_value={"input": [1, 2, 3, 4]})
|
|
1474
|
+
c.reshape_inputs_for_circuit = MagicMock(return_value={"input": [1, 2, 3, 4]})
|
|
1475
|
+
c._to_json_safely = MagicMock()
|
|
1476
|
+
|
|
1477
|
+
input_file = tmp_path / "input.json"
|
|
1478
|
+
exec_config = MagicMock()
|
|
1479
|
+
exec_config.input_file = str(input_file)
|
|
1480
|
+
|
|
1481
|
+
result = c.prepare_inputs_for_verification(exec_config)
|
|
1482
|
+
|
|
1483
|
+
expected_file = str(tmp_path / "input_veri.json")
|
|
1484
|
+
assert result == expected_file
|
|
1485
|
+
|
|
1486
|
+
c._read_from_json_safely.assert_called_once_with(str(input_file))
|
|
1487
|
+
c.reshape_inputs_for_circuit.assert_called_once_with({"input": [1, 2, 3, 4]})
|
|
1488
|
+
c._to_json_safely.assert_called_once_with(
|
|
1489
|
+
{"input": [1, 2, 3, 4]},
|
|
1490
|
+
expected_file,
|
|
1491
|
+
"renamed input",
|
|
1492
|
+
)
|