JSTprove 1.2.0__py3-none-macosx_11_0_arm64.whl → 1.4.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/METADATA +1 -1
  2. {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/RECORD +32 -26
  3. python/core/binaries/onnx_generic_circuit_1-4-0 +0 -0
  4. python/core/circuits/base.py +29 -12
  5. python/core/circuits/errors.py +1 -2
  6. python/core/model_processing/converters/base.py +3 -3
  7. python/core/model_processing/onnx_custom_ops/__init__.py +5 -4
  8. python/core/model_processing/onnx_quantizer/exceptions.py +2 -2
  9. python/core/model_processing/onnx_quantizer/layers/base.py +79 -2
  10. python/core/model_processing/onnx_quantizer/layers/clip.py +92 -0
  11. python/core/model_processing/onnx_quantizer/layers/max.py +49 -0
  12. python/core/model_processing/onnx_quantizer/layers/maxpool.py +79 -4
  13. python/core/model_processing/onnx_quantizer/layers/min.py +54 -0
  14. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -0
  15. python/core/model_templates/circuit_template.py +48 -38
  16. python/core/utils/errors.py +1 -1
  17. python/core/utils/scratch_tests.py +29 -23
  18. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +18 -14
  19. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +11 -13
  20. python/tests/circuit_parent_classes/test_ort_custom_layers.py +35 -53
  21. python/tests/onnx_quantizer_tests/layers/base.py +1 -3
  22. python/tests/onnx_quantizer_tests/layers/clip_config.py +127 -0
  23. python/tests/onnx_quantizer_tests/layers/max_config.py +100 -0
  24. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +106 -0
  25. python/tests/onnx_quantizer_tests/layers/min_config.py +94 -0
  26. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +6 -5
  27. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +6 -1
  28. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +17 -8
  29. python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
  30. {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/WHEEL +0 -0
  31. {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/entry_points.txt +0 -0
  32. {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/licenses/LICENSE +0 -0
  33. {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,12 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  if TYPE_CHECKING:
6
+ from typing import ClassVar
7
+
6
8
  import onnx
7
9
 
8
10
  from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
11
+ extract_attributes,
9
12
  get_attribute_ints,
10
13
  )
11
14
  from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
@@ -21,6 +24,12 @@ class QuantizeMaxpool(QuantizerBase):
21
24
  USE_WB = False
22
25
  USE_SCALING = False
23
26
 
27
+ DEFAULT_ATTRS: ClassVar = {
28
+ "dilations": [1],
29
+ "pads": [0],
30
+ "strides": [1],
31
+ }
32
+
24
33
 
25
34
  class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
26
35
  """
@@ -72,6 +81,35 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
72
81
  InvalidParamError: If any requirement is not met.
73
82
  """
74
83
  _ = initializer_map
84
+ attributes = extract_attributes(node)
85
+ ceil_mode = attributes.get("ceil_mode", None)
86
+ auto_pad = attributes.get("auto_pad", None)
87
+ storage_order = attributes.get("storage_order", None)
88
+
89
+ if ceil_mode != 0 and ceil_mode is not None:
90
+ raise InvalidParamError(
91
+ node.name,
92
+ node.op_type,
93
+ "ceil_mode must be 0",
94
+ "ceil_mode",
95
+ "0",
96
+ )
97
+ if auto_pad != "NOTSET" and auto_pad is not None:
98
+ raise InvalidParamError(
99
+ node.name,
100
+ node.op_type,
101
+ "auto_pad must be NOTSET",
102
+ "auto_pad",
103
+ "NOTSET",
104
+ )
105
+ if storage_order != 0 and storage_order is not None:
106
+ raise InvalidParamError(
107
+ node.name,
108
+ node.op_type,
109
+ "storage_order must be 0",
110
+ "storage_order",
111
+ "0",
112
+ )
75
113
  self.check_all_params_exist(node)
76
114
  self.check_params_size(node)
77
115
  self.check_pool_pads(node)
@@ -85,8 +123,8 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
85
123
  Raises:
86
124
  InvalidParamError: If shape requirement is not met.
87
125
  """
88
- # May need: ["strides", "kernel_shape", "pads", "dilations"]
89
- required_attrs = ["strides", "kernel_shape"]
126
+ required_attrs = ["kernel_shape"]
127
+
90
128
  self.validate_required_attrs(node, required_attrs)
91
129
 
92
130
  # Check dimension of kernel
@@ -121,11 +159,23 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
121
159
 
122
160
  def check_pool_pads(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
123
161
  kernel_shape = get_attribute_ints(node, "kernel_shape", default=[])
124
- pads = get_attribute_ints(node, "pads", default=None)
162
+ pads_raw = get_attribute_ints(
163
+ node,
164
+ "pads",
165
+ default=self.DEFAULT_ATTRS.get("pads", None),
166
+ )
167
+ pads = self.adjust_pads(node, pads_raw)
168
+
125
169
  if pads is None:
126
170
  return
127
171
  num_dims = len(kernel_shape)
128
- if len(pads) != num_dims * 2:
172
+
173
+ if len(pads) == 1:
174
+ pads = pads * 2 * num_dims
175
+ elif len(pads) == num_dims:
176
+ # If only beginning pads given, repeat for end pads
177
+ pads = pads + pads
178
+ elif len(pads) != num_dims * 2:
129
179
  raise InvalidParamError(
130
180
  node.name,
131
181
  node.op_type,
@@ -148,3 +198,28 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
148
198
  node.op_type,
149
199
  f"pads[{dim + num_dims}]={pad_after} >= kernel[{dim}]={kernel}",
150
200
  )
201
+
202
+ def adjust_pads(
203
+ self: MaxpoolQuantizer,
204
+ node: onnx.NodeProto,
205
+ pads_raw: str | int | list[int] | None,
206
+ ) -> list[int]:
207
+ if pads_raw is None:
208
+ pads: list[int] = []
209
+ elif isinstance(pads_raw, str):
210
+ # single string, could be "0" or "1 2"
211
+ pads = [int(x) for x in pads_raw.split()]
212
+ elif isinstance(pads_raw, int):
213
+ # single integer
214
+ pads = [pads_raw]
215
+ elif isinstance(pads_raw, (list, tuple)):
216
+ # already a list of numbers (may be strings)
217
+ pads = [int(x) for x in pads_raw]
218
+ else:
219
+ raise InvalidParamError(
220
+ node.name,
221
+ node.op_type,
222
+ f"Cannot parse pads: {pads_raw}",
223
+ )
224
+
225
+ return pads
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
7
+
8
+ from python.core.model_processing.onnx_quantizer.layers.base import (
9
+ BaseOpQuantizer,
10
+ QuantizerBase,
11
+ ScaleConfig,
12
+ )
13
+
14
+
15
+ class QuantizeMin(QuantizerBase):
16
+ OP_TYPE = "Min"
17
+ DOMAIN = "" # standard ONNX domain
18
+ USE_WB = True # let framework wire inputs/outputs normally
19
+ USE_SCALING = False # passthrough: no internal scaling
20
+ SCALE_PLAN: ClassVar = {1: 1} # elementwise arity plan
21
+
22
+
23
+ class MinQuantizer(BaseOpQuantizer, QuantizeMin):
24
+ """
25
+ Passthrough quantizer for elementwise Min.
26
+ We rely on the converter to quantize graph inputs; no extra scaling here.
27
+ """
28
+
29
+ def __init__(
30
+ self: MinQuantizer,
31
+ new_initializers: list[onnx.TensorProto] | None = None,
32
+ ) -> None:
33
+ super().__init__()
34
+ if new_initializers is not None:
35
+ self.new_initializers = new_initializers
36
+
37
+ def quantize(
38
+ self: MinQuantizer,
39
+ node: onnx.NodeProto,
40
+ graph: onnx.GraphProto,
41
+ scale_config: ScaleConfig,
42
+ initializer_map: dict[str, onnx.TensorProto],
43
+ ) -> list[onnx.NodeProto]:
44
+ # Delegate to QuantizerBase's generic passthrough implementation.
45
+ return QuantizeMin.quantize(self, node, graph, scale_config, initializer_map)
46
+
47
+ def check_supported(
48
+ self: MinQuantizer,
49
+ node: onnx.NodeProto,
50
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
51
+ ) -> None:
52
+ # Min has no attributes; elementwise, variadic ≥ 1 input per ONNX spec.
53
+ # We mirror Add/Max broadcasting behavior; no extra checks here.
54
+ _ = node, initializer_map
@@ -20,12 +20,15 @@ from python.core.model_processing.onnx_quantizer.layers.base import (
20
20
  from python.core.model_processing.onnx_quantizer.layers.batchnorm import (
21
21
  BatchnormQuantizer,
22
22
  )
23
+ from python.core.model_processing.onnx_quantizer.layers.clip import ClipQuantizer
23
24
  from python.core.model_processing.onnx_quantizer.layers.constant import (
24
25
  ConstantQuantizer,
25
26
  )
26
27
  from python.core.model_processing.onnx_quantizer.layers.conv import ConvQuantizer
27
28
  from python.core.model_processing.onnx_quantizer.layers.gemm import GemmQuantizer
29
+ from python.core.model_processing.onnx_quantizer.layers.max import MaxQuantizer
28
30
  from python.core.model_processing.onnx_quantizer.layers.maxpool import MaxpoolQuantizer
31
+ from python.core.model_processing.onnx_quantizer.layers.min import MinQuantizer
29
32
  from python.core.model_processing.onnx_quantizer.layers.mul import MulQuantizer
30
33
  from python.core.model_processing.onnx_quantizer.layers.relu import ReluQuantizer
31
34
  from python.core.model_processing.onnx_quantizer.layers.sub import SubQuantizer
@@ -74,6 +77,7 @@ class ONNXOpQuantizer:
74
77
 
75
78
  # Register handlers
76
79
  self.register("Add", AddQuantizer(self.new_initializers))
80
+ self.register("Clip", ClipQuantizer(self.new_initializers))
77
81
  self.register("Sub", SubQuantizer(self.new_initializers))
78
82
  self.register("Mul", MulQuantizer(self.new_initializers))
79
83
  self.register("Conv", ConvQuantizer(self.new_initializers))
@@ -83,6 +87,8 @@ class ONNXOpQuantizer:
83
87
  self.register("Constant", ConstantQuantizer())
84
88
  self.register("MaxPool", MaxpoolQuantizer())
85
89
  self.register("Flatten", PassthroughQuantizer())
90
+ self.register("Max", MaxQuantizer(self.new_initializers))
91
+ self.register("Min", MinQuantizer(self.new_initializers))
86
92
  self.register("BatchNormalization", BatchnormQuantizer(self.new_initializers))
87
93
 
88
94
  def register(
@@ -1,57 +1,67 @@
1
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from secrets import randbelow
4
+
2
5
  from python.core.circuits.base import Circuit
3
- from random import randint
6
+
4
7
 
5
8
  class SimpleCircuit(Circuit):
6
- '''
7
- Note: This template is irrelevant if using the onnx circuit builder.
8
- The template only helps developers if they choose to incorporate other circuit builders into the framework.
9
+ """
10
+ Note: This template is irrelevant if using the ONNX circuit builder.
11
+ The template only helps developers if they choose to incorporate other circuit
12
+ builders into the framework.
9
13
 
10
- To begin, we need to specify some basic attributes surrounding the circuit we will be using.
11
- required_keys - specify the variables in the input dictionary (and input file).
12
- name - name of the rust bin to be run by the circuit.
14
+ To begin, we need to specify some basic attributes surrounding the circuit we will
15
+ be using.
13
16
 
14
- scale_base - specify the base of the scaling applied to each value
15
- scale_exponent - the exponent applied to the base to get the scaling factor. Scaling factor will be multiplied by each input
17
+ - `required_keys`: the variables in the input dictionary (and input file).
18
+ - `name`: name of the Rust bin to be run by the circuit.
19
+ - `scale_base`: base of the scaling applied to each value.
20
+ - `scale_exponent`: exponent applied to the base to get the scaling factor.
21
+ Scaling factor will be multiplied by each input.
16
22
 
17
- Other default inputs can be defined below
18
- '''
19
- def __init__(self, file_name):
23
+ Other default inputs can be defined below.
24
+ """
25
+
26
+ def __init__(self, file_name: str | None = None) -> None:
20
27
  # Initialize the base class
21
28
  super().__init__()
22
-
29
+ self.file_name = file_name
30
+
23
31
  # Circuit-specific parameters
24
32
  self.required_keys = ["input_a", "input_b", "nonce"]
25
33
  self.name = "simple_circuit" # Use exact name that matches the binary
26
34
 
27
35
  self.scale_exponent = 1
28
36
  self.scale_base = 1
29
-
37
+
30
38
  self.input_a = 100
31
39
  self.input_b = 200
32
- self.nonce = randint(0,10000)
33
-
34
- '''
35
- The following are some important functions used by the model. get inputs should be defined to specify the inputs to the circuit
36
- '''
37
- def get_inputs(self):
38
- '''
39
- Specify the inputs to the circuit, based on what was specified in the __init__. Can also have inputs to this function for the inputs.
40
- '''
41
- return {'input_a': self.input_a, 'input_b': self.input_b, 'nonce': self.nonce}
42
-
43
- def get_outputs(self, inputs = None):
40
+ self.nonce = randbelow(10_000)
41
+
42
+ def get_inputs(self) -> dict[str, int]:
43
+ """
44
+ Specify the inputs to the circuit, based on what was specified
45
+ in `__init__`.
46
+ """
47
+ return {
48
+ "input_a": self.input_a,
49
+ "input_b": self.input_b,
50
+ "nonce": self.nonce,
51
+ }
52
+
53
+ def get_outputs(self, inputs: dict[str, int] | None = None) -> int:
44
54
  """
45
55
  Compute the output of the circuit.
46
- This is overwritten from the base class to ensure computation happens only once.
56
+
57
+ This is overwritten from the base class to ensure computation happens
58
+ only once.
47
59
  """
48
- if inputs == None:
49
- inputs = {'input_a': self.input_a, 'input_b': self.input_b, 'nonce': self.nonce}
50
- print(f"Performing addition operation: {inputs['input_a']} + {inputs['input_b']}")
51
- return inputs['input_a'] + inputs['input_b']
52
-
53
- # def format_inputs(self, inputs):
54
- # return {"input": inputs.long().tolist()}
55
-
56
- # def format_outputs(self, outputs):
57
- # return {"output": outputs.long().tolist()}
60
+ if inputs is None:
61
+ inputs = {
62
+ "input_a": self.input_a,
63
+ "input_b": self.input_b,
64
+ "nonce": self.nonce,
65
+ }
66
+
67
+ return inputs["input_a"] + inputs["input_b"]
@@ -30,7 +30,7 @@ class FileCacheError(CircuitExecutionError):
30
30
  class ProofBackendError(CircuitExecutionError):
31
31
  """Raised when a Cargo command fails."""
32
32
 
33
- def __init__( # noqa: PLR0913
33
+ def __init__(
34
34
  self: ProofBackendError,
35
35
  message: str,
36
36
  command: list[str] | None = None,
@@ -1,66 +1,72 @@
1
+ from __future__ import annotations
2
+
1
3
  import onnx
2
- from onnx import TensorProto, helper, shape_inference
3
- from onnx import numpy_helper
4
- from onnx import load, save
4
+ from onnx import TensorProto, helper, load, shape_inference
5
5
  from onnx.utils import extract_model
6
6
 
7
- def prune_model(model_path, output_names, save_path):
7
+
8
+ def prune_model(
9
+ model_path: str,
10
+ output_names: list[str],
11
+ save_path: str,
12
+ ) -> None:
13
+ """Extract a sub-model with the same inputs and new outputs."""
8
14
  model = load(model_path)
9
15
 
10
- # Provide model input names and the new desired output names
16
+ # Provide model input names and the new desired output names.
11
17
  input_names = [i.name for i in model.graph.input]
12
18
 
13
19
  extract_model(
14
20
  input_path=model_path,
15
21
  output_path=save_path,
16
22
  input_names=input_names,
17
- output_names=output_names
23
+ output_names=output_names,
18
24
  )
19
25
 
20
- print(f"Pruned model saved to {save_path}")
26
+ print(f"Pruned model saved to {save_path}") # noqa: T201
21
27
 
22
28
 
23
- def cut_model(model_path, output_names, save_path):
29
+ def cut_model(
30
+ model_path: str,
31
+ output_names: list[str],
32
+ save_path: str,
33
+ ) -> None:
34
+ """Replace the graph outputs with the tensors named in `output_names`."""
24
35
  model = onnx.load(model_path)
25
36
  model = shape_inference.infer_shapes(model)
26
37
 
27
38
  graph = model.graph
28
39
 
29
- # Remove all current outputs one by one (cannot use .clear() or assignment)
30
- while len(graph.output) > 0:
40
+ # Remove all current outputs one by one (cannot use .clear() or assignment).
41
+ while graph.output:
31
42
  graph.output.pop()
32
43
 
33
- # Add new outputs
44
+ # Add new outputs.
34
45
  for name in output_names:
35
- # Look in value_info, input, or output
46
+ # Look in value_info, input, or output.
36
47
  candidates = list(graph.value_info) + list(graph.input) + list(graph.output)
37
48
  value_info = next((vi for vi in candidates if vi.name == name), None)
38
49
  if value_info is None:
39
- raise ValueError(f"Tensor {name} not found in model graph.")
50
+ msg = f"Tensor {name} not found in model graph."
51
+ raise ValueError(msg)
40
52
 
41
53
  elem_type = value_info.type.tensor_type.elem_type
42
54
  shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
43
55
  new_output = helper.make_tensor_value_info(name, elem_type, shape)
44
56
  graph.output.append(new_output)
57
+
45
58
  for output in graph.output:
46
- print(output)
59
+ print(output) # noqa: T201
47
60
  if output.name == "/conv1/Conv_output_0":
48
61
  output.type.tensor_type.elem_type = TensorProto.INT64
49
62
 
50
63
  onnx.save(model, save_path)
51
- print(f"Saved cut model with outputs {output_names} to {save_path}")
64
+ print(f"Saved cut model with outputs {output_names} to {save_path}") # noqa: T201
52
65
 
53
66
 
54
67
  if __name__ == "__main__":
55
- # /conv1/Conv_output_0
56
- # prune_model(
57
- # model_path="models_onnx/doom.onnx",
58
- # output_names=["/Relu_2_output_0"], # replace with your intermediate tensor
59
- # save_path= "models_onnx/test_doom_cut.onnx"
60
- # )
61
- # cut_model("models_onnx/doom.onnx",["/Relu_2_output_0"], "test_doom_after_conv.onnx")
62
68
  prune_model(
63
69
  model_path="models_onnx/doom.onnx",
64
70
  output_names=["/Relu_3_output_0"], # replace with your intermediate tensor
65
- save_path= "models_onnx/test_doom_cut.onnx"
71
+ save_path="models_onnx/test_doom_cut.onnx",
66
72
  )
@@ -2,7 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  import json
4
4
  from pathlib import Path
5
- from typing import Any, Generator
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ if TYPE_CHECKING:
8
+ from collections.abc import Generator
9
+
6
10
 
7
11
  import pytest
8
12
  import torch
@@ -35,7 +39,7 @@ OUTPUTTWICE = 2
35
39
  OUTPUTTHREETIMES = 3
36
40
 
37
41
 
38
- @pytest.mark.e2e()
42
+ @pytest.mark.e2e
39
43
  def test_circuit_compiles(model_fixture: dict[str, Any]) -> None:
40
44
  # Here you could just check that circuit file exists
41
45
  circuit_compile_results[model_fixture["model"]] = False
@@ -43,7 +47,7 @@ def test_circuit_compiles(model_fixture: dict[str, Any]) -> None:
43
47
  circuit_compile_results[model_fixture["model"]] = True
44
48
 
45
49
 
46
- @pytest.mark.e2e()
50
+ @pytest.mark.e2e
47
51
  def test_witness_dev(
48
52
  model_fixture: dict[str, Any],
49
53
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -89,7 +93,7 @@ def test_witness_dev(
89
93
  witness_generated_results[model_fixture["model"]] = True
90
94
 
91
95
 
92
- @pytest.mark.e2e()
96
+ @pytest.mark.e2e
93
97
  def test_witness_wrong_outputs_dev(
94
98
  model_fixture: dict[str, Any],
95
99
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -145,7 +149,7 @@ def test_witness_wrong_outputs_dev(
145
149
  ), f"Expected '{output}' in stdout, but it was not found."
146
150
 
147
151
 
148
- @pytest.mark.e2e()
152
+ @pytest.mark.e2e
149
153
  def test_witness_prove_verify_true_inputs_dev(
150
154
  model_fixture: dict[str, Any],
151
155
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -239,7 +243,7 @@ def test_witness_prove_verify_true_inputs_dev(
239
243
  ), "Expected 'Verified' in stdout three times, but it was not found."
240
244
 
241
245
 
242
- @pytest.mark.e2e()
246
+ @pytest.mark.e2e
243
247
  def test_witness_prove_verify_true_inputs_dev_expander_call(
244
248
  model_fixture: dict[str, Any],
245
249
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -339,7 +343,7 @@ def test_witness_prove_verify_true_inputs_dev_expander_call(
339
343
  assert stdout.count("proving") == 1, "Expected 'proving' but it was not found."
340
344
 
341
345
 
342
- @pytest.mark.e2e()
346
+ @pytest.mark.e2e
343
347
  def test_witness_read_after_write_json(
344
348
  model_fixture: dict[str, Any],
345
349
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -421,7 +425,7 @@ def test_witness_read_after_write_json(
421
425
  ), "Input JSON read is not identical to what was written"
422
426
 
423
427
 
424
- @pytest.mark.e2e()
428
+ @pytest.mark.e2e
425
429
  def test_witness_fresh_compile_dev(
426
430
  model_fixture: dict[str, Any],
427
431
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -467,7 +471,7 @@ def test_witness_fresh_compile_dev(
467
471
 
468
472
 
469
473
  # Use once fixed input shape read in rust
470
- @pytest.mark.e2e()
474
+ @pytest.mark.e2e
471
475
  def test_witness_incorrect_input_shape(
472
476
  model_fixture: dict[str, Any],
473
477
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -555,7 +559,7 @@ def test_witness_incorrect_input_shape(
555
559
  ), f"Did not expect '{output}' in stdout, but it was found."
556
560
 
557
561
 
558
- @pytest.mark.e2e()
562
+ @pytest.mark.e2e
559
563
  def test_witness_unscaled(
560
564
  model_fixture: dict[str, Any],
561
565
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -666,7 +670,7 @@ def test_witness_unscaled(
666
670
  )
667
671
 
668
672
 
669
- @pytest.mark.e2e()
673
+ @pytest.mark.e2e
670
674
  def test_witness_unscaled_and_incorrect_shape_input(
671
675
  model_fixture: dict[str, Any],
672
676
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -790,7 +794,7 @@ def test_witness_unscaled_and_incorrect_shape_input(
790
794
  )
791
795
 
792
796
 
793
- @pytest.mark.e2e()
797
+ @pytest.mark.e2e
794
798
  def test_witness_unscaled_and_incorrect_and_bad_named_input( # noqa: PLR0915
795
799
  model_fixture: dict[str, Any],
796
800
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -937,7 +941,7 @@ def test_witness_unscaled_and_incorrect_and_bad_named_input( # noqa: PLR0915
937
941
  )
938
942
 
939
943
 
940
- @pytest.mark.e2e()
944
+ @pytest.mark.e2e
941
945
  def test_witness_wrong_name(
942
946
  model_fixture: dict[str, Any],
943
947
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -1063,7 +1067,7 @@ def add_to_first_scalar(data: list, delta: float = 0.1) -> bool:
1063
1067
  return False
1064
1068
 
1065
1069
 
1066
- @pytest.mark.e2e()
1070
+ @pytest.mark.e2e
1067
1071
  def test_witness_prove_verify_false_inputs_dev(
1068
1072
  model_fixture: dict[str, Any],
1069
1073
  capsys: Generator[pytest.CaptureFixture[str], None, None],
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections.abc import Generator, Mapping, Sequence
3
4
  from pathlib import Path
4
- from typing import TYPE_CHECKING, Any, Generator, Mapping, Sequence, TypeAlias, Union
5
+ from typing import TYPE_CHECKING, Any, TypeAlias
5
6
 
6
7
  import numpy as np
7
8
  import pytest
@@ -64,7 +65,7 @@ def model_fixture(
64
65
  }
65
66
 
66
67
 
67
- @pytest.fixture()
68
+ @pytest.fixture
68
69
  def temp_witness_file(tmp_path: str) -> Generator[Path, None, None]:
69
70
  witness_path = tmp_path / "temp_witness.txt"
70
71
  # Give it to the test
@@ -75,7 +76,7 @@ def temp_witness_file(tmp_path: str) -> Generator[Path, None, None]:
75
76
  witness_path.unlink()
76
77
 
77
78
 
78
- @pytest.fixture()
79
+ @pytest.fixture
79
80
  def temp_input_file(tmp_path: str) -> Generator[Path, None, None]:
80
81
  input_path = tmp_path / "temp_input.txt"
81
82
  # Give it to the test
@@ -86,7 +87,7 @@ def temp_input_file(tmp_path: str) -> Generator[Path, None, None]:
86
87
  input_path.unlink()
87
88
 
88
89
 
89
- @pytest.fixture()
90
+ @pytest.fixture
90
91
  def temp_output_file(tmp_path: str) -> Generator[Path, None, None]:
91
92
  output_path = tmp_path / "temp_output.txt"
92
93
  # Give it to the test
@@ -97,7 +98,7 @@ def temp_output_file(tmp_path: str) -> Generator[Path, None, None]:
97
98
  output_path.unlink()
98
99
 
99
100
 
100
- @pytest.fixture()
101
+ @pytest.fixture
101
102
  def temp_proof_file(tmp_path: str) -> Generator[Path, None, None]:
102
103
  output_path = tmp_path / "temp_proof.txt"
103
104
  # Give it to the test
@@ -108,13 +109,10 @@ def temp_proof_file(tmp_path: str) -> Generator[Path, None, None]:
108
109
  output_path.unlink()
109
110
 
110
111
 
111
- ScalarOrTensor: TypeAlias = Union[int, float, torch.Tensor]
112
- NestedArray: TypeAlias = Union[
113
- ScalarOrTensor,
114
- list["NestedArray"],
115
- tuple["NestedArray"],
116
- np.ndarray,
117
- ]
112
+ ScalarOrTensor: TypeAlias = int | float | torch.Tensor
113
+ NestedArray: TypeAlias = (
114
+ ScalarOrTensor | list["NestedArray"] | tuple["NestedArray"] | np.ndarray
115
+ )
118
116
 
119
117
 
120
118
  def add_1_to_first_element(x: NestedArray) -> NestedArray:
@@ -137,7 +135,7 @@ def add_1_to_first_element(x: NestedArray) -> NestedArray:
137
135
  circuit_compile_results = {}
138
136
  witness_generated_results = {}
139
137
 
140
- Nested: TypeAlias = Union[float, Mapping[str, "Nested"], Sequence["Nested"]]
138
+ Nested: TypeAlias = float | Mapping[str, "Nested"] | Sequence["Nested"]
141
139
 
142
140
 
143
141
  def contains_float(obj: Nested) -> bool: