JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +6 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,201 @@
1
+ # test_converter.py
2
+ import tempfile
3
+ from pathlib import Path
4
+ from typing import Any, Generator
5
+ from unittest.mock import MagicMock, patch
6
+
7
+ import onnx
8
+ import onnxruntime as ort
9
+ import pytest
10
+ import torch
11
+ from onnx import TensorProto, helper
12
+
13
+ from python.core.model_processing.converters.onnx_converter import ONNXConverter
14
+
15
+
16
+ @pytest.fixture()
17
+ def temp_model_path(
18
+ tmp_path: Generator[Path, None, None],
19
+ ) -> Generator[Path, Any, None]:
20
+ model_path = tmp_path / "temp_model.onnx"
21
+ # Give it to the test
22
+ yield model_path
23
+
24
+ # After the test is done, remove it
25
+ if Path.exists(model_path):
26
+ model_path.unlink()
27
+
28
+
29
+ @pytest.fixture()
30
+ def temp_quant_model_path(
31
+ tmp_path: Generator[Path, None, None],
32
+ ) -> Generator[Path, Any, None]:
33
+ model_path = tmp_path / "temp_quantized_model.onnx"
34
+ # Give it to the test
35
+ yield model_path
36
+
37
+ # After the test is done, remove it
38
+ if Path.exists(model_path):
39
+ model_path.unlink()
40
+
41
+
42
+ @pytest.fixture()
43
+ def converter() -> ONNXConverter:
44
+ conv = ONNXConverter()
45
+ conv.model = MagicMock(name="model")
46
+ conv.quantized_model = MagicMock(name="quantized_model")
47
+ return conv
48
+
49
+
50
+ @pytest.mark.unit()
51
+ @patch("python.core.model_processing.converters.onnx_converter.onnx.save")
52
+ def test_save_model(mock_save: MagicMock, converter: ONNXConverter) -> None:
53
+ path = "model.onnx"
54
+ converter.save_model(path)
55
+ mock_save.assert_called_once_with(converter.model, path)
56
+
57
+
58
+ @pytest.mark.unit()
59
+ @patch("python.core.model_processing.converters.onnx_converter.onnx.load")
60
+ def test_load_model(mock_load: MagicMock, converter: ONNXConverter) -> None:
61
+ fake_model = MagicMock(name="onnx_model")
62
+ mock_load.return_value = fake_model
63
+
64
+ path = "model.onnx"
65
+ converter.load_model(path)
66
+
67
+ mock_load.assert_called_once_with(path)
68
+ assert converter.model == fake_model
69
+
70
+
71
+ @pytest.mark.unit()
72
+ @patch("python.core.model_processing.converters.onnx_converter.onnx.save")
73
+ def test_save_quantized_model(mock_save: MagicMock, converter: ONNXConverter) -> None:
74
+ path = "quantized_model.onnx"
75
+ converter.save_quantized_model(path)
76
+ mock_save.assert_called_once_with(converter.quantized_model, path)
77
+
78
+
79
+ @pytest.mark.unit()
80
+ @patch("python.core.model_processing.converters.onnx_converter.Path.exists")
81
+ @patch("python.core.model_processing.converters.onnx_converter.SessionOptions")
82
+ @patch("python.core.model_processing.converters.onnx_converter.InferenceSession")
83
+ @patch("python.core.model_processing.converters.onnx_converter.onnx.load")
84
+ def test_load_quantized_model(
85
+ mock_load: MagicMock,
86
+ mock_ort_sess: MagicMock,
87
+ mock_session_opts: MagicMock,
88
+ mock_exists: MagicMock,
89
+ converter: ONNXConverter,
90
+ ) -> None:
91
+
92
+ fake_model = MagicMock(name="onnx_model")
93
+ mock_load.return_value = fake_model
94
+ mock_exists.return_value = True # Mock os.path.exists to return True
95
+
96
+ mock_opts_instance = MagicMock(name="session_options")
97
+ mock_session_opts.return_value = mock_opts_instance
98
+
99
+ path = "quantized_model.onnx"
100
+ converter.load_quantized_model(path)
101
+
102
+ mock_load.assert_called_once_with(path)
103
+ mock_ort_sess.assert_called_once_with(
104
+ path,
105
+ mock_opts_instance,
106
+ providers=["CPUExecutionProvider"],
107
+ )
108
+ assert converter.quantized_model == fake_model
109
+
110
+
111
+ @pytest.mark.unit()
112
+ def test_get_outputs_with_mocked_session(converter: ONNXConverter) -> None:
113
+ dummy_input = [[1.0]]
114
+ dummy_output = [[2.0]]
115
+
116
+ mock_sess = MagicMock()
117
+
118
+ # Mock .get_inputs()[0].name => "input"
119
+ mock_input = MagicMock()
120
+ mock_input.name = "input"
121
+ mock_sess.get_inputs.return_value = [mock_input]
122
+
123
+ # Mock .get_outputs()[0].name => "output"
124
+ mock_output = MagicMock()
125
+ mock_output.name = "output"
126
+ mock_sess.get_outputs.return_value = [mock_output]
127
+
128
+ # Mock .run() output
129
+ mock_sess.run.return_value = dummy_output
130
+
131
+ converter.ort_sess = mock_sess
132
+
133
+ result = converter.get_outputs(dummy_input)
134
+
135
+ mock_sess.run.assert_called_once_with(["output"], {"input": dummy_input})
136
+ assert result == dummy_output
137
+
138
+
139
+ # Integration test
140
+
141
+
142
+ def create_dummy_model() -> onnx.ModelProto:
143
+ input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1])
144
+ output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1])
145
+ node = helper.make_node("Identity", inputs=["input"], outputs=["output"])
146
+ graph = helper.make_graph([node], "test-graph", [input_tensor], [output_tensor])
147
+
148
+ return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)])
149
+
150
+
151
+ @pytest.mark.integration()
152
+ def test_save_and_load_real_model() -> None:
153
+ converter = ONNXConverter()
154
+ model = create_dummy_model()
155
+ converter.model = model
156
+ converter.quantized_model = model
157
+
158
+ with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp:
159
+ # Save model
160
+ converter.save_model(tmp.name)
161
+
162
+ # Load model
163
+ converter.load_model(tmp.name)
164
+
165
+ # Validate loaded model
166
+ assert isinstance(converter.model, onnx.ModelProto)
167
+ assert converter.model.graph.name == model.graph.name
168
+ assert len(converter.model.graph.node) == 1
169
+ assert converter.model.graph.node[0].op_type == "Identity"
170
+
171
+ # Save model
172
+ converter.save_quantized_model(tmp.name)
173
+
174
+ # Load model
175
+ converter.load_quantized_model(tmp.name)
176
+
177
+ # Validate loaded model
178
+ assert isinstance(converter.model, onnx.ModelProto)
179
+ assert converter.model.graph.name == model.graph.name
180
+ assert len(converter.model.graph.node) == 1
181
+ assert converter.model.graph.node[0].op_type == "Identity"
182
+
183
+
184
+ @pytest.mark.integration()
185
+ def test_real_inference_from_onnx() -> None:
186
+ converter = ONNXConverter()
187
+ converter.model = create_dummy_model()
188
+
189
+ # Save and load into onnxruntime
190
+ with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp:
191
+ onnx.save(converter.model, tmp.name)
192
+ converter.ort_sess = ort.InferenceSession(
193
+ tmp.name,
194
+ providers=["CPUExecutionProvider"],
195
+ )
196
+
197
+ dummy_input = torch.tensor([1.0], dtype=torch.float32).numpy()
198
+ result = converter.get_outputs(dummy_input)
199
+
200
+ assert isinstance(result, list)
201
+ print(result) # Identity op should return input
@@ -0,0 +1,116 @@
1
+ import pytest
2
+ import numpy as np
3
+ import torch
4
+ import onnx
5
+
6
+ from onnx import TensorProto, shape_inference, helper, numpy_helper
7
+
8
+ from python.core.model_processing.converters.onnx_converter import ONNXConverter
9
+ from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_shape_dict
10
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import ONNXOpQuantizer
11
+
12
+ from onnxruntime import InferenceSession, SessionOptions
13
+ from onnxruntime_extensions import get_library_path, OrtPyFunction
14
+ from python.core.model_processing.onnx_custom_ops import conv
15
+
16
+ from python.core.model_processing.onnx_custom_ops.conv import int64_conv
17
+ from python.core.model_processing.onnx_custom_ops.gemm import int64_gemm7
18
+
19
+
20
+ @pytest.fixture
21
+ def tiny_conv_model_path(tmp_path):
22
+ # Create input and output tensor info
23
+ input_tensor = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 4, 4])
24
+ output_tensor = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 2, 2])
25
+
26
+ # Kernel weights (3x3 ones)
27
+ W_init = helper.make_tensor(
28
+ name='W',
29
+ data_type=TensorProto.FLOAT,
30
+ dims=[1, 1, 3, 3],
31
+ vals=np.ones((1 * 1 * 3 * 3), dtype=np.float32).tolist()
32
+ )
33
+ Z_init = helper.make_tensor(
34
+ name='Z',
35
+ data_type=TensorProto.FLOAT,
36
+ dims=[1],
37
+ vals=np.ones(( 1), dtype=np.float32).tolist()
38
+ )
39
+
40
+ # Conv node with no padding, stride 1
41
+ conv_node = helper.make_node(
42
+ 'Conv',
43
+ inputs=['X', 'W', 'Z'],
44
+ outputs=['Y'],
45
+ kernel_shape=[3, 3],
46
+ pads=[0, 0, 0, 0],
47
+ strides=[1, 1],
48
+ dilations = [1,1]
49
+ )
50
+
51
+ # Build graph and model
52
+ graph = helper.make_graph(
53
+ nodes=[conv_node],
54
+ name='TinyConvGraph',
55
+ inputs=[input_tensor],
56
+ outputs=[output_tensor],
57
+ initializer=[W_init, Z_init]
58
+ )
59
+
60
+ model = helper.make_model(graph, producer_name='tiny-conv-example')
61
+
62
+ # Save to a temporary file
63
+ model_path = tmp_path / "tiny_conv.onnx"
64
+ onnx.save(model, str(model_path))
65
+
66
+ return str(model_path)
67
+
68
+ @pytest.mark.integration
69
+ def test_tiny_conv(tiny_conv_model_path):
70
+ path = tiny_conv_model_path
71
+
72
+ converter = ONNXConverter()
73
+
74
+ X_input = np.arange(16, dtype=np.float32).reshape(1, 1, 4, 4)
75
+ id_count = 0
76
+ model = onnx.load(path)
77
+ # Fix, can remove this next line
78
+ onnx.checker.check_model(model)
79
+
80
+ # Check the model and print Y"s shape information
81
+ onnx.checker.check_model(model)
82
+ print(f"Before shape inference, the shape info of Y is:\n{model.graph.value_info}")
83
+
84
+ # Apply shape inference on the model
85
+ inferred_model = shape_inference.infer_shapes(model)
86
+
87
+ # Check the model and print Y"s shape information
88
+ onnx.checker.check_model(inferred_model)
89
+ # print(f"After shape inference, the shape info of Y is:\n{inferred_model.graph.value_info}")
90
+
91
+
92
+ domain_to_version = {opset.domain: opset.version for opset in model.opset_import}
93
+
94
+ inferred_model = shape_inference.infer_shapes(model)
95
+ output_name_to_shape = extract_shape_dict(inferred_model)
96
+ id_count = 0
97
+
98
+ new_model = converter.quantize_model(model, 2, 21)
99
+ custom_domain = onnx.helper.make_operatorsetid(domain="ai.onnx.contrib", version=1)
100
+ new_model.opset_import.append(custom_domain)
101
+ onnx.checker.check_model(new_model)
102
+
103
+ with open("model.onnx", "wb") as f:
104
+ f.write(new_model.SerializeToString())
105
+
106
+ model = onnx.load("model.onnx")
107
+ onnx.checker.check_model(model) # This throws a descriptive error
108
+
109
+ inputs = np.arange(16, dtype=np.float32).reshape(1, 1, 4, 4)
110
+ outputs_true = converter.run_model_onnx_runtime(path, inputs)
111
+
112
+ outputs_quant = converter.run_model_onnx_runtime("model.onnx", inputs)
113
+ true = torch.tensor(np.array(outputs_true), dtype=torch.float32)
114
+ quant = torch.tensor(np.array(outputs_quant), dtype=torch.float32) / (2**21)
115
+
116
+ assert torch.allclose(true, quant, rtol=1e-3, atol=1e-5), "Outputs do not match"