JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.1.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/METADATA +2 -2
  2. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/RECORD +51 -24
  3. python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +86 -32
  7. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  8. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  9. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  10. python/core/model_processing/onnx_quantizer/layers/base.py +121 -1
  11. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  12. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  13. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  14. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  15. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  16. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -1
  17. python/core/utils/general_layer_functions.py +17 -12
  18. python/core/utils/model_registry.py +6 -3
  19. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  20. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  21. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  22. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  23. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  24. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  25. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  26. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  27. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  28. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  29. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  30. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  31. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  32. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  33. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  34. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  35. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  36. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  37. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  38. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  39. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  40. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +265 -0
  41. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  42. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  43. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  44. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  45. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  46. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  47. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  48. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  49. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/WHEEL +0 -0
  50. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/entry_points.txt +0 -0
  51. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/licenses/LICENSE +0 -0
  52. {jstprove-1.0.0.dist-info → jstprove-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Callable
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from collections.abc import Callable
4
7
 
5
8
  import onnx
6
9
 
@@ -9,6 +12,7 @@ from python.core.model_processing.onnx_quantizer.exceptions import (
9
12
  MissingHandlerError,
10
13
  UnsupportedOpError,
11
14
  )
15
+ from python.core.model_processing.onnx_quantizer.layers.add import AddQuantizer
12
16
  from python.core.model_processing.onnx_quantizer.layers.base import (
13
17
  PassthroughQuantizer,
14
18
  ScaleConfig,
@@ -64,6 +68,7 @@ class ONNXOpQuantizer:
64
68
  self.new_initializers = []
65
69
 
66
70
  # Register handlers
71
+ self.register("Add", AddQuantizer(self.new_initializers))
67
72
  self.register("Conv", ConvQuantizer(self.new_initializers))
68
73
  self.register("Relu", ReluQuantizer())
69
74
  self.register("Reshape", PassthroughQuantizer())
@@ -83,7 +83,9 @@ class GeneralLayerFunctions:
83
83
  tensor = torch.mul(tensor, self.scale_base**self.scale_exponent)
84
84
 
85
85
  tensor = tensor.long()
86
+ return self.reshape_inputs(tensor)
86
87
 
88
+ def reshape_inputs(self: GeneralLayerFunctions, tensor: torch.Tensor) -> None:
87
89
  if hasattr(self, "input_shape"):
88
90
  shape = self.input_shape
89
91
  if hasattr(self, "adjust_shape") and callable(
@@ -94,6 +96,7 @@ class GeneralLayerFunctions:
94
96
  tensor = tensor.reshape(shape)
95
97
  except RuntimeError as e:
96
98
  raise ShapeMismatchError(shape, list(tensor.shape)) from e
99
+
97
100
  return tensor
98
101
 
99
102
  def get_inputs(
@@ -154,21 +157,19 @@ class GeneralLayerFunctions:
154
157
  # If unknown dim in batch spot, assume batch size of 1
155
158
  first_key = next(iter(keys))
156
159
  input_shape = self.input_shape[first_key]
157
- input_shape[0] = 1 if input_shape[0] < 1 else input_shape[0]
160
+ input_shape[0] = max(input_shape[0], 1)
158
161
  return self.get_rand_inputs(input_shape)
159
162
  inputs = {}
160
163
  for key in keys:
161
164
  # If unknown dim in batch spot, assume batch size of 1
162
- input_shape = self.input_shape[keys[key]]
163
- if not isinstance(input_shape, list) and not isinstance(
164
- input_shape,
165
- tuple,
166
- ):
165
+ input_shape = self.input_shape[key]
166
+ if not isinstance(input_shape, (list, tuple)):
167
167
  msg = f"Invalid input shape for key '{key}': {input_shape}"
168
168
  raise CircuitUtilsError(msg)
169
- input_shape[0] = 1 if input_shape[0] < 1 else input_shape[0]
169
+ input_shape[0] = max(input_shape[0], 1)
170
170
  inputs[key] = self.get_rand_inputs(input_shape)
171
171
  return inputs
172
+
172
173
  if not (hasattr(self, "scale_base") and hasattr(self, "scale_exponent")):
173
174
  attr_name = "scale_base/scale_exponent"
174
175
  context = "needed for scaling random inputs"
@@ -195,8 +196,10 @@ class GeneralLayerFunctions:
195
196
  torch.Tensor: A tensor of random values in [-1, 1).
196
197
  """
197
198
  if not isinstance(input_shape, (list, tuple)):
198
- msg = f"Invalid input_shape type: {type(input_shape)}."
199
- " Expected list or tuple of ints."
199
+ msg = (
200
+ f"Invalid input_shape type: {type(input_shape)}."
201
+ " Expected list or tuple of ints."
202
+ )
200
203
  raise CircuitUtilsError(msg)
201
204
  if not all(isinstance(x, int) and x > 0 for x in input_shape):
202
205
  raise ShapeMismatchError(
@@ -240,9 +243,11 @@ class GeneralLayerFunctions:
240
243
  try:
241
244
  rescaled = torch.div(outputs, self.scale_base**self.scale_exponent)
242
245
  except Exception as e:
243
- msg = "Failed to rescale outputs using scale_base="
244
- f"{getattr(self, 'scale_base', None)} "
245
- f"and scale_exponent={getattr(self, 'scale_exponent', None)}: {e}"
246
+ msg = (
247
+ "Failed to rescale outputs using scale_base="
248
+ f"{getattr(self, 'scale_base', None)} "
249
+ f"and scale_exponent={getattr(self, 'scale_exponent', None)}: {e}"
250
+ )
246
251
  raise CircuitUtilsError(msg) from e
247
252
  return {
248
253
  "output": outputs.long().tolist(),
@@ -147,7 +147,8 @@ def get_models_to_test(
147
147
 
148
148
  Args:
149
149
  selected_models (list[str], optional):
150
- A list of model names to include. Defaults to None.
150
+ A list of model names to include. If None, returns empty list.
151
+ If provided but no matches found, returns empty list. Defaults to None.
151
152
  source_filter (str, optional):
152
153
  Restrict models to a specific source (e.g., "onnx", "class").
153
154
  Defaults to None.
@@ -155,10 +156,12 @@ def get_models_to_test(
155
156
  Returns:
156
157
  list[ModelEntry]: A filtered list of model entries.
157
158
  """
159
+ if selected_models is None:
160
+ return []
161
+
158
162
  models = MODELS_TO_TEST
159
163
 
160
- if selected_models is not None:
161
- models = [m for m in models if m.name in selected_models]
164
+ models = [m for m in models if m.name in selected_models]
162
165
 
163
166
  if source_filter is not None:
164
167
  models = [m for m in models if m.source == source_filter]
@@ -1,13 +1,15 @@
1
+ # ruff: noqa: S603
1
2
  import json
2
3
  import subprocess
3
4
  import sys
5
+ from collections.abc import Generator
4
6
  from pathlib import Path
5
- from typing import Generator
6
7
 
7
8
  import numpy as np
8
9
  import onnx
9
10
  import pytest
10
- from onnx import helper, numpy_helper
11
+ import torch
12
+ from onnx import TensorProto, helper, numpy_helper
11
13
 
12
14
 
13
15
  def create_simple_gemm_onnx_model(
@@ -19,14 +21,14 @@ def create_simple_gemm_onnx_model(
19
21
  # Define input
20
22
  input_tensor = helper.make_tensor_value_info(
21
23
  "input",
22
- onnx.TensorProto.FLOAT,
24
+ TensorProto.FLOAT,
23
25
  [1, input_size],
24
26
  )
25
27
 
26
28
  # Define output
27
29
  output_tensor = helper.make_tensor_value_info(
28
30
  "output",
29
- onnx.TensorProto.FLOAT,
31
+ TensorProto.FLOAT,
30
32
  [1, output_size],
31
33
  )
32
34
 
@@ -67,7 +69,7 @@ def create_simple_gemm_onnx_model(
67
69
  onnx.save(model, str(model_path))
68
70
 
69
71
 
70
- @pytest.mark.e2e()
72
+ @pytest.mark.e2e
71
73
  def test_parallel_compile_and_witness_two_simple_models( # noqa: PLR0915
72
74
  tmp_path: str,
73
75
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -115,7 +117,7 @@ def test_parallel_compile_and_witness_two_simple_models( # noqa: PLR0915
115
117
 
116
118
  # Run compile commands
117
119
  result1 = subprocess.run(
118
- compile_cmd1, # noqa: S603
120
+ compile_cmd1,
119
121
  capture_output=True,
120
122
  text=True,
121
123
  check=False,
@@ -123,7 +125,7 @@ def test_parallel_compile_and_witness_two_simple_models( # noqa: PLR0915
123
125
  assert result1.returncode == 0, f"Compile failed for model1: {result1.stderr}"
124
126
 
125
127
  result2 = subprocess.run(
126
- compile_cmd2, # noqa: S603
128
+ compile_cmd2,
127
129
  capture_output=True,
128
130
  text=True,
129
131
  check=False,
@@ -179,8 +181,8 @@ def test_parallel_compile_and_witness_two_simple_models( # noqa: PLR0915
179
181
  ]
180
182
 
181
183
  # Start both processes
182
- proc1 = subprocess.Popen(witness_cmd1) # noqa: S603
183
- proc2 = subprocess.Popen(witness_cmd2) # noqa: S603
184
+ proc1 = subprocess.Popen(witness_cmd1)
185
+ proc2 = subprocess.Popen(witness_cmd2)
184
186
 
185
187
  # Wait for both to complete
186
188
  proc1.wait()
@@ -215,3 +217,194 @@ def test_parallel_compile_and_witness_two_simple_models( # noqa: PLR0915
215
217
  len(output2["output"]) == model2_output_size
216
218
  ), f"Output2 should have {model2_output_size} elements,"
217
219
  f" got {len(output2['output'])}"
220
+
221
+
222
+ def create_multi_input_multi_output_model(model_path: Path) -> None:
223
+ """Create a simple ONNX model with two inputs and two outputs."""
224
+ # Define inputs
225
+ x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
226
+ w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [1, 1, 4, 4])
227
+
228
+ # Define outputs
229
+ y1 = helper.make_tensor_value_info("sum", TensorProto.FLOAT, [1, 1, 4, 4])
230
+ y2 = helper.make_tensor_value_info("pooled", TensorProto.FLOAT, [1, 1, 2, 2])
231
+
232
+ # Node 1: Add
233
+ add_node = helper.make_node("Add", inputs=["X", "W"], outputs=["sum"])
234
+
235
+ # Node 2: MaxPool
236
+ pool_node = helper.make_node(
237
+ "MaxPool",
238
+ inputs=["sum"],
239
+ outputs=["pooled"],
240
+ kernel_shape=[2, 2],
241
+ strides=[2, 2],
242
+ dilations=[1, 1],
243
+ pads=[0, 0, 0, 0],
244
+ ceil_mode=0,
245
+ )
246
+
247
+ # Build the graph
248
+ graph_def = helper.make_graph(
249
+ [add_node, pool_node],
250
+ "TwoOutputGraph",
251
+ [x, w],
252
+ [y1, y2],
253
+ )
254
+
255
+ model_def = helper.make_model(graph_def, producer_name="pytest-multi-output-model")
256
+ onnx.save(model_def, model_path)
257
+
258
+
259
+ @pytest.mark.e2e
260
+ def test_multi_input_multi_output_model_e2e(tmp_path: Path) -> None:
261
+ """
262
+ E2E test: compile, witness, and verify outputs
263
+ for a multi-input/multi-output ONNX model.
264
+ """
265
+ model_path = tmp_path / "multi_output_no_identity.onnx"
266
+ circuit_path = tmp_path / "circuit.txt"
267
+ input_path = tmp_path / "input.json"
268
+ output_path = tmp_path / "output.json"
269
+ witness_path = tmp_path / "witness.bin"
270
+ proof_path = tmp_path / "proof.bin"
271
+
272
+ # --- Step 1: Generate model ---
273
+ create_multi_input_multi_output_model(model_path)
274
+
275
+ # --- Step 2: Compile model ---
276
+ compile_cmd = [
277
+ sys.executable,
278
+ "-m",
279
+ "python.frontend.cli",
280
+ "compile",
281
+ "-m",
282
+ str(model_path),
283
+ "-c",
284
+ str(circuit_path),
285
+ ]
286
+ result = subprocess.run(compile_cmd, capture_output=True, text=True, check=False)
287
+ assert (
288
+ result.returncode == 0
289
+ ), f"Compile failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
290
+
291
+ # --- Step 3: Create input JSON ---
292
+ # Simple constant tensors (shape [1,1,4,4])
293
+ x = [
294
+ [
295
+ [
296
+ [1.0, 2.0, 3.0, 4.0],
297
+ [5.0, 6.0, 7.0, 8.0],
298
+ [9.0, 10.0, 11.0, 12.0],
299
+ [13.0, 14.0, 15.0, 16.0],
300
+ ],
301
+ ],
302
+ ]
303
+ w = [
304
+ [
305
+ [
306
+ [0.1, 0.2, 0.3, 0.4],
307
+ [0.5, 0.6, 0.7, 0.8],
308
+ [0.9, 1.0, 1.1, 1.2],
309
+ [1.3, 1.4, 1.5, 1.6],
310
+ ],
311
+ ],
312
+ ]
313
+
314
+ with Path.open(input_path, "w") as f:
315
+ json.dump({"X": x, "W": w}, f)
316
+
317
+ # --- Step 4: Run witness ---
318
+ witness_cmd = [
319
+ sys.executable,
320
+ "-m",
321
+ "python.frontend.cli",
322
+ "witness",
323
+ "-c",
324
+ str(circuit_path),
325
+ "-i",
326
+ str(input_path),
327
+ "-o",
328
+ str(output_path),
329
+ "-w",
330
+ str(witness_path),
331
+ ]
332
+ result = subprocess.run(witness_cmd, capture_output=True, text=True, check=False)
333
+ assert (
334
+ result.returncode == 0
335
+ ), f"Witness failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
336
+
337
+ # --- Step 5: Validate output files ---
338
+ assert output_path.exists(), "Output file not generated"
339
+ assert witness_path.exists(), "Witness file not generated"
340
+
341
+ with Path.open(output_path) as f:
342
+ outputs = json.load(f)
343
+
344
+ output_raw = (
345
+ (torch.as_tensor(x) * 2**18).long() + (torch.as_tensor(w) * 2**18).long()
346
+ ).flatten()
347
+
348
+ second_outputs = output_raw.clone().reshape([1, 1, 4, 4])
349
+
350
+ outputs_2 = torch.max_pool2d(
351
+ second_outputs,
352
+ kernel_size=2,
353
+ stride=2,
354
+ dilation=1,
355
+ padding=0,
356
+ ).flatten()
357
+
358
+ output_raw = torch.cat((output_raw, outputs_2))
359
+
360
+ assert torch.allclose(
361
+ torch.as_tensor(outputs["output"]),
362
+ output_raw,
363
+ rtol=1e-3,
364
+ atol=1e-5,
365
+ ), "Outputs do not match"
366
+
367
+ # --- Step 5: Prove ---
368
+ prove_cmd = [
369
+ sys.executable,
370
+ "-m",
371
+ "python.frontend.cli",
372
+ "prove",
373
+ "-c",
374
+ str(circuit_path),
375
+ "-w",
376
+ str(witness_path),
377
+ "-p",
378
+ str(proof_path),
379
+ ]
380
+ result = subprocess.run(prove_cmd, check=False, capture_output=True, text=True)
381
+ assert (
382
+ result.returncode == 0
383
+ ), f"Prove failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
384
+
385
+ # --- Step 6: Verify ---
386
+ verify_cmd = [
387
+ sys.executable,
388
+ "-m",
389
+ "python.frontend.cli",
390
+ "verify",
391
+ "-c",
392
+ str(circuit_path),
393
+ "-i",
394
+ str(input_path),
395
+ "-o",
396
+ str(output_path),
397
+ "-w",
398
+ str(witness_path),
399
+ "-p",
400
+ str(proof_path),
401
+ ]
402
+ result = subprocess.run(verify_cmd, check=False, capture_output=True, text=True)
403
+ assert (
404
+ result.returncode == 0
405
+ ), f"Verify failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
406
+
407
+ # --- Step 7: Validate output ---
408
+ assert output_path.exists(), "Output JSON not generated"
409
+ assert witness_path.exists(), "Witness not generated"
410
+ assert proof_path.exists(), "Proof not generated"