JSTprove 1.1.0__py3-none-macosx_11_0_arm64.whl → 1.3.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (41) hide show
  1. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/METADATA +3 -3
  2. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/RECORD +40 -26
  3. python/core/binaries/onnx_generic_circuit_1-3-0 +0 -0
  4. python/core/circuits/base.py +29 -12
  5. python/core/circuits/errors.py +1 -2
  6. python/core/model_processing/converters/base.py +3 -3
  7. python/core/model_processing/converters/onnx_converter.py +28 -27
  8. python/core/model_processing/onnx_custom_ops/__init__.py +5 -4
  9. python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
  10. python/core/model_processing/onnx_custom_ops/mul.py +66 -0
  11. python/core/model_processing/onnx_quantizer/exceptions.py +2 -2
  12. python/core/model_processing/onnx_quantizer/layers/base.py +101 -0
  13. python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
  14. python/core/model_processing/onnx_quantizer/layers/clip.py +92 -0
  15. python/core/model_processing/onnx_quantizer/layers/max.py +49 -0
  16. python/core/model_processing/onnx_quantizer/layers/min.py +54 -0
  17. python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
  18. python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
  19. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -0
  20. python/core/model_templates/circuit_template.py +48 -38
  21. python/core/utils/errors.py +1 -1
  22. python/core/utils/scratch_tests.py +29 -23
  23. python/scripts/gen_and_bench.py +2 -2
  24. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +18 -14
  25. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +11 -13
  26. python/tests/circuit_parent_classes/test_ort_custom_layers.py +35 -53
  27. python/tests/onnx_quantizer_tests/layers/base.py +1 -3
  28. python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
  29. python/tests/onnx_quantizer_tests/layers/clip_config.py +127 -0
  30. python/tests/onnx_quantizer_tests/layers/max_config.py +100 -0
  31. python/tests/onnx_quantizer_tests/layers/min_config.py +94 -0
  32. python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
  33. python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
  34. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +6 -5
  35. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +8 -1
  36. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +17 -8
  37. python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
  38. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/WHEEL +0 -0
  39. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/entry_points.txt +0 -0
  40. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/licenses/LICENSE +0 -0
  41. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/top_level.txt +0 -0
@@ -17,13 +17,21 @@ from python.core.model_processing.onnx_quantizer.layers.base import (
17
17
  PassthroughQuantizer,
18
18
  ScaleConfig,
19
19
  )
20
+ from python.core.model_processing.onnx_quantizer.layers.batchnorm import (
21
+ BatchnormQuantizer,
22
+ )
23
+ from python.core.model_processing.onnx_quantizer.layers.clip import ClipQuantizer
20
24
  from python.core.model_processing.onnx_quantizer.layers.constant import (
21
25
  ConstantQuantizer,
22
26
  )
23
27
  from python.core.model_processing.onnx_quantizer.layers.conv import ConvQuantizer
24
28
  from python.core.model_processing.onnx_quantizer.layers.gemm import GemmQuantizer
29
+ from python.core.model_processing.onnx_quantizer.layers.max import MaxQuantizer
25
30
  from python.core.model_processing.onnx_quantizer.layers.maxpool import MaxpoolQuantizer
31
+ from python.core.model_processing.onnx_quantizer.layers.min import MinQuantizer
32
+ from python.core.model_processing.onnx_quantizer.layers.mul import MulQuantizer
26
33
  from python.core.model_processing.onnx_quantizer.layers.relu import ReluQuantizer
34
+ from python.core.model_processing.onnx_quantizer.layers.sub import SubQuantizer
27
35
 
28
36
 
29
37
  class ONNXOpQuantizer:
@@ -69,6 +77,9 @@ class ONNXOpQuantizer:
69
77
 
70
78
  # Register handlers
71
79
  self.register("Add", AddQuantizer(self.new_initializers))
80
+ self.register("Clip", ClipQuantizer(self.new_initializers))
81
+ self.register("Sub", SubQuantizer(self.new_initializers))
82
+ self.register("Mul", MulQuantizer(self.new_initializers))
72
83
  self.register("Conv", ConvQuantizer(self.new_initializers))
73
84
  self.register("Relu", ReluQuantizer())
74
85
  self.register("Reshape", PassthroughQuantizer())
@@ -76,6 +87,9 @@ class ONNXOpQuantizer:
76
87
  self.register("Constant", ConstantQuantizer())
77
88
  self.register("MaxPool", MaxpoolQuantizer())
78
89
  self.register("Flatten", PassthroughQuantizer())
90
+ self.register("Max", MaxQuantizer(self.new_initializers))
91
+ self.register("Min", MinQuantizer(self.new_initializers))
92
+ self.register("BatchNormalization", BatchnormQuantizer(self.new_initializers))
79
93
 
80
94
  def register(
81
95
  self: ONNXOpQuantizer,
@@ -203,3 +217,32 @@ class ONNXOpQuantizer:
203
217
  dict[str, onnx.TensorProto]: Map from initializer name to tensors in graph.
204
218
  """
205
219
  return {init.name: init for init in model.graph.initializer}
220
+
221
+ def apply_pre_analysis_transforms(
222
+ self: ONNXOpQuantizer,
223
+ model: onnx.ModelProto,
224
+ scale_exponent: int,
225
+ scale_base: int,
226
+ ) -> onnx.ModelProto:
227
+ """
228
+ Give each registered handler a chance to rewrite the model before analysis.
229
+ """
230
+ graph = model.graph
231
+ initializer_map = self.get_initializer_map(model)
232
+
233
+ # We allow handlers to modify graph in-place.
234
+ # (Nodes may be replaced, removed, or new nodes added.)
235
+ for node in list(graph.node):
236
+ handler = self.handlers.get(node.op_type)
237
+ if handler and hasattr(handler, "pre_analysis_transform"):
238
+ handler.pre_analysis_transform(
239
+ node,
240
+ graph,
241
+ initializer_map,
242
+ scale_exponent=scale_exponent,
243
+ scale_base=scale_base,
244
+ )
245
+ # Refresh map if transforms may add initializers
246
+ initializer_map = self.get_initializer_map(model)
247
+
248
+ return model
@@ -1,57 +1,67 @@
1
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from secrets import randbelow
4
+
2
5
  from python.core.circuits.base import Circuit
3
- from random import randint
6
+
4
7
 
5
8
  class SimpleCircuit(Circuit):
6
- '''
7
- Note: This template is irrelevant if using the onnx circuit builder.
8
- The template only helps developers if they choose to incorporate other circuit builders into the framework.
9
+ """
10
+ Note: This template is irrelevant if using the ONNX circuit builder.
11
+ The template only helps developers if they choose to incorporate other circuit
12
+ builders into the framework.
9
13
 
10
- To begin, we need to specify some basic attributes surrounding the circuit we will be using.
11
- required_keys - specify the variables in the input dictionary (and input file).
12
- name - name of the rust bin to be run by the circuit.
14
+ To begin, we need to specify some basic attributes surrounding the circuit we will
15
+ be using.
13
16
 
14
- scale_base - specify the base of the scaling applied to each value
15
- scale_exponent - the exponent applied to the base to get the scaling factor. Scaling factor will be multiplied by each input
17
+ - `required_keys`: the variables in the input dictionary (and input file).
18
+ - `name`: name of the Rust bin to be run by the circuit.
19
+ - `scale_base`: base of the scaling applied to each value.
20
+ - `scale_exponent`: exponent applied to the base to get the scaling factor.
21
+ Scaling factor will be multiplied by each input.
16
22
 
17
- Other default inputs can be defined below
18
- '''
19
- def __init__(self, file_name):
23
+ Other default inputs can be defined below.
24
+ """
25
+
26
+ def __init__(self, file_name: str | None = None) -> None:
20
27
  # Initialize the base class
21
28
  super().__init__()
22
-
29
+ self.file_name = file_name
30
+
23
31
  # Circuit-specific parameters
24
32
  self.required_keys = ["input_a", "input_b", "nonce"]
25
33
  self.name = "simple_circuit" # Use exact name that matches the binary
26
34
 
27
35
  self.scale_exponent = 1
28
36
  self.scale_base = 1
29
-
37
+
30
38
  self.input_a = 100
31
39
  self.input_b = 200
32
- self.nonce = randint(0,10000)
33
-
34
- '''
35
- The following are some important functions used by the model. get inputs should be defined to specify the inputs to the circuit
36
- '''
37
- def get_inputs(self):
38
- '''
39
- Specify the inputs to the circuit, based on what was specified in the __init__. Can also have inputs to this function for the inputs.
40
- '''
41
- return {'input_a': self.input_a, 'input_b': self.input_b, 'nonce': self.nonce}
42
-
43
- def get_outputs(self, inputs = None):
40
+ self.nonce = randbelow(10_000)
41
+
42
+ def get_inputs(self) -> dict[str, int]:
43
+ """
44
+ Specify the inputs to the circuit, based on what was specified
45
+ in `__init__`.
46
+ """
47
+ return {
48
+ "input_a": self.input_a,
49
+ "input_b": self.input_b,
50
+ "nonce": self.nonce,
51
+ }
52
+
53
+ def get_outputs(self, inputs: dict[str, int] | None = None) -> int:
44
54
  """
45
55
  Compute the output of the circuit.
46
- This is overwritten from the base class to ensure computation happens only once.
56
+
57
+ This is overwritten from the base class to ensure computation happens
58
+ only once.
47
59
  """
48
- if inputs == None:
49
- inputs = {'input_a': self.input_a, 'input_b': self.input_b, 'nonce': self.nonce}
50
- print(f"Performing addition operation: {inputs['input_a']} + {inputs['input_b']}")
51
- return inputs['input_a'] + inputs['input_b']
52
-
53
- # def format_inputs(self, inputs):
54
- # return {"input": inputs.long().tolist()}
55
-
56
- # def format_outputs(self, outputs):
57
- # return {"output": outputs.long().tolist()}
60
+ if inputs is None:
61
+ inputs = {
62
+ "input_a": self.input_a,
63
+ "input_b": self.input_b,
64
+ "nonce": self.nonce,
65
+ }
66
+
67
+ return inputs["input_a"] + inputs["input_b"]
@@ -30,7 +30,7 @@ class FileCacheError(CircuitExecutionError):
30
30
  class ProofBackendError(CircuitExecutionError):
31
31
  """Raised when a Cargo command fails."""
32
32
 
33
- def __init__( # noqa: PLR0913
33
+ def __init__(
34
34
  self: ProofBackendError,
35
35
  message: str,
36
36
  command: list[str] | None = None,
@@ -1,66 +1,72 @@
1
+ from __future__ import annotations
2
+
1
3
  import onnx
2
- from onnx import TensorProto, helper, shape_inference
3
- from onnx import numpy_helper
4
- from onnx import load, save
4
+ from onnx import TensorProto, helper, load, shape_inference
5
5
  from onnx.utils import extract_model
6
6
 
7
- def prune_model(model_path, output_names, save_path):
7
+
8
+ def prune_model(
9
+ model_path: str,
10
+ output_names: list[str],
11
+ save_path: str,
12
+ ) -> None:
13
+ """Extract a sub-model with the same inputs and new outputs."""
8
14
  model = load(model_path)
9
15
 
10
- # Provide model input names and the new desired output names
16
+ # Provide model input names and the new desired output names.
11
17
  input_names = [i.name for i in model.graph.input]
12
18
 
13
19
  extract_model(
14
20
  input_path=model_path,
15
21
  output_path=save_path,
16
22
  input_names=input_names,
17
- output_names=output_names
23
+ output_names=output_names,
18
24
  )
19
25
 
20
- print(f"Pruned model saved to {save_path}")
26
+ print(f"Pruned model saved to {save_path}") # noqa: T201
21
27
 
22
28
 
23
- def cut_model(model_path, output_names, save_path):
29
+ def cut_model(
30
+ model_path: str,
31
+ output_names: list[str],
32
+ save_path: str,
33
+ ) -> None:
34
+ """Replace the graph outputs with the tensors named in `output_names`."""
24
35
  model = onnx.load(model_path)
25
36
  model = shape_inference.infer_shapes(model)
26
37
 
27
38
  graph = model.graph
28
39
 
29
- # Remove all current outputs one by one (cannot use .clear() or assignment)
30
- while len(graph.output) > 0:
40
+ # Remove all current outputs one by one (cannot use .clear() or assignment).
41
+ while graph.output:
31
42
  graph.output.pop()
32
43
 
33
- # Add new outputs
44
+ # Add new outputs.
34
45
  for name in output_names:
35
- # Look in value_info, input, or output
46
+ # Look in value_info, input, or output.
36
47
  candidates = list(graph.value_info) + list(graph.input) + list(graph.output)
37
48
  value_info = next((vi for vi in candidates if vi.name == name), None)
38
49
  if value_info is None:
39
- raise ValueError(f"Tensor {name} not found in model graph.")
50
+ msg = f"Tensor {name} not found in model graph."
51
+ raise ValueError(msg)
40
52
 
41
53
  elem_type = value_info.type.tensor_type.elem_type
42
54
  shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
43
55
  new_output = helper.make_tensor_value_info(name, elem_type, shape)
44
56
  graph.output.append(new_output)
57
+
45
58
  for output in graph.output:
46
- print(output)
59
+ print(output) # noqa: T201
47
60
  if output.name == "/conv1/Conv_output_0":
48
61
  output.type.tensor_type.elem_type = TensorProto.INT64
49
62
 
50
63
  onnx.save(model, save_path)
51
- print(f"Saved cut model with outputs {output_names} to {save_path}")
64
+ print(f"Saved cut model with outputs {output_names} to {save_path}") # noqa: T201
52
65
 
53
66
 
54
67
  if __name__ == "__main__":
55
- # /conv1/Conv_output_0
56
- # prune_model(
57
- # model_path="models_onnx/doom.onnx",
58
- # output_names=["/Relu_2_output_0"], # replace with your intermediate tensor
59
- # save_path= "models_onnx/test_doom_cut.onnx"
60
- # )
61
- # cut_model("models_onnx/doom.onnx",["/Relu_2_output_0"], "test_doom_after_conv.onnx")
62
68
  prune_model(
63
69
  model_path="models_onnx/doom.onnx",
64
70
  output_names=["/Relu_3_output_0"], # replace with your intermediate tensor
65
- save_path= "models_onnx/test_doom_cut.onnx"
71
+ save_path="models_onnx/test_doom_cut.onnx",
66
72
  )
@@ -247,12 +247,12 @@ def export_onnx(
247
247
 
248
248
 
249
249
  def write_input_json(json_path: Path, input_shape: tuple[int] = (1, 4, 28, 28)) -> None:
250
- """Write a zero-valued input tensor to JSON alongside its [N,C,H,W] shape."""
250
+ """Write a zero-valued input tensor to JSON without shape information."""
251
251
  json_path.parent.mkdir(parents=True, exist_ok=True)
252
252
  n, c, h, w = input_shape
253
253
  arr = [0.0] * (n * c * h * w)
254
254
  with json_path.open("w", encoding="utf-8") as f:
255
- json.dump({"input": arr, "shape": [n, c, h, w]}, f)
255
+ json.dump({"input": arr}, f)
256
256
 
257
257
 
258
258
  def run_bench(
@@ -2,7 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
  from pathlib import Path
5
- from typing import Any, Generator
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ if TYPE_CHECKING:
8
+ from collections.abc import Generator
9
+
6
10
 
7
11
  import pytest
8
12
  import torch
@@ -35,7 +39,7 @@ OUTPUTTWICE = 2
35
39
  OUTPUTTHREETIMES = 3
36
40
 
37
41
 
38
- @pytest.mark.e2e()
42
+ @pytest.mark.e2e
39
43
  def test_circuit_compiles(model_fixture: dict[str, Any]) -> None:
40
44
  # Here you could just check that circuit file exists
41
45
  circuit_compile_results[model_fixture["model"]] = False
@@ -43,7 +47,7 @@ def test_circuit_compiles(model_fixture: dict[str, Any]) -> None:
43
47
  circuit_compile_results[model_fixture["model"]] = True
44
48
 
45
49
 
46
- @pytest.mark.e2e()
50
+ @pytest.mark.e2e
47
51
  def test_witness_dev(
48
52
  model_fixture: dict[str, Any],
49
53
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -89,7 +93,7 @@ def test_witness_dev(
89
93
  witness_generated_results[model_fixture["model"]] = True
90
94
 
91
95
 
92
- @pytest.mark.e2e()
96
+ @pytest.mark.e2e
93
97
  def test_witness_wrong_outputs_dev(
94
98
  model_fixture: dict[str, Any],
95
99
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -145,7 +149,7 @@ def test_witness_wrong_outputs_dev(
145
149
  ), f"Expected '{output}' in stdout, but it was not found."
146
150
 
147
151
 
148
- @pytest.mark.e2e()
152
+ @pytest.mark.e2e
149
153
  def test_witness_prove_verify_true_inputs_dev(
150
154
  model_fixture: dict[str, Any],
151
155
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -239,7 +243,7 @@ def test_witness_prove_verify_true_inputs_dev(
239
243
  ), "Expected 'Verified' in stdout three times, but it was not found."
240
244
 
241
245
 
242
- @pytest.mark.e2e()
246
+ @pytest.mark.e2e
243
247
  def test_witness_prove_verify_true_inputs_dev_expander_call(
244
248
  model_fixture: dict[str, Any],
245
249
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -339,7 +343,7 @@ def test_witness_prove_verify_true_inputs_dev_expander_call(
339
343
  assert stdout.count("proving") == 1, "Expected 'proving' but it was not found."
340
344
 
341
345
 
342
- @pytest.mark.e2e()
346
+ @pytest.mark.e2e
343
347
  def test_witness_read_after_write_json(
344
348
  model_fixture: dict[str, Any],
345
349
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -421,7 +425,7 @@ def test_witness_read_after_write_json(
421
425
  ), "Input JSON read is not identical to what was written"
422
426
 
423
427
 
424
- @pytest.mark.e2e()
428
+ @pytest.mark.e2e
425
429
  def test_witness_fresh_compile_dev(
426
430
  model_fixture: dict[str, Any],
427
431
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -467,7 +471,7 @@ def test_witness_fresh_compile_dev(
467
471
 
468
472
 
469
473
  # Use once fixed input shape read in rust
470
- @pytest.mark.e2e()
474
+ @pytest.mark.e2e
471
475
  def test_witness_incorrect_input_shape(
472
476
  model_fixture: dict[str, Any],
473
477
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -555,7 +559,7 @@ def test_witness_incorrect_input_shape(
555
559
  ), f"Did not expect '{output}' in stdout, but it was found."
556
560
 
557
561
 
558
- @pytest.mark.e2e()
562
+ @pytest.mark.e2e
559
563
  def test_witness_unscaled(
560
564
  model_fixture: dict[str, Any],
561
565
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -666,7 +670,7 @@ def test_witness_unscaled(
666
670
  )
667
671
 
668
672
 
669
- @pytest.mark.e2e()
673
+ @pytest.mark.e2e
670
674
  def test_witness_unscaled_and_incorrect_shape_input(
671
675
  model_fixture: dict[str, Any],
672
676
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -790,7 +794,7 @@ def test_witness_unscaled_and_incorrect_shape_input(
790
794
  )
791
795
 
792
796
 
793
- @pytest.mark.e2e()
797
+ @pytest.mark.e2e
794
798
  def test_witness_unscaled_and_incorrect_and_bad_named_input( # noqa: PLR0915
795
799
  model_fixture: dict[str, Any],
796
800
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -937,7 +941,7 @@ def test_witness_unscaled_and_incorrect_and_bad_named_input( # noqa: PLR0915
937
941
  )
938
942
 
939
943
 
940
- @pytest.mark.e2e()
944
+ @pytest.mark.e2e
941
945
  def test_witness_wrong_name(
942
946
  model_fixture: dict[str, Any],
943
947
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -1063,7 +1067,7 @@ def add_to_first_scalar(data: list, delta: float = 0.1) -> bool:
1063
1067
  return False
1064
1068
 
1065
1069
 
1066
- @pytest.mark.e2e()
1070
+ @pytest.mark.e2e
1067
1071
  def test_witness_prove_verify_false_inputs_dev(
1068
1072
  model_fixture: dict[str, Any],
1069
1073
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections.abc import Generator, Mapping, Sequence
3
4
  from pathlib import Path
4
- from typing import TYPE_CHECKING, Any, Generator, Mapping, Sequence, TypeAlias, Union
5
+ from typing import TYPE_CHECKING, Any, TypeAlias
5
6
 
6
7
  import numpy as np
7
8
  import pytest
@@ -64,7 +65,7 @@ def model_fixture(
64
65
  }
65
66
 
66
67
 
67
- @pytest.fixture()
68
+ @pytest.fixture
68
69
  def temp_witness_file(tmp_path: str) -> Generator[Path, None, None]:
69
70
  witness_path = tmp_path / "temp_witness.txt"
70
71
  # Give it to the test
@@ -75,7 +76,7 @@ def temp_witness_file(tmp_path: str) -> Generator[Path, None, None]:
75
76
  witness_path.unlink()
76
77
 
77
78
 
78
- @pytest.fixture()
79
+ @pytest.fixture
79
80
  def temp_input_file(tmp_path: str) -> Generator[Path, None, None]:
80
81
  input_path = tmp_path / "temp_input.txt"
81
82
  # Give it to the test
@@ -86,7 +87,7 @@ def temp_input_file(tmp_path: str) -> Generator[Path, None, None]:
86
87
  input_path.unlink()
87
88
 
88
89
 
89
- @pytest.fixture()
90
+ @pytest.fixture
90
91
  def temp_output_file(tmp_path: str) -> Generator[Path, None, None]:
91
92
  output_path = tmp_path / "temp_output.txt"
92
93
  # Give it to the test
@@ -97,7 +98,7 @@ def temp_output_file(tmp_path: str) -> Generator[Path, None, None]:
97
98
  output_path.unlink()
98
99
 
99
100
 
100
- @pytest.fixture()
101
+ @pytest.fixture
101
102
  def temp_proof_file(tmp_path: str) -> Generator[Path, None, None]:
102
103
  output_path = tmp_path / "temp_proof.txt"
103
104
  # Give it to the test
@@ -108,13 +109,10 @@ def temp_proof_file(tmp_path: str) -> Generator[Path, None, None]:
108
109
  output_path.unlink()
109
110
 
110
111
 
111
- ScalarOrTensor: TypeAlias = Union[int, float, torch.Tensor]
112
- NestedArray: TypeAlias = Union[
113
- ScalarOrTensor,
114
- list["NestedArray"],
115
- tuple["NestedArray"],
116
- np.ndarray,
117
- ]
112
+ ScalarOrTensor: TypeAlias = int | float | torch.Tensor
113
+ NestedArray: TypeAlias = (
114
+ ScalarOrTensor | list["NestedArray"] | tuple["NestedArray"] | np.ndarray
115
+ )
118
116
 
119
117
 
120
118
  def add_1_to_first_element(x: NestedArray) -> NestedArray:
@@ -137,7 +135,7 @@ def add_1_to_first_element(x: NestedArray) -> NestedArray:
137
135
  circuit_compile_results = {}
138
136
  witness_generated_results = {}
139
137
 
140
- Nested: TypeAlias = Union[float, Mapping[str, "Nested"], Sequence["Nested"]]
138
+ Nested: TypeAlias = float | Mapping[str, "Nested"] | Sequence["Nested"]
141
139
 
142
140
 
143
141
  def contains_float(obj: Nested) -> bool:
@@ -1,115 +1,97 @@
1
- import pytest
1
+ from pathlib import Path
2
+
2
3
  import numpy as np
3
- import torch
4
4
  import onnx
5
-
6
- from onnx import TensorProto, shape_inference, helper, numpy_helper
5
+ import pytest
6
+ import torch
7
+ from onnx import TensorProto, helper, shape_inference
7
8
 
8
9
  from python.core.model_processing.converters.onnx_converter import ONNXConverter
9
- from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_shape_dict
10
- from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import ONNXOpQuantizer
11
-
12
- from onnxruntime import InferenceSession, SessionOptions
13
- from onnxruntime_extensions import get_library_path, OrtPyFunction
14
- from python.core.model_processing.onnx_custom_ops import conv
15
-
16
- from python.core.model_processing.onnx_custom_ops.conv import int64_conv
17
- from python.core.model_processing.onnx_custom_ops.gemm import int64_gemm7
18
10
 
19
11
 
20
12
  @pytest.fixture
21
- def tiny_conv_model_path(tmp_path):
13
+ def tiny_conv_model_path(tmp_path: Path) -> Path:
22
14
  # Create input and output tensor info
23
- input_tensor = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 4, 4])
24
- output_tensor = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 2, 2])
15
+ input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
16
+ output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2])
25
17
 
26
18
  # Kernel weights (3x3 ones)
27
- W_init = helper.make_tensor(
28
- name='W',
19
+ w_init = helper.make_tensor(
20
+ name="W",
29
21
  data_type=TensorProto.FLOAT,
30
22
  dims=[1, 1, 3, 3],
31
- vals=np.ones((1 * 1 * 3 * 3), dtype=np.float32).tolist()
23
+ vals=np.ones((1 * 1 * 3 * 3), dtype=np.float32).tolist(),
32
24
  )
33
- Z_init = helper.make_tensor(
34
- name='Z',
25
+ z_init = helper.make_tensor(
26
+ name="Z",
35
27
  data_type=TensorProto.FLOAT,
36
28
  dims=[1],
37
- vals=np.ones(( 1), dtype=np.float32).tolist()
29
+ vals=np.ones((1), dtype=np.float32).tolist(),
38
30
  )
39
31
 
40
32
  # Conv node with no padding, stride 1
41
33
  conv_node = helper.make_node(
42
- 'Conv',
43
- inputs=['X', 'W', 'Z'],
44
- outputs=['Y'],
34
+ "Conv",
35
+ inputs=["X", "W", "Z"],
36
+ outputs=["Y"],
45
37
  kernel_shape=[3, 3],
46
38
  pads=[0, 0, 0, 0],
47
39
  strides=[1, 1],
48
- dilations = [1,1]
40
+ dilations=[1, 1],
49
41
  )
50
42
 
51
43
  # Build graph and model
52
44
  graph = helper.make_graph(
53
45
  nodes=[conv_node],
54
- name='TinyConvGraph',
46
+ name="TinyConvGraph",
55
47
  inputs=[input_tensor],
56
48
  outputs=[output_tensor],
57
- initializer=[W_init, Z_init]
49
+ initializer=[w_init, z_init],
58
50
  )
59
51
 
60
- model = helper.make_model(graph, producer_name='tiny-conv-example')
52
+ model = helper.make_model(graph, producer_name="tiny-conv-example")
61
53
 
62
54
  # Save to a temporary file
63
55
  model_path = tmp_path / "tiny_conv.onnx"
64
56
  onnx.save(model, str(model_path))
65
57
 
66
- return str(model_path)
58
+ return model_path
59
+
67
60
 
68
61
  @pytest.mark.integration
69
- def test_tiny_conv(tiny_conv_model_path):
62
+ def test_tiny_conv(tiny_conv_model_path: Path, tmp_path: Path) -> None:
70
63
  path = tiny_conv_model_path
71
64
 
72
65
  converter = ONNXConverter()
73
66
 
74
- X_input = np.arange(16, dtype=np.float32).reshape(1, 1, 4, 4)
75
- id_count = 0
67
+ # Load and validate original model
76
68
  model = onnx.load(path)
77
- # Fix, can remove this next line
78
- onnx.checker.check_model(model)
79
-
80
- # Check the model and print Y"s shape information
81
69
  onnx.checker.check_model(model)
82
- print(f"Before shape inference, the shape info of Y is:\n{model.graph.value_info}")
83
70
 
84
- # Apply shape inference on the model
71
+ # Apply shape inference and validate
85
72
  inferred_model = shape_inference.infer_shapes(model)
86
-
87
- # Check the model and print Y"s shape information
88
73
  onnx.checker.check_model(inferred_model)
89
- # print(f"After shape inference, the shape info of Y is:\n{inferred_model.graph.value_info}")
90
-
91
-
92
- domain_to_version = {opset.domain: opset.version for opset in model.opset_import}
93
-
94
- inferred_model = shape_inference.infer_shapes(model)
95
- output_name_to_shape = extract_shape_dict(inferred_model)
96
- id_count = 0
97
74
 
75
+ # Quantize and add custom domain
98
76
  new_model = converter.quantize_model(model, 2, 21)
99
77
  custom_domain = onnx.helper.make_operatorsetid(domain="ai.onnx.contrib", version=1)
100
78
  new_model.opset_import.append(custom_domain)
101
79
  onnx.checker.check_model(new_model)
102
80
 
103
- with open("model.onnx", "wb") as f:
81
+ # Save quantized model
82
+ out_path = tmp_path / "model_quant.onnx"
83
+ with out_path.open("wb") as f:
104
84
  f.write(new_model.SerializeToString())
105
85
 
106
- model = onnx.load("model.onnx")
107
- onnx.checker.check_model(model) # This throws a descriptive error
86
+ # Reload quantized model to ensure it is valid
87
+ model_quant = onnx.load(str(out_path))
88
+ onnx.checker.check_model(model_quant)
108
89
 
90
+ # Prepare inputs and compare outputs
109
91
  inputs = np.arange(16, dtype=np.float32).reshape(1, 1, 4, 4)
110
92
  outputs_true = converter.run_model_onnx_runtime(path, inputs)
93
+ outputs_quant = converter.run_model_onnx_runtime(out_path, inputs)
111
94
 
112
- outputs_quant = converter.run_model_onnx_runtime("model.onnx", inputs)
113
95
  true = torch.tensor(np.array(outputs_true), dtype=torch.float32)
114
96
  quant = torch.tensor(np.array(outputs_quant), dtype=torch.float32) / (2**21)
115
97
 
@@ -113,9 +113,7 @@ class LayerTestConfig:
113
113
  # respect that; otherwise use original valid_inputs.
114
114
  inputs = test_spec.input_overrides or self.valid_inputs
115
115
 
116
- # Prepare attributes
117
- attrs = {**self.valid_attributes, **test_spec.attr_overrides}
118
- # Remove omitted attributes if specified
116
+ # Prepare attributes and remove omitted attributes if specified
119
117
  attrs = {**self.valid_attributes, **test_spec.attr_overrides}
120
118
  for key in getattr(test_spec, "omit_attrs", []):
121
119
  attrs.pop(key, None)