JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (61) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
  2. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
  3. python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +114 -59
  7. python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
  8. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  9. python/core/model_processing/onnx_custom_ops/mul.py +66 -0
  10. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  11. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  12. python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
  13. python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
  14. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  15. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  16. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  17. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  18. python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
  19. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  20. python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
  21. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
  22. python/core/utils/general_layer_functions.py +17 -12
  23. python/core/utils/model_registry.py +6 -3
  24. python/scripts/gen_and_bench.py +2 -2
  25. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  26. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  27. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  28. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  29. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  30. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  31. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  32. python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
  33. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  34. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  35. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  36. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  37. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  38. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  39. python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
  40. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  41. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  42. python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
  43. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  44. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  45. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  46. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  47. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  48. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  49. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
  50. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  51. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  52. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  53. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  54. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  55. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  56. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  57. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  58. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
  59. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
  60. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
  61. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,28 @@
1
1
  from __future__ import annotations
2
2
 
3
- import onnx
4
- from onnx import helper
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
5
7
 
6
8
  from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
7
- extract_attributes,
8
9
  get_attribute_ints,
9
10
  )
10
11
  from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
11
12
  from python.core.model_processing.onnx_quantizer.layers.base import (
12
13
  BaseOpQuantizer,
14
+ QuantizerBase,
13
15
  ScaleConfig,
14
16
  )
15
17
 
16
18
 
17
- class MaxpoolQuantizer(BaseOpQuantizer):
19
+ class QuantizeMaxpool(QuantizerBase):
20
+ OP_TYPE = "Int64MaxPool"
21
+ USE_WB = False
22
+ USE_SCALING = False
23
+
24
+
25
+ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
18
26
  """
19
27
  Quantizer for ONNX MaxPool layers.
20
28
 
@@ -25,55 +33,26 @@ class MaxpoolQuantizer(BaseOpQuantizer):
25
33
 
26
34
  def __init__(
27
35
  self: MaxpoolQuantizer,
28
- new_initializer: dict[str, onnx.TensorProto] | None = None,
36
+ new_initializer: list[onnx.TensorProto] | None = None,
29
37
  ) -> None:
30
38
  super().__init__()
31
39
  self.accepted_kernel_shapes = [2]
32
40
  _ = new_initializer
33
41
 
34
42
  def quantize(
35
- self: BaseOpQuantizer,
43
+ self: MaxpoolQuantizer,
36
44
  node: onnx.NodeProto,
37
45
  graph: onnx.GraphProto,
38
46
  scale_config: ScaleConfig,
39
47
  initializer_map: dict[str, onnx.TensorProto],
40
48
  ) -> list[onnx.NodeProto]:
41
- """
42
- Quantize a node by converting the node to Int64 version
43
-
44
- Args:
45
- node (onnx.NodeProto): The node to quantize.
46
- rescale (bool): Whether rescaling is enabled
47
- (Doesnt have an affect on this op type)
48
- graph (onnx.GraphProto): The ONNX graph.
49
- scale_exponent (int): Scale exponent.
50
- scale_base (int): The base of scaling.
51
- initializer_map (dict[str, onnx.TensorProto]):
52
- Map of initializer names to tensor data.
53
-
54
- Returns:
55
- List[onnx.NodeProto]: A list of ONNX nodes
56
- (quantized MaxPool and any auxiliary nodes).
57
- """
58
- _ = initializer_map, graph
59
-
60
- attrs = extract_attributes(node)
61
- attrs["rescale"] = int(scale_config.rescale)
62
-
63
- attr_str = {
64
- k: ",".join(map(str, v)) if isinstance(v, list) else str(v)
65
- for k, v in attrs.items()
66
- }
67
- return [
68
- helper.make_node(
69
- "Int64MaxPool",
70
- inputs=node.input,
71
- outputs=node.output,
72
- name=node.name,
73
- domain="ai.onnx.contrib",
74
- **attr_str,
75
- ),
76
- ]
49
+ return QuantizeMaxpool.quantize(
50
+ self,
51
+ node,
52
+ graph,
53
+ scale_config,
54
+ initializer_map,
55
+ )
77
56
 
78
57
  def check_supported(
79
58
  self: MaxpoolQuantizer,
@@ -95,6 +74,7 @@ class MaxpoolQuantizer(BaseOpQuantizer):
95
74
  _ = initializer_map
96
75
  self.check_all_params_exist(node)
97
76
  self.check_params_size(node)
77
+ self.check_pool_pads(node)
98
78
 
99
79
  def check_all_params_exist(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
100
80
  """Checks all parameters that are needed, do exist
@@ -131,10 +111,40 @@ class MaxpoolQuantizer(BaseOpQuantizer):
131
111
  InvalidParamError: If shape requirement is not met.
132
112
  """
133
113
 
134
- kernel_shape = get_attribute_ints(node, "kernel_shape", default="N/A")
114
+ kernel_shape = get_attribute_ints(node, "kernel_shape", default=[])
135
115
  if len(kernel_shape) not in self.accepted_kernel_shapes:
136
116
  raise InvalidParamError(
137
117
  node.name,
138
118
  node.op_type,
139
119
  f"Currently only maxpool2d is supported. Found {len(kernel_shape)}D",
140
120
  )
121
+
122
+ def check_pool_pads(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
123
+ kernel_shape = get_attribute_ints(node, "kernel_shape", default=[])
124
+ pads = get_attribute_ints(node, "pads", default=None)
125
+ if pads is None:
126
+ return
127
+ num_dims = len(kernel_shape)
128
+ if len(pads) != num_dims * 2:
129
+ raise InvalidParamError(
130
+ node.name,
131
+ node.op_type,
132
+ f"Expected {num_dims * 2} pads, got {len(pads)}",
133
+ )
134
+
135
+ for dim in range(num_dims):
136
+ pad_before = pads[dim]
137
+ pad_after = pads[dim + num_dims]
138
+ kernel = kernel_shape[dim]
139
+ if pad_before >= kernel:
140
+ raise InvalidParamError(
141
+ node.name,
142
+ node.op_type,
143
+ f"pads[{dim}]={pad_before} >= kernel[{dim}]={kernel}",
144
+ )
145
+ if pad_after >= kernel:
146
+ raise InvalidParamError(
147
+ node.name,
148
+ node.op_type,
149
+ f"pads[{dim + num_dims}]={pad_after} >= kernel[{dim}]={kernel}",
150
+ )
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
7
+
8
+ from python.core.model_processing.onnx_quantizer.layers.base import (
9
+ BaseOpQuantizer,
10
+ QuantizerBase,
11
+ ScaleConfig,
12
+ )
13
+
14
+
15
+ class QuantizeMul(QuantizerBase):
16
+ OP_TYPE = "Int64Mul"
17
+ USE_WB = True
18
+ USE_SCALING = True
19
+ SCALE_PLAN: ClassVar = {0: 1, 1: 1}
20
+
21
+
22
+ class MulQuantizer(BaseOpQuantizer, QuantizeMul):
23
+ """
24
+ Quantizer for ONNX Mul layers.
25
+
26
+ - Uses custom Mul layer to incorporate rescaling, and
27
+ makes relevant additional changes to the graph.
28
+ """
29
+
30
+ def __init__(
31
+ self: MulQuantizer,
32
+ new_initializers: list[onnx.TensorProto] | None = None,
33
+ ) -> None:
34
+ super().__init__()
35
+ # Only replace if caller provided something
36
+ if new_initializers is not None:
37
+ self.new_initializers = new_initializers
38
+
39
+ def quantize(
40
+ self: MulQuantizer,
41
+ node: onnx.NodeProto,
42
+ graph: onnx.GraphProto,
43
+ scale_config: ScaleConfig,
44
+ initializer_map: dict[str, onnx.TensorProto],
45
+ ) -> list[onnx.NodeProto]:
46
+ return QuantizeMul.quantize(self, node, graph, scale_config, initializer_map)
47
+
48
+ def check_supported(
49
+ self: MulQuantizer,
50
+ node: onnx.NodeProto,
51
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
52
+ ) -> None:
53
+ pass
@@ -1,14 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
- import onnx
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from onnx import GraphProto, NodeProto, TensorProto
4
7
 
5
8
  from python.core.model_processing.onnx_quantizer.layers.base import (
6
9
  BaseOpQuantizer,
10
+ QuantizerBase,
7
11
  ScaleConfig,
8
12
  )
9
13
 
10
14
 
11
- class ReluQuantizer(BaseOpQuantizer):
15
+ class QuantizeRelu(QuantizerBase):
16
+ OP_TYPE = "Int64Relu"
17
+ USE_WB = False
18
+ USE_SCALING = False
19
+
20
+
21
+ class ReluQuantizer(BaseOpQuantizer, QuantizeRelu):
12
22
  """
13
23
  Quantizer for ONNX ReLU layers.
14
24
 
@@ -19,49 +29,24 @@ class ReluQuantizer(BaseOpQuantizer):
19
29
 
20
30
  def __init__(
21
31
  self: ReluQuantizer,
22
- new_initializer: dict[str, onnx.TensorProto] | None = None,
32
+ new_initializer: list[TensorProto] | None = None,
23
33
  ) -> None:
24
34
  super().__init__()
25
35
  _ = new_initializer
26
36
 
27
37
  def quantize(
28
38
  self: ReluQuantizer,
29
- node: onnx.NodeProto,
30
- graph: onnx.GraphProto,
39
+ node: NodeProto,
40
+ graph: GraphProto,
31
41
  scale_config: ScaleConfig,
32
- initializer_map: dict[str, onnx.TensorProto],
33
- ) -> list[onnx.NodeProto]:
34
- """
35
- Quantize a node by converting the node to Int64 version
36
-
37
- Args:
38
- node (onnx.NodeProto): The node to quantize.
39
- rescale (bool): Whether rescaling is enabled
40
- (Doesnt have an affect on this op type)
41
- graph (onnx.GraphProto): The ONNX graph.
42
- scale_exponent (int): Scale exponent.
43
- scale_base (int): The base of scaling.
44
- initializer_map (dict[str, onnx.TensorProto]):
45
- Map of initializer names to tensor data.
46
-
47
- Returns:
48
- List[onnx.NodeProto]: The quantized ONNX node.
49
- """
50
- _ = graph, scale_config, initializer_map
51
- return [
52
- onnx.helper.make_node(
53
- "Int64Relu",
54
- inputs=node.input,
55
- outputs=node.output, # preserve original output name
56
- name=node.name,
57
- domain="ai.onnx.contrib",
58
- ),
59
- ]
42
+ initializer_map: dict[str, TensorProto],
43
+ ) -> list[NodeProto]:
44
+ return QuantizeRelu.quantize(self, node, graph, scale_config, initializer_map)
60
45
 
61
46
  def check_supported(
62
47
  self: ReluQuantizer,
63
- node: onnx.NodeProto,
64
- initializer_map: dict[str, onnx.TensorProto] | None = None,
48
+ node: NodeProto,
49
+ initializer_map: dict[str, TensorProto] | None = None,
65
50
  ) -> None:
66
51
  """
67
52
  Perform high-level validation to ensure that this node
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
7
+
8
+ from python.core.model_processing.onnx_quantizer.layers.base import (
9
+ BaseOpQuantizer,
10
+ QuantizerBase,
11
+ ScaleConfig,
12
+ )
13
+
14
+
15
+ class QuantizeSub(QuantizerBase):
16
+ OP_TYPE = "Sub"
17
+ DOMAIN = ""
18
+ USE_WB = True
19
+ USE_SCALING = False
20
+ SCALE_PLAN: ClassVar = {0: 1, 1: 1}
21
+
22
+
23
+ class SubQuantizer(BaseOpQuantizer, QuantizeSub):
24
+ """
25
+ Quantizer for ONNX Sub layers.
26
+
27
+ - Uses standard ONNX Sub layer in standard domain, and
28
+ makes relevant additional changes to the graph.
29
+ """
30
+
31
+ def __init__(
32
+ self: SubQuantizer,
33
+ new_initializers: list[onnx.TensorProto] | None = None,
34
+ ) -> None:
35
+ super().__init__()
36
+ # Only replace if caller provided something
37
+ if new_initializers is not None:
38
+ self.new_initializers = new_initializers
39
+
40
+ def quantize(
41
+ self: SubQuantizer,
42
+ node: onnx.NodeProto,
43
+ graph: onnx.GraphProto,
44
+ scale_config: ScaleConfig,
45
+ initializer_map: dict[str, onnx.TensorProto],
46
+ ) -> list[onnx.NodeProto]:
47
+ return QuantizeSub.quantize(self, node, graph, scale_config, initializer_map)
48
+
49
+ def check_supported(
50
+ self: SubQuantizer,
51
+ node: onnx.NodeProto,
52
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
53
+ ) -> None:
54
+ pass
@@ -1,6 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Callable
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from collections.abc import Callable
4
7
 
5
8
  import onnx
6
9
 
@@ -9,17 +12,23 @@ from python.core.model_processing.onnx_quantizer.exceptions import (
9
12
  MissingHandlerError,
10
13
  UnsupportedOpError,
11
14
  )
15
+ from python.core.model_processing.onnx_quantizer.layers.add import AddQuantizer
12
16
  from python.core.model_processing.onnx_quantizer.layers.base import (
13
17
  PassthroughQuantizer,
14
18
  ScaleConfig,
15
19
  )
20
+ from python.core.model_processing.onnx_quantizer.layers.batchnorm import (
21
+ BatchnormQuantizer,
22
+ )
16
23
  from python.core.model_processing.onnx_quantizer.layers.constant import (
17
24
  ConstantQuantizer,
18
25
  )
19
26
  from python.core.model_processing.onnx_quantizer.layers.conv import ConvQuantizer
20
27
  from python.core.model_processing.onnx_quantizer.layers.gemm import GemmQuantizer
21
28
  from python.core.model_processing.onnx_quantizer.layers.maxpool import MaxpoolQuantizer
29
+ from python.core.model_processing.onnx_quantizer.layers.mul import MulQuantizer
22
30
  from python.core.model_processing.onnx_quantizer.layers.relu import ReluQuantizer
31
+ from python.core.model_processing.onnx_quantizer.layers.sub import SubQuantizer
23
32
 
24
33
 
25
34
  class ONNXOpQuantizer:
@@ -64,6 +73,9 @@ class ONNXOpQuantizer:
64
73
  self.new_initializers = []
65
74
 
66
75
  # Register handlers
76
+ self.register("Add", AddQuantizer(self.new_initializers))
77
+ self.register("Sub", SubQuantizer(self.new_initializers))
78
+ self.register("Mul", MulQuantizer(self.new_initializers))
67
79
  self.register("Conv", ConvQuantizer(self.new_initializers))
68
80
  self.register("Relu", ReluQuantizer())
69
81
  self.register("Reshape", PassthroughQuantizer())
@@ -71,6 +83,7 @@ class ONNXOpQuantizer:
71
83
  self.register("Constant", ConstantQuantizer())
72
84
  self.register("MaxPool", MaxpoolQuantizer())
73
85
  self.register("Flatten", PassthroughQuantizer())
86
+ self.register("BatchNormalization", BatchnormQuantizer(self.new_initializers))
74
87
 
75
88
  def register(
76
89
  self: ONNXOpQuantizer,
@@ -198,3 +211,32 @@ class ONNXOpQuantizer:
198
211
  dict[str, onnx.TensorProto]: Map from initializer name to tensors in graph.
199
212
  """
200
213
  return {init.name: init for init in model.graph.initializer}
214
+
215
+ def apply_pre_analysis_transforms(
216
+ self: ONNXOpQuantizer,
217
+ model: onnx.ModelProto,
218
+ scale_exponent: int,
219
+ scale_base: int,
220
+ ) -> onnx.ModelProto:
221
+ """
222
+ Give each registered handler a chance to rewrite the model before analysis.
223
+ """
224
+ graph = model.graph
225
+ initializer_map = self.get_initializer_map(model)
226
+
227
+ # We allow handlers to modify graph in-place.
228
+ # (Nodes may be replaced, removed, or new nodes added.)
229
+ for node in list(graph.node):
230
+ handler = self.handlers.get(node.op_type)
231
+ if handler and hasattr(handler, "pre_analysis_transform"):
232
+ handler.pre_analysis_transform(
233
+ node,
234
+ graph,
235
+ initializer_map,
236
+ scale_exponent=scale_exponent,
237
+ scale_base=scale_base,
238
+ )
239
+ # Refresh map if transforms may add initializers
240
+ initializer_map = self.get_initializer_map(model)
241
+
242
+ return model
@@ -83,7 +83,9 @@ class GeneralLayerFunctions:
83
83
  tensor = torch.mul(tensor, self.scale_base**self.scale_exponent)
84
84
 
85
85
  tensor = tensor.long()
86
+ return self.reshape_inputs(tensor)
86
87
 
88
+ def reshape_inputs(self: GeneralLayerFunctions, tensor: torch.Tensor) -> None:
87
89
  if hasattr(self, "input_shape"):
88
90
  shape = self.input_shape
89
91
  if hasattr(self, "adjust_shape") and callable(
@@ -94,6 +96,7 @@ class GeneralLayerFunctions:
94
96
  tensor = tensor.reshape(shape)
95
97
  except RuntimeError as e:
96
98
  raise ShapeMismatchError(shape, list(tensor.shape)) from e
99
+
97
100
  return tensor
98
101
 
99
102
  def get_inputs(
@@ -154,21 +157,19 @@ class GeneralLayerFunctions:
154
157
  # If unknown dim in batch spot, assume batch size of 1
155
158
  first_key = next(iter(keys))
156
159
  input_shape = self.input_shape[first_key]
157
- input_shape[0] = 1 if input_shape[0] < 1 else input_shape[0]
160
+ input_shape[0] = max(input_shape[0], 1)
158
161
  return self.get_rand_inputs(input_shape)
159
162
  inputs = {}
160
163
  for key in keys:
161
164
  # If unknown dim in batch spot, assume batch size of 1
162
- input_shape = self.input_shape[keys[key]]
163
- if not isinstance(input_shape, list) and not isinstance(
164
- input_shape,
165
- tuple,
166
- ):
165
+ input_shape = self.input_shape[key]
166
+ if not isinstance(input_shape, (list, tuple)):
167
167
  msg = f"Invalid input shape for key '{key}': {input_shape}"
168
168
  raise CircuitUtilsError(msg)
169
- input_shape[0] = 1 if input_shape[0] < 1 else input_shape[0]
169
+ input_shape[0] = max(input_shape[0], 1)
170
170
  inputs[key] = self.get_rand_inputs(input_shape)
171
171
  return inputs
172
+
172
173
  if not (hasattr(self, "scale_base") and hasattr(self, "scale_exponent")):
173
174
  attr_name = "scale_base/scale_exponent"
174
175
  context = "needed for scaling random inputs"
@@ -195,8 +196,10 @@ class GeneralLayerFunctions:
195
196
  torch.Tensor: A tensor of random values in [-1, 1).
196
197
  """
197
198
  if not isinstance(input_shape, (list, tuple)):
198
- msg = f"Invalid input_shape type: {type(input_shape)}."
199
- " Expected list or tuple of ints."
199
+ msg = (
200
+ f"Invalid input_shape type: {type(input_shape)}."
201
+ " Expected list or tuple of ints."
202
+ )
200
203
  raise CircuitUtilsError(msg)
201
204
  if not all(isinstance(x, int) and x > 0 for x in input_shape):
202
205
  raise ShapeMismatchError(
@@ -240,9 +243,11 @@ class GeneralLayerFunctions:
240
243
  try:
241
244
  rescaled = torch.div(outputs, self.scale_base**self.scale_exponent)
242
245
  except Exception as e:
243
- msg = "Failed to rescale outputs using scale_base="
244
- f"{getattr(self, 'scale_base', None)} "
245
- f"and scale_exponent={getattr(self, 'scale_exponent', None)}: {e}"
246
+ msg = (
247
+ "Failed to rescale outputs using scale_base="
248
+ f"{getattr(self, 'scale_base', None)} "
249
+ f"and scale_exponent={getattr(self, 'scale_exponent', None)}: {e}"
250
+ )
246
251
  raise CircuitUtilsError(msg) from e
247
252
  return {
248
253
  "output": outputs.long().tolist(),
@@ -147,7 +147,8 @@ def get_models_to_test(
147
147
 
148
148
  Args:
149
149
  selected_models (list[str], optional):
150
- A list of model names to include. Defaults to None.
150
+ A list of model names to include. If None, returns empty list.
151
+ If provided but no matches found, returns empty list. Defaults to None.
151
152
  source_filter (str, optional):
152
153
  Restrict models to a specific source (e.g., "onnx", "class").
153
154
  Defaults to None.
@@ -155,10 +156,12 @@ def get_models_to_test(
155
156
  Returns:
156
157
  list[ModelEntry]: A filtered list of model entries.
157
158
  """
159
+ if selected_models is None:
160
+ return []
161
+
158
162
  models = MODELS_TO_TEST
159
163
 
160
- if selected_models is not None:
161
- models = [m for m in models if m.name in selected_models]
164
+ models = [m for m in models if m.name in selected_models]
162
165
 
163
166
  if source_filter is not None:
164
167
  models = [m for m in models if m.source == source_filter]
@@ -247,12 +247,12 @@ def export_onnx(
247
247
 
248
248
 
249
249
  def write_input_json(json_path: Path, input_shape: tuple[int] = (1, 4, 28, 28)) -> None:
250
- """Write a zero-valued input tensor to JSON alongside its [N,C,H,W] shape."""
250
+ """Write a zero-valued input tensor to JSON without shape information."""
251
251
  json_path.parent.mkdir(parents=True, exist_ok=True)
252
252
  n, c, h, w = input_shape
253
253
  arr = [0.0] * (n * c * h * w)
254
254
  with json_path.open("w", encoding="utf-8") as f:
255
- json.dump({"input": arr, "shape": [n, c, h, w]}, f)
255
+ json.dump({"input": arr}, f)
256
256
 
257
257
 
258
258
  def run_bench(