JSTprove 1.1.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: JSTprove
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: Zero-knowledge proofs of ML inference on ONNX models
5
5
  Author: Inference Labs Inc
6
6
  Requires-Python: >=3.10
@@ -45,7 +45,7 @@ Dynamic: license-file
45
45
  Zero-knowledge proofs of ML inference on **ONNX** models — powered by [Polyhedra Network’s **Expander**](https://github.com/PolyhedraZK/Expander) (GKR/sum-check prover) and [**Expander Compiler Collection (ECC)**](https://github.com/PolyhedraZK/ExpanderCompilerCollection).
46
46
 
47
47
  * 🎯 **You bring ONNX** → we quantize, compile to a circuit, generate a witness, prove, and verify — via a simple CLI.
48
- * ✅ Supported ops (current): **Conv2D**, **GEMM/MatMul (FC)**, **ReLU**, **MaxPool2D**, **Add**.
48
+ * ✅ Supported ops (current): **Conv2D**, **GEMM/MatMul (FC)**, **ReLU**, **MaxPool2D**, **Add**, **Mul**, **Sub**, **BatchNorm**.
49
49
  * 🧰 CLI details: see **[docs/cli.md](docs/cli.md)**
50
50
 
51
51
  👉 Just want to see it in action? Jump to [Quickstart (LeNet demo)](#quickstart-lenet-demo).<br>
@@ -85,7 +85,7 @@ You provide an **ONNX** model and inputs; JSTprove handles **quantization**, **c
85
85
  ### High-level architecture
86
86
 
87
87
  * **Python pipeline:** Converts **ONNX → quantized ONNX**, prepares I/O, drives the Rust runner, exposes the **CLI**.
88
- * **Rust crate:** `rust/jstprove_circuits` implements layer circuits (Conv2D, ReLU, MaxPool2D, GEMM/FC) and a runner.
88
+ * **Rust crate:** `rust/jstprove_circuits` implements layer circuits (Conv2D, ReLU, MaxPool2D, GEMM/FC, BatchNorm) and a runner.
89
89
  * **Circuit frontend:** [ECC](https://github.com/PolyhedraZK/ExpanderCompilerCollection) Rust API for arithmetic circuits.
90
90
  * **Prover backend:** [Expander](https://github.com/PolyhedraZK/Expander) (GKR/sum-check prover/verification).
91
91
 
@@ -1,9 +1,9 @@
1
- jstprove-1.1.0.dist-info/licenses/LICENSE,sha256=UXQRcYRUH-PfN27n3P-FMaZFY6jr9jFPKcwT7CWbljw,1160
1
+ jstprove-1.2.0.dist-info/licenses/LICENSE,sha256=UXQRcYRUH-PfN27n3P-FMaZFY6jr9jFPKcwT7CWbljw,1160
2
2
  python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  python/core/__init__.py,sha256=RlfbqGAaUulKl44QGMCkkGJBQZ8R_AgC5bU5zS7BjnA,97
4
4
  python/core/binaries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  python/core/binaries/expander-exec,sha256=C_1JcezdfLp9sFOQ2z3wp2gcq1k8zjIR09CxJKGGIuM,7095168
6
- python/core/binaries/onnx_generic_circuit_1-1-0,sha256=2YBhVx-neun-Dmx3ntyLq20qwsLrY9coOcU2bNLprZ0,3086160
6
+ python/core/binaries/onnx_generic_circuit_1-2-0,sha256=vLWr1O5PePljr54ZJ32dgHcuawzauRzuZpz7cZxvwgc,3144592
7
7
  python/core/circuit_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  python/core/circuit_models/generic_onnx.py,sha256=P65UZkfVBTE6YhaQ951S6QoTHPuU5ntDt8QL5pXghvw,8787
9
9
  python/core/circuit_models/simple_circuit.py,sha256=igQrZtQyreyHc26iAgCyDb0TuD2bJAoumYhc1pYPDzQ,4682
@@ -15,25 +15,30 @@ python/core/model_processing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NM
15
15
  python/core/model_processing/errors.py,sha256=uh2YFjuuU5JM3anMtSTLAH-zjlNAKStmLDZqRUgBWS8,4611
16
16
  python/core/model_processing/converters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  python/core/model_processing/converters/base.py,sha256=eG7iRDbDJJDTG2cCVgYlPlfkpmYPEnMzjGNK9wrA1m0,4303
18
- python/core/model_processing/converters/onnx_converter.py,sha256=BJc6rU3wLHI3imt8yzm8Cngri3KvcBSUbJ3Urw2PoEQ,44560
18
+ python/core/model_processing/converters/onnx_converter.py,sha256=-eXdF6tfluFRxGgnQtJQ8R2309aYX-8z8HzMxk_Qv8I,44340
19
19
  python/core/model_processing/onnx_custom_ops/__init__.py,sha256=ofecV9pzpDJJl_r6inRw9JOKxtfK2rzzxWahAq9BKXE,475
20
+ python/core/model_processing/onnx_custom_ops/batchnorm.py,sha256=8kg4iGGdt6B_fIJkpt4v5eNFpoHa4bjTB0NnCSmKFvE,1693
20
21
  python/core/model_processing/onnx_custom_ops/conv.py,sha256=6jJm3fcGWzcU4RjVgf179mPFCqsl4C3AR7bqQTffDgA,3464
21
22
  python/core/model_processing/onnx_custom_ops/custom_helpers.py,sha256=2WdnHw9NAoN_6wjIBoAQDyL6wEIlZOqo6ysCZp5DpZs,1844
22
23
  python/core/model_processing/onnx_custom_ops/gemm.py,sha256=bnEUXhqQCEcH4TIfbMTsCTtAlAlRzFvl4jj8g2QZFWU,2674
23
24
  python/core/model_processing/onnx_custom_ops/maxpool.py,sha256=Sd3BwqpGLSVU2iuAAIXAHdI3WO27Aa3g3r29HPiECvM,2319
25
+ python/core/model_processing/onnx_custom_ops/mul.py,sha256=w6X1sl1HnzoUJx2Mm_LaoXGTpvtwXxr3zZDPySVHBcM,1888
24
26
  python/core/model_processing/onnx_custom_ops/onnx_helpers.py,sha256=utnJuc5sgb_z1LgxuY9y2cQbMpdEJ8xOOrcP8DhfDCM,5686
25
27
  python/core/model_processing/onnx_custom_ops/relu.py,sha256=pZsPXC_r0FPggURKDphh8P1IRXY0w4hH7ExBmYTlWjE,1202
26
28
  python/core/model_processing/onnx_quantizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
29
  python/core/model_processing/onnx_quantizer/exceptions.py,sha256=_YaXXEMbfD1P8N86L5YIz3uCilkuzlhv_2lU90T4FfA,5646
28
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=POoDEBFzkr145P4INgAux2LQY2GdpsBtRpw_UuKVNhw,7679
30
+ python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=ncL0rK5hXZUvssmw20PZO1WyjYSyenem23B6QLUHlLY,9213
29
31
  python/core/model_processing/onnx_quantizer/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
32
  python/core/model_processing/onnx_quantizer/layers/add.py,sha256=AGxzqMa0jABIEKOIgPqEAA7EpZtynQtnD9nxI2NHc0s,1409
31
- python/core/model_processing/onnx_quantizer/layers/base.py,sha256=LvyTvmR2w6jYSJiBvyFluaDgL_Voc6dZ00TTWi6V7Tc,17426
33
+ python/core/model_processing/onnx_quantizer/layers/base.py,sha256=Vq6pwChw9eJMKYAJyA1C3wLycaBConkP9sNRInpWavo,19989
34
+ python/core/model_processing/onnx_quantizer/layers/batchnorm.py,sha256=KSBDPHd52f5Qyf-cnIDFPmfzssaJgMPiTmpIWEdM41U,7718
32
35
  python/core/model_processing/onnx_quantizer/layers/constant.py,sha256=l1IvgvXkmFMiaBsym8wchPF-y1ZH-c5PmFUy92IXWok,3694
33
36
  python/core/model_processing/onnx_quantizer/layers/conv.py,sha256=TlUpCRO6PPqH7MPkIrEiEcVfzuiN1WMYEiNIjhYXtWM,4451
34
37
  python/core/model_processing/onnx_quantizer/layers/gemm.py,sha256=7fCUMv8OLVZ45a2lYjA2XNvcW3By7lSbX7zeForNK-0,3950
35
38
  python/core/model_processing/onnx_quantizer/layers/maxpool.py,sha256=PJ8hZPPBpfWV_RZdySl50-BU8TATjcg8Tg_mrAVS1Ic,4916
39
+ python/core/model_processing/onnx_quantizer/layers/mul.py,sha256=qHsmnYPH-c5uiFeDCvV6e1xSgmIXJ64Sjvh0LYDYEqQ,1396
36
40
  python/core/model_processing/onnx_quantizer/layers/relu.py,sha256=d-5fyeKNLTgKKnqCwURpxkjl7QdbJQpuovtCFBM03FA,1685
41
+ python/core/model_processing/onnx_quantizer/layers/sub.py,sha256=M7D98TZBNP9-2R9MX6mcpYlrWFxTiX9JCs3XNcg1U-Q,1409
37
42
  python/core/model_templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
43
  python/core/model_templates/circuit_template.py,sha256=X8bA4AdmtQeb3ltU74GaWYfrOFhqs_DOpUqRMFXLAD8,2352
39
44
  python/core/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -62,7 +67,7 @@ python/frontend/commands/bench/model.py,sha256=SaIWXAXZbWGbrNqEo5bs4NwgZfMOmmxaC
62
67
  python/frontend/commands/bench/sweep.py,sha256=rl-QBS9eXgQkuPJBhsU4CohfE1PdJvnM8NRhNU7ztQw,5279
63
68
  python/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
64
69
  python/scripts/benchmark_runner.py,sha256=sjbqaLrdjt94AoyQXAxT4FhsN6aRu5idTRQ5uHmZOWM,28593
65
- python/scripts/gen_and_bench.py,sha256=9kcIj-K_nG-G194C68Uig-Yw-p3nYKESACIpWRflmts,16276
70
+ python/scripts/gen_and_bench.py,sha256=V36x7djYmHlveAJgYzMlXwnmF0gAGO3-1mg9PWOmpj8,16249
66
71
  python/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
67
72
  python/tests/test_cli.py,sha256=OiAyG3aBpukk0i5FFWbiKaF42wf-7By-UWDHNjwtsqo,27042
68
73
  python/tests/circuit_e2e_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -82,27 +87,30 @@ python/tests/onnx_quantizer_tests/testing_helper_functions.py,sha256=N0fQv2pYzUC
82
87
  python/tests/onnx_quantizer_tests/layers/__init__.py,sha256=xP-RmW6LfIANgK1s9Q0KZet2yvNr-3c6YIVLAAQqGUY,404
83
88
  python/tests/onnx_quantizer_tests/layers/add_config.py,sha256=T3tGddupDtrvLck2SL2yETDblNtv0aU7Tl7fNyZUhO4,4133
84
89
  python/tests/onnx_quantizer_tests/layers/base.py,sha256=uLCqhMcBA7zWiRSLRMNKKb4A9N27l-RUqSEEQ8SR3xI,9393
90
+ python/tests/onnx_quantizer_tests/layers/batchnorm_config.py,sha256=P-sZuHAdEfNczcgTeLjqJnEbpqN3dKTsbqvY4-SBqiQ,8231
85
91
  python/tests/onnx_quantizer_tests/layers/constant_config.py,sha256=RdrKNMNZjI3Sk5o8WLNqmBUyYVJRWgtFbQ6oFWMwyQk,1193
86
92
  python/tests/onnx_quantizer_tests/layers/conv_config.py,sha256=H0ioW4H3ei5IK4tKhrA0ffThxJ4K5oO9jIs9A0T0VaM,6005
87
93
  python/tests/onnx_quantizer_tests/layers/factory.py,sha256=WLLEP9ECmSpTliSjhtdWOHcX1xOi6HM10S9Y4re1A74,4844
88
94
  python/tests/onnx_quantizer_tests/layers/flatten_config.py,sha256=Xln5Hh6gyeM5gGRCjLGvIL-u08NEs1tXSF32urCqPfE,2110
89
95
  python/tests/onnx_quantizer_tests/layers/gemm_config.py,sha256=t7nJY-Wnj6YUD821-jaWzgrQVPa6ytwER3hFMsvyY6Y,7294
90
96
  python/tests/onnx_quantizer_tests/layers/maxpool_config.py,sha256=XfTPk_ZQXEzaCjHHymSLVv2HS-PKH1rS9IuyyoEtM78,3176
97
+ python/tests/onnx_quantizer_tests/layers/mul_config.py,sha256=_Oy4b97ORxFlF3w0BmJ94hNA968HQx2AvwYiASrGPxw,4135
91
98
  python/tests/onnx_quantizer_tests/layers/relu_config.py,sha256=_aHuddDApLUBOa0FiR9h4fNfmMSnH5r4JzOMLW0KaTk,2197
92
99
  python/tests/onnx_quantizer_tests/layers/reshape_config.py,sha256=fZchSqIAy76m7j97wVC_UI6slSpv8nbwukhkbGR2sRE,2203
100
+ python/tests/onnx_quantizer_tests/layers/sub_config.py,sha256=IxF18mG9kjlEiKYSNG912CEcBxOFGxIWoRAwjvBXiRo,4133
93
101
  python/tests/onnx_quantizer_tests/layers_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
94
102
  python/tests/onnx_quantizer_tests/layers_tests/base_test.py,sha256=UgbcT97tgcuTtO1pOADpww9bz_JElKiI2mxLJYKyF1k,2992
95
103
  python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py,sha256=Vxn4LEWHZeGa_vS1-7ptFqSSBb0D-3BG-ETocP4pvsI,3651
96
104
  python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py,sha256=40779aaHgdryVwLlIO18F1d7uSLSXdJUG5Uj_5-xD4U,6712
97
105
  python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py,sha256=t5c_zqO4Ex3HIFWcykX4PTftdKN7UWnEOF5blShL0Ik,1881
98
106
  python/tests/onnx_quantizer_tests/layers_tests/test_integration.py,sha256=Mq1-PBKR3756i9VrFOP5DY3GkRE32D6Hjd1fK9wZdVk,7228
99
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py,sha256=zclzXxtgA5BEmNwSf_aNbJgbsArMXn5WDdlxiMR2-aM,9255
107
+ python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py,sha256=bVdMDkIq0gdHNLTFrWRdrCgCAG03rEF8aCRU-t4b4Kg,9391
100
108
  python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py,sha256=RfnIIiYbgPbU3620H6MPvSxE3MNR2G1yPELwdWV3mK4,4107
101
109
  python/tests/onnx_quantizer_tests/layers_tests/test_validation.py,sha256=jz-WtIEP-jjUklOOAnznwPUXbf07U2PAMGrhzMWP0JU,1371
102
110
  python/tests/utils_testing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
103
111
  python/tests/utils_testing/test_helper_functions.py,sha256=xmeGQieh4LE9U-CDKBlHhSWqH0cAmmDU3qXNbDkkvms,27192
104
- jstprove-1.1.0.dist-info/METADATA,sha256=3gdOLaD4eYGawv4SuvofjuzBW-y564J4gpNPXHFNY1A,14056
105
- jstprove-1.1.0.dist-info/WHEEL,sha256=jc2C2uw104ioj1TL9cE0YO67_kdAwX4W8JgYPomxr5M,105
106
- jstprove-1.1.0.dist-info/entry_points.txt,sha256=nGcTSO-4q08gPl1IoWdrPaiY7IbO7XvmXKkd34dYHc8,49
107
- jstprove-1.1.0.dist-info/top_level.txt,sha256=J-z0poNcsv31IHB413--iOY8LoHBKiTHeybHX3abokI,7
108
- jstprove-1.1.0.dist-info/RECORD,,
112
+ jstprove-1.2.0.dist-info/METADATA,sha256=UVxR8iFm2kjvrvh1t4hEaCn0n4ZCYE2fcurGeCRmRCk,14100
113
+ jstprove-1.2.0.dist-info/WHEEL,sha256=jc2C2uw104ioj1TL9cE0YO67_kdAwX4W8JgYPomxr5M,105
114
+ jstprove-1.2.0.dist-info/entry_points.txt,sha256=nGcTSO-4q08gPl1IoWdrPaiY7IbO7XvmXKkd34dYHc8,49
115
+ jstprove-1.2.0.dist-info/top_level.txt,sha256=J-z0poNcsv31IHB413--iOY8LoHBKiTHeybHX3abokI,7
116
+ jstprove-1.2.0.dist-info/RECORD,,
@@ -247,6 +247,7 @@ class ONNXConverter(ModelConverter):
247
247
 
248
248
  def analyze_layers(
249
249
  self: ONNXConverter,
250
+ model: onnx.ModelProto,
250
251
  output_name_to_shape: dict[str, list[int]] | None = None,
251
252
  ) -> tuple[list[ONNXLayer], list[ONNXLayer]]:
252
253
  """Analyze the onnx model graph into
@@ -268,29 +269,29 @@ class ONNXConverter(ModelConverter):
268
269
  id_count = 0
269
270
  # Apply shape inference on the model
270
271
  if not output_name_to_shape:
271
- inferred_model = shape_inference.infer_shapes(self.model)
272
+ inferred_model = shape_inference.infer_shapes(model)
272
273
  self._onnx_check_model_safely(inferred_model)
273
274
 
274
275
  output_name_to_shape = extract_shape_dict(inferred_model)
275
276
  domain_to_version = {
276
- opset.domain: opset.version for opset in self.model.opset_import
277
+ opset.domain: opset.version for opset in model.opset_import
277
278
  }
278
279
 
279
280
  id_count = 0
280
281
  architecture = self.get_model_architecture(
281
- self.model,
282
+ model,
282
283
  output_name_to_shape,
283
284
  domain_to_version,
284
285
  )
285
286
  w_and_b = self.get_model_w_and_b(
286
- self.model,
287
+ model,
287
288
  output_name_to_shape,
288
289
  id_count,
289
290
  domain_to_version,
290
291
  )
291
292
  except InvalidModelError:
292
293
  raise
293
- except (ValueError, TypeError, RuntimeError, OSError, onnx.ONNXException) as e:
294
+ except (ValueError, TypeError, RuntimeError, OSError) as e:
294
295
  raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
295
296
  except Exception as e:
296
297
  raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
@@ -557,6 +558,7 @@ class ONNXConverter(ModelConverter):
557
558
  output_shapes = {
558
559
  out_name: output_name_to_shape.get(out_name, []) for out_name in outputs
559
560
  }
561
+
560
562
  return ONNXLayer(
561
563
  id=layer_id,
562
564
  name=name,
@@ -605,6 +607,7 @@ class ONNXConverter(ModelConverter):
605
607
  np_data = onnx.numpy_helper.to_array(node, constant_dtype)
606
608
  except (ValueError, TypeError, onnx.ONNXException, Exception) as e:
607
609
  raise SerializationError(
610
+ model_type=self.model_type,
608
611
  tensor_name=node.name,
609
612
  reason=f"Failed to convert tensor: {e!s}",
610
613
  ) from e
@@ -1040,38 +1043,36 @@ class ONNXConverter(ModelConverter):
1040
1043
  ``rescale_config``.
1041
1044
  """
1042
1045
  inferred_model = shape_inference.infer_shapes(self.model)
1043
-
1044
- scaling = BaseOpQuantizer.get_scaling(
1045
- scale_base=getattr(self, "scale_base", 2),
1046
- scale_exponent=(getattr(self, "scale_exponent", 18)),
1047
- )
1046
+ scale_base = getattr(self, "scale_base", 2)
1047
+ scale_exponent = getattr(self, "scale_exponent", 18)
1048
1048
 
1049
1049
  # Check the model and print Y"s shape information
1050
1050
  self._onnx_check_model_safely(inferred_model)
1051
1051
  output_name_to_shape = extract_shape_dict(inferred_model)
1052
- (architecture, w_and_b) = self.analyze_layers(output_name_to_shape)
1053
- for w in w_and_b:
1052
+ scaled_and_transformed_model = self.op_quantizer.apply_pre_analysis_transforms(
1053
+ inferred_model,
1054
+ scale_exponent=scale_exponent,
1055
+ scale_base=scale_base,
1056
+ )
1057
+ # Get layers in correct format
1058
+ (architecture, w_and_b) = self.analyze_layers(
1059
+ scaled_and_transformed_model,
1060
+ output_name_to_shape,
1061
+ )
1062
+
1063
+ def _convert_tensor_to_int_list(w: ONNXLayer) -> list:
1054
1064
  try:
1055
- w_and_b_array = np.asarray(w.tensor)
1056
- except (ValueError, TypeError, Exception) as e:
1065
+ arr = np.asarray(w.tensor).astype(np.int64)
1066
+ return arr.tolist()
1067
+ except Exception as e:
1057
1068
  raise SerializationError(
1058
1069
  tensor_name=getattr(w, "name", None),
1070
+ model_type=self.model_type,
1059
1071
  reason=f"cannot convert to ndarray: {e}",
1060
1072
  ) from e
1061
1073
 
1062
- try:
1063
- # TODO @jsgold-1: We need a better way to distinguish bias tensors from weight tensors # noqa: FIX002, TD003,E501
1064
- if "bias" in w.name:
1065
- w_and_b_scaled = w_and_b_array * scaling * scaling
1066
- else:
1067
- w_and_b_scaled = w_and_b_array * scaling
1068
- w_and_b_out = w_and_b_scaled.astype(np.int64).tolist()
1069
- w.tensor = w_and_b_out
1070
- except (ValueError, TypeError, OverflowError, Exception) as e:
1071
- raise SerializationError(
1072
- tensor_name=getattr(w, "name", None),
1073
- reason=str(e),
1074
- ) from e
1074
+ for w in w_and_b:
1075
+ w.tensor = _convert_tensor_to_int_list(w)
1075
1076
 
1076
1077
  inputs = []
1077
1078
  outputs = []
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from onnxruntime_extensions import PyCustomOpDef, onnx_op
5
+
6
+ from .custom_helpers import rescaling
7
+
8
+
9
+ @onnx_op(
10
+ op_type="Int64BatchNorm",
11
+ domain="ai.onnx.contrib",
12
+ inputs=[
13
+ PyCustomOpDef.dt_int64, # X (int64)
14
+ PyCustomOpDef.dt_int64, # mul (int64 scaled multiplier)
15
+ PyCustomOpDef.dt_int64, # add (int64 scaled adder)
16
+ PyCustomOpDef.dt_int64, # scaling_factor
17
+ ],
18
+ outputs=[PyCustomOpDef.dt_int64],
19
+ attrs={"rescale": PyCustomOpDef.dt_int64},
20
+ )
21
+ def int64_batchnorm(
22
+ x: np.ndarray,
23
+ mul: np.ndarray,
24
+ add: np.ndarray,
25
+ scaling_factor: np.ndarray | None = None,
26
+ rescale: int | None = None,
27
+ ) -> np.ndarray:
28
+ """
29
+ Int64 BatchNorm (folded into affine transform).
30
+
31
+ Computes:
32
+ Y = X * mul + add
33
+ where mul/add are already scaled to int64.
34
+
35
+ Parameters
36
+ ----------
37
+ x : Input int64 tensor
38
+ mul : Per-channel int64 scale multipliers
39
+ add : Per-channel int64 bias terms
40
+ scaling_factor: factor to rescale
41
+ rescale : Optional flag to apply post-scaling
42
+
43
+ Returns
44
+ -------
45
+ numpy.ndarray (int64)
46
+ """
47
+ try:
48
+ # Broadcasting shapes must match batchnorm layout: NCHW
49
+ # Typically mul/add have shape [C]
50
+ dims_x = len(x.shape)
51
+ dim_ones = (1,) * (dims_x - 2)
52
+ mul = mul.reshape(-1, *dim_ones)
53
+ add = add.reshape(-1, *dim_ones)
54
+
55
+ y = x * mul + add
56
+
57
+ if rescale is not None:
58
+ y = rescaling(scaling_factor, rescale, y)
59
+
60
+ return y.astype(np.int64)
61
+
62
+ except Exception as e:
63
+ msg = f"Int64BatchNorm failed: {e}"
64
+ raise RuntimeError(msg) from e
@@ -0,0 +1,66 @@
1
+ import numpy as np
2
+ from onnxruntime_extensions import PyCustomOpDef, onnx_op
3
+
4
+ from .custom_helpers import rescaling
5
+
6
+
7
+ @onnx_op(
8
+ op_type="Int64Mul",
9
+ domain="ai.onnx.contrib",
10
+ inputs=[
11
+ PyCustomOpDef.dt_int64,
12
+ PyCustomOpDef.dt_int64,
13
+ PyCustomOpDef.dt_int64, # Scalar
14
+ ],
15
+ outputs=[PyCustomOpDef.dt_int64],
16
+ attrs={
17
+ "rescale": PyCustomOpDef.dt_int64,
18
+ },
19
+ )
20
+ def int64_mul(
21
+ a: np.ndarray,
22
+ b: np.ndarray,
23
+ scaling_factor: np.ndarray | None = None,
24
+ rescale: int | None = None,
25
+ ) -> np.ndarray:
26
+ """
27
+ Performs a Mul (hadamard product) operation on int64 input tensors.
28
+
29
+ This function is registered as a custom ONNX operator via onnxruntime_extensions
30
+ and is used in the JSTprove quantized inference pipeline.
31
+ It applies Mul with the rescaling the outputs back to the original scale.
32
+
33
+ Parameters
34
+ ----------
35
+ a : np.ndarray
36
+ First input tensor with dtype int64.
37
+ b : np.ndarray
38
+ Second input tensor with dtype int64.
39
+ scaling_factor : Scaling factor for rescaling the output.
40
+ Optional scalar tensor for rescaling when rescale=1.
41
+ rescale : int, optional
42
+ Whether to apply rescaling (0=no, 1=yes).
43
+
44
+ Returns
45
+ -------
46
+ numpy.ndarray
47
+ Mul tensor with dtype int64.
48
+
49
+ Notes
50
+ -----
51
+ - This op is part of the `ai.onnx.contrib` custom domain.
52
+ - ONNX Runtime Extensions is required to register this op.
53
+
54
+ References
55
+ ----------
56
+ For more information on the Mul operation, please refer to the
57
+ ONNX standard Mul operator documentation:
58
+ https://onnx.ai/onnx/operators/onnx__Mul.html
59
+ """
60
+ try:
61
+ result = a * b
62
+ result = rescaling(scaling_factor, rescale, result)
63
+ return result.astype(np.int64)
64
+ except Exception as e:
65
+ msg = f"Int64Mul failed: {e}"
66
+ raise RuntimeError(msg) from e
@@ -479,6 +479,73 @@ class QuantizerBase:
479
479
  nodes.append(quantized_node)
480
480
  return nodes
481
481
 
482
+ def pre_analysis_transform(
483
+ self: QuantizerBase,
484
+ node: onnx.NodeProto,
485
+ graph: onnx.GraphProto,
486
+ initializer_map: dict[str, onnx.TensorProto],
487
+ scale_base: int,
488
+ scale_exponent: int,
489
+ ) -> None:
490
+ """
491
+ pre_analysis_transform aims to transform the given layer along the
492
+ same lines as it would be transformed for the quantized model, but
493
+ for the weights and biases file instead, to be sent to the backend
494
+
495
+ Default pre-analysis behavior:
496
+
497
+ - If the subclass uses weights/bias (`USE_WB=True`), apply the SAME
498
+ scaling rules as quantization, but directly mutate the initializers.
499
+
500
+ - Subclasses can override this to implement more complex rewrites
501
+ (e.g., BatchNorm → Mul/Add).
502
+
503
+ Args:
504
+ node (onnx.NodeProto): Node to transform.
505
+ graph (onnx.GraphProto): Rest of the Onnx graph for initializers.
506
+ initializer_map (dict[str, onnx.TensorProto]): The initializer map.
507
+
508
+ scale_base (int): Scaling base.
509
+ scale_exponent (int): Scaling exponent.
510
+
511
+ NOTE
512
+ - The resulting model will not make accurate prediction and should be
513
+ used solely for analysis and keeping track of w_and_b
514
+ """
515
+ # If subclass does not want auto-scaling, do nothing
516
+ if not getattr(self, "USE_WB", False):
517
+ return
518
+
519
+ # Each quantizer defines which inputs to scale (Weight:1x, Bias:2x etc.)
520
+ scale_plan = getattr(self, "SCALE_PLAN", {})
521
+
522
+ # Perform the same scaling as quantization, but directly modify initializers
523
+ for input_idx, scale_mult in scale_plan.items():
524
+ if input_idx >= len(node.input):
525
+ continue
526
+
527
+ name = node.input[input_idx]
528
+ if name not in initializer_map:
529
+ continue # optional input missing
530
+
531
+ tensor = initializer_map[name]
532
+ arr = numpy_helper.to_array(tensor).astype(np.float64)
533
+
534
+ scale = scale_base ** (scale_exponent * scale_mult)
535
+ new_arr = arr * scale
536
+
537
+ # Replace initializer directly
538
+ new_tensor = numpy_helper.from_array(new_arr, name=tensor.name)
539
+
540
+ # Modify graph initializer in place
541
+ for j in range(len(graph.initializer)):
542
+ if graph.initializer[j].name == tensor.name:
543
+ del graph.initializer[j]
544
+ break
545
+ graph.initializer.append(new_tensor)
546
+
547
+ initializer_map[tensor.name] = new_tensor
548
+
482
549
 
483
550
  class PassthroughQuantizer(BaseOpQuantizer):
484
551
  """
@@ -0,0 +1,224 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ from python.core.circuits.errors import CircuitConfigurationError
6
+
7
+ if TYPE_CHECKING:
8
+ import onnx
9
+
10
+ import numpy as np
11
+ from onnx import helper, numpy_helper
12
+
13
+ from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_attributes
14
+ from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
15
+ from python.core.model_processing.onnx_quantizer.layers.base import (
16
+ BaseOpQuantizer,
17
+ QuantizerBase,
18
+ ScaleConfig,
19
+ )
20
+
21
+
22
+ class QuantizeBatchnorm(QuantizerBase):
23
+ OP_TYPE = "Int64BatchNorm"
24
+ USE_WB = True
25
+ USE_SCALING = False
26
+ SCALE_PLAN: ClassVar = {}
27
+
28
+
29
+ class BatchnormQuantizer(BaseOpQuantizer, QuantizeBatchnorm):
30
+ """
31
+ Quantizer for ONNX Batchnorm layers.
32
+
33
+ - Uses standard ONNX Batchnorm layer in standard domain, and
34
+ makes relevant additional changes to the graph.
35
+ """
36
+
37
+ def __init__(
38
+ self: BatchnormQuantizer,
39
+ new_initializers: list[onnx.TensorProto] | None = None,
40
+ ) -> None:
41
+ super().__init__()
42
+ # Only replace if caller provided something
43
+ if new_initializers is not None:
44
+ self.new_initializers = new_initializers
45
+
46
+ def _compute_mul_add(
47
+ self: BatchnormQuantizer,
48
+ initializer_map: dict[str, onnx.TensorProto],
49
+ node: onnx.NodeProto,
50
+ scale_base: int,
51
+ scale_exponent: int,
52
+ ) -> tuple[np.ndarray, np.ndarray]:
53
+ """
54
+ Compute the 'mul' and 'add' tensors for BatchNorm folding.
55
+ """
56
+ self._validate_inputs(node=node)
57
+ # ONNX BatchNorm inputs: [X, scale, bias, mean, var]
58
+ scale_factor = scale_base**scale_exponent
59
+ scale = numpy_helper.to_array(initializer_map[node.input[1]]).astype(np.float32)
60
+ bias = numpy_helper.to_array(initializer_map[node.input[2]]).astype(np.float32)
61
+ mean = numpy_helper.to_array(initializer_map[node.input[3]]).astype(np.float32)
62
+ var = numpy_helper.to_array(initializer_map[node.input[4]]).astype(np.float32)
63
+
64
+ # Find epsilon attribute
65
+ epsilon_attr = next((a for a in node.attribute if a.name == "epsilon"), None)
66
+ epsilon = float(epsilon_attr.f) if epsilon_attr else 1e-5
67
+
68
+ mul = scale / np.sqrt(var + epsilon)
69
+ add = bias - mean * mul
70
+ scaled_add = add * (scale_factor**2)
71
+ scaled_mul = scale_factor * mul
72
+ return scaled_mul, scaled_add
73
+
74
+ def pre_analysis_transform(
75
+ self: BatchnormQuantizer,
76
+ node: onnx.NodeProto,
77
+ graph: onnx.GraphProto,
78
+ initializer_map: dict[str, onnx.TensorProto],
79
+ scale_base: int,
80
+ scale_exponent: int,
81
+ ) -> None:
82
+ # Compute linearized BN tensors
83
+ mul, add = self._compute_mul_add(
84
+ initializer_map,
85
+ node,
86
+ scale_base=scale_base,
87
+ scale_exponent=scale_exponent,
88
+ )
89
+
90
+ # Name base
91
+ node_name = node.name if node.name else node.input[0]
92
+ mul_name = f"{node_name}_mul"
93
+ add_name = f"{node_name}_add"
94
+
95
+ # Create ONNX tensors
96
+ mul_tensor = numpy_helper.from_array(mul.astype(np.int64), name=mul_name)
97
+ add_tensor = numpy_helper.from_array(add.astype(np.int64), name=add_name)
98
+
99
+ # Insert them into the graph
100
+ graph.initializer.extend([mul_tensor, add_tensor])
101
+ initializer_map[mul_name] = mul_tensor
102
+ initializer_map[add_name] = add_tensor
103
+ self.new_initializers.extend([mul_tensor, add_tensor])
104
+
105
+ node.input[:] = [node.input[0], mul_name, add_name]
106
+
107
+ del node.attribute[:]
108
+
109
+ def quantize(
110
+ self,
111
+ node: onnx.NodeProto,
112
+ graph: onnx.GraphProto,
113
+ scale_config: ScaleConfig,
114
+ initializer_map: dict[str, onnx.TensorProto],
115
+ ) -> list[onnx.NodeProto]:
116
+ _ = graph
117
+
118
+ nodes: list[onnx.NodeProto] = []
119
+
120
+ # 1. Compute unscaled float mul/add coefficients
121
+ mul, add = self._compute_mul_add(
122
+ initializer_map,
123
+ node,
124
+ scale_base=1,
125
+ scale_exponent=1,
126
+ )
127
+
128
+ node_name = node.name if node.name else node.input[0]
129
+ mul_name = f"{node_name}_mul"
130
+ add_name = f"{node_name}_add"
131
+
132
+ # 2. Store unscaled mul and add initializers (as floats)
133
+ scale_value = self.get_scaling(scale_config.base, scale_config.exponent)
134
+ scale_name = f"{node.name}_int_scaler"
135
+ scale_tensor = numpy_helper.from_array(
136
+ np.array([scale_value], dtype=np.int64),
137
+ name=scale_name,
138
+ )
139
+ self.new_initializers.append(scale_tensor)
140
+
141
+ mul_tensor = numpy_helper.from_array(mul.astype(np.float32), name=mul_name)
142
+ add_tensor = numpy_helper.from_array(add.astype(np.float32), name=add_name)
143
+
144
+ initializer_map[mul_name] = mul_tensor
145
+ initializer_map[add_name] = add_tensor
146
+
147
+ # 3. Insert scale and cast for mul_tensor
148
+ scaled_mul_name, mul_scale_node, mul_cast_node = self.insert_scale_node(
149
+ tensor=mul_tensor,
150
+ scale_base=scale_config.base,
151
+ scale_exponent=scale_config.exponent,
152
+ )
153
+
154
+ # 4. Insert scale and cast for add_tensor
155
+ scaled_add_name, add_scale_node, add_cast_node = self.insert_scale_node(
156
+ tensor=add_tensor,
157
+ scale_base=scale_config.base,
158
+ scale_exponent=scale_config.exponent * 2,
159
+ )
160
+ # Note, order is important here
161
+ nodes.extend(
162
+ [
163
+ mul_scale_node,
164
+ mul_cast_node,
165
+ add_scale_node,
166
+ add_cast_node,
167
+ ],
168
+ )
169
+
170
+ # 5. Build final Int64BatchNorm node
171
+ attrs = extract_attributes(node)
172
+ for k, v in getattr(self, "DEFAULT_ATTRS", {}).items():
173
+ attrs.setdefault(k, v)
174
+ attrs["rescale"] = 1
175
+
176
+ quant_node = helper.make_node(
177
+ self.OP_TYPE, # Should be "Int64BatchNorm"
178
+ inputs=[
179
+ node.input[0], # original X
180
+ scaled_mul_name, # scaled mul
181
+ scaled_add_name, # scaled add
182
+ scale_name, # scaling factor
183
+ ],
184
+ outputs=node.output,
185
+ name=node.name,
186
+ domain=self.DOMAIN,
187
+ **attrs,
188
+ )
189
+
190
+ nodes.append(quant_node)
191
+ return nodes
192
+
193
+ def check_supported(
194
+ self: BatchnormQuantizer,
195
+ node: onnx.NodeProto,
196
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
197
+ ) -> None:
198
+ """
199
+ For our current implementation, all batchnorm inputs
200
+ (scale, variance, mean, etc.)
201
+ must be initializers to the circuit and not inputs from earlier in the graph.
202
+ """
203
+
204
+ if initializer_map is None:
205
+ msg = "initializer_map is required for BatchNorm support check"
206
+ raise CircuitConfigurationError(node.name, node.op_type, msg)
207
+
208
+ self._validate_inputs(node=node)
209
+
210
+ # First, check to make sure that each of the batchnorm inputs are initializers
211
+ initializer_inputs = node.input[1:]
212
+ if not all(i in initializer_map for i in initializer_inputs):
213
+ msg = "Unsupported BatchNorm with normalization inputs not in initializers"
214
+ raise InvalidParamError(node.name, node.op_type, msg)
215
+
216
+ def _validate_inputs(self, node: onnx.NodeProto) -> None:
217
+ """Validate BatchNorm has required inputs in initializer_map."""
218
+ num_inputs = 5
219
+ if len(node.input) < num_inputs:
220
+ raise InvalidParamError(
221
+ node.name,
222
+ node.op_type,
223
+ f"BatchNorm requires 5 inputs, got {len(node.input)}",
224
+ )
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
7
+
8
+ from python.core.model_processing.onnx_quantizer.layers.base import (
9
+ BaseOpQuantizer,
10
+ QuantizerBase,
11
+ ScaleConfig,
12
+ )
13
+
14
+
15
+ class QuantizeMul(QuantizerBase):
16
+ OP_TYPE = "Int64Mul"
17
+ USE_WB = True
18
+ USE_SCALING = True
19
+ SCALE_PLAN: ClassVar = {0: 1, 1: 1}
20
+
21
+
22
+ class MulQuantizer(BaseOpQuantizer, QuantizeMul):
23
+ """
24
+ Quantizer for ONNX Mul layers.
25
+
26
+ - Uses custom Mul layer to incorporate rescaling, and
27
+ makes relevant additional changes to the graph.
28
+ """
29
+
30
+ def __init__(
31
+ self: MulQuantizer,
32
+ new_initializers: list[onnx.TensorProto] | None = None,
33
+ ) -> None:
34
+ super().__init__()
35
+ # Only replace if caller provided something
36
+ if new_initializers is not None:
37
+ self.new_initializers = new_initializers
38
+
39
+ def quantize(
40
+ self: MulQuantizer,
41
+ node: onnx.NodeProto,
42
+ graph: onnx.GraphProto,
43
+ scale_config: ScaleConfig,
44
+ initializer_map: dict[str, onnx.TensorProto],
45
+ ) -> list[onnx.NodeProto]:
46
+ return QuantizeMul.quantize(self, node, graph, scale_config, initializer_map)
47
+
48
+ def check_supported(
49
+ self: MulQuantizer,
50
+ node: onnx.NodeProto,
51
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
52
+ ) -> None:
53
+ pass
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, ClassVar
4
+
5
+ if TYPE_CHECKING:
6
+ import onnx
7
+
8
+ from python.core.model_processing.onnx_quantizer.layers.base import (
9
+ BaseOpQuantizer,
10
+ QuantizerBase,
11
+ ScaleConfig,
12
+ )
13
+
14
+
15
+ class QuantizeSub(QuantizerBase):
16
+ OP_TYPE = "Sub"
17
+ DOMAIN = ""
18
+ USE_WB = True
19
+ USE_SCALING = False
20
+ SCALE_PLAN: ClassVar = {0: 1, 1: 1}
21
+
22
+
23
+ class SubQuantizer(BaseOpQuantizer, QuantizeSub):
24
+ """
25
+ Quantizer for ONNX Sub layers.
26
+
27
+ - Uses standard ONNX Sub layer in standard domain, and
28
+ makes relevant additional changes to the graph.
29
+ """
30
+
31
+ def __init__(
32
+ self: SubQuantizer,
33
+ new_initializers: list[onnx.TensorProto] | None = None,
34
+ ) -> None:
35
+ super().__init__()
36
+ # Only replace if caller provided something
37
+ if new_initializers is not None:
38
+ self.new_initializers = new_initializers
39
+
40
+ def quantize(
41
+ self: SubQuantizer,
42
+ node: onnx.NodeProto,
43
+ graph: onnx.GraphProto,
44
+ scale_config: ScaleConfig,
45
+ initializer_map: dict[str, onnx.TensorProto],
46
+ ) -> list[onnx.NodeProto]:
47
+ return QuantizeSub.quantize(self, node, graph, scale_config, initializer_map)
48
+
49
+ def check_supported(
50
+ self: SubQuantizer,
51
+ node: onnx.NodeProto,
52
+ initializer_map: dict[str, onnx.TensorProto] | None = None,
53
+ ) -> None:
54
+ pass
@@ -17,13 +17,18 @@ from python.core.model_processing.onnx_quantizer.layers.base import (
17
17
  PassthroughQuantizer,
18
18
  ScaleConfig,
19
19
  )
20
+ from python.core.model_processing.onnx_quantizer.layers.batchnorm import (
21
+ BatchnormQuantizer,
22
+ )
20
23
  from python.core.model_processing.onnx_quantizer.layers.constant import (
21
24
  ConstantQuantizer,
22
25
  )
23
26
  from python.core.model_processing.onnx_quantizer.layers.conv import ConvQuantizer
24
27
  from python.core.model_processing.onnx_quantizer.layers.gemm import GemmQuantizer
25
28
  from python.core.model_processing.onnx_quantizer.layers.maxpool import MaxpoolQuantizer
29
+ from python.core.model_processing.onnx_quantizer.layers.mul import MulQuantizer
26
30
  from python.core.model_processing.onnx_quantizer.layers.relu import ReluQuantizer
31
+ from python.core.model_processing.onnx_quantizer.layers.sub import SubQuantizer
27
32
 
28
33
 
29
34
  class ONNXOpQuantizer:
@@ -69,6 +74,8 @@ class ONNXOpQuantizer:
69
74
 
70
75
  # Register handlers
71
76
  self.register("Add", AddQuantizer(self.new_initializers))
77
+ self.register("Sub", SubQuantizer(self.new_initializers))
78
+ self.register("Mul", MulQuantizer(self.new_initializers))
72
79
  self.register("Conv", ConvQuantizer(self.new_initializers))
73
80
  self.register("Relu", ReluQuantizer())
74
81
  self.register("Reshape", PassthroughQuantizer())
@@ -76,6 +83,7 @@ class ONNXOpQuantizer:
76
83
  self.register("Constant", ConstantQuantizer())
77
84
  self.register("MaxPool", MaxpoolQuantizer())
78
85
  self.register("Flatten", PassthroughQuantizer())
86
+ self.register("BatchNormalization", BatchnormQuantizer(self.new_initializers))
79
87
 
80
88
  def register(
81
89
  self: ONNXOpQuantizer,
@@ -203,3 +211,32 @@ class ONNXOpQuantizer:
203
211
  dict[str, onnx.TensorProto]: Map from initializer name to tensors in graph.
204
212
  """
205
213
  return {init.name: init for init in model.graph.initializer}
214
+
215
+ def apply_pre_analysis_transforms(
216
+ self: ONNXOpQuantizer,
217
+ model: onnx.ModelProto,
218
+ scale_exponent: int,
219
+ scale_base: int,
220
+ ) -> onnx.ModelProto:
221
+ """
222
+ Give each registered handler a chance to rewrite the model before analysis.
223
+ """
224
+ graph = model.graph
225
+ initializer_map = self.get_initializer_map(model)
226
+
227
+ # We allow handlers to modify graph in-place.
228
+ # (Nodes may be replaced, removed, or new nodes added.)
229
+ for node in list(graph.node):
230
+ handler = self.handlers.get(node.op_type)
231
+ if handler and hasattr(handler, "pre_analysis_transform"):
232
+ handler.pre_analysis_transform(
233
+ node,
234
+ graph,
235
+ initializer_map,
236
+ scale_exponent=scale_exponent,
237
+ scale_base=scale_base,
238
+ )
239
+ # Refresh map if transforms may add initializers
240
+ initializer_map = self.get_initializer_map(model)
241
+
242
+ return model
@@ -247,12 +247,12 @@ def export_onnx(
247
247
 
248
248
 
249
249
  def write_input_json(json_path: Path, input_shape: tuple[int] = (1, 4, 28, 28)) -> None:
250
- """Write a zero-valued input tensor to JSON alongside its [N,C,H,W] shape."""
250
+ """Write a zero-valued input tensor to JSON without shape information."""
251
251
  json_path.parent.mkdir(parents=True, exist_ok=True)
252
252
  n, c, h, w = input_shape
253
253
  arr = [0.0] * (n * c * h * w)
254
254
  with json_path.open("w", encoding="utf-8") as f:
255
- json.dump({"input": arr, "shape": [n, c, h, w]}, f)
255
+ json.dump({"input": arr}, f)
256
256
 
257
257
 
258
258
  def run_bench(
@@ -0,0 +1,190 @@
1
+ import numpy as np
2
+
3
+ from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
4
+ from python.tests.onnx_quantizer_tests.layers.base import (
5
+ BaseLayerConfigProvider,
6
+ LayerTestConfig,
7
+ LayerTestSpec,
8
+ e2e_test,
9
+ valid_test,
10
+ )
11
+
12
+
13
+ class BatchNormConfigProvider(BaseLayerConfigProvider):
14
+ """Test configuration provider for BatchNorm (ONNX BatchNormalization op)"""
15
+
16
+ @property
17
+ def layer_name(self) -> str:
18
+ return "BatchNormalization"
19
+
20
+ def get_config(self) -> LayerTestConfig:
21
+ rng = np.random.default_rng(TEST_RNG_SEED)
22
+
23
+ # default shapes: N x C x H x W
24
+ default_input_shape = [1, 3, 4, 4]
25
+ c = default_input_shape[1]
26
+
27
+ # typical required initializers (scale, bias, mean, var) are length C
28
+ return LayerTestConfig(
29
+ op_type="BatchNormalization",
30
+ valid_inputs=["X", "scale", "B", "input_mean", "input_var"],
31
+ valid_attributes={
32
+ "epsilon": 1e-5,
33
+ "momentum": 0.9,
34
+ "training_mode": 0,
35
+ },
36
+ required_initializers={
37
+ # Defaults are stored as numpy arrays with shape (C,)
38
+ "scale": rng.normal(1.0, 0.5, c).astype(np.float32),
39
+ "B": rng.normal(0.0, 0.5, c).astype(np.float32),
40
+ "input_mean": rng.normal(0.0, 1.0, c).astype(np.float32),
41
+ "input_var": np.abs(rng.normal(1.0, 0.5, c)).astype(np.float32),
42
+ },
43
+ input_shapes={"X": default_input_shape},
44
+ output_shapes={"batchnormalization_output": default_input_shape},
45
+ )
46
+
47
+ def get_test_specs(self) -> list[LayerTestSpec]:
48
+ rng = np.random.default_rng(TEST_RNG_SEED)
49
+ c = 3
50
+
51
+ return [
52
+ # Basic valid tests
53
+ valid_test("basic_inference")
54
+ .description("Basic BatchNormalization inference: standard shapes")
55
+ .tags("basic", "inference", "batchnorm")
56
+ .build(),
57
+ valid_test("different_input_shape")
58
+ .description("Inference with different spatial dims")
59
+ .override_input_shapes(X=[2, c, 8, 8])
60
+ .override_output_shapes(batchnormalization_output=[2, c, 8, 8])
61
+ .tags("inference", "spatial")
62
+ .build(),
63
+ valid_test("epsilon_variation")
64
+ .description("Inference with larger epsilon for numerical stability")
65
+ .override_attrs(epsilon=1e-3)
66
+ .tags("epsilon")
67
+ .build(),
68
+ valid_test("momentum_variation")
69
+ .description(
70
+ "Inference with non-default momentum (has no effect in inference mode)",
71
+ )
72
+ .override_attrs(momentum=0.5)
73
+ .tags("momentum")
74
+ .build(),
75
+ valid_test("zero_mean_input")
76
+ .description("Input with zero mean")
77
+ .override_initializer("input_mean", np.zeros((c,), dtype=np.float32))
78
+ .tags("edge", "zero_mean")
79
+ .build(),
80
+ # Scalar / broadcast style tests
81
+ valid_test("per_channel_zero_variance")
82
+ .description(
83
+ "Edge case: very small variance values (clamped by epsilon), inference",
84
+ )
85
+ .override_initializer("input_var", np.full((c,), 1e-8, dtype=np.float32))
86
+ .override_attrs(epsilon=1e-5)
87
+ .tags("edge", "small_variance")
88
+ .build(),
89
+ # E2E tests that set explicit initializer values
90
+ e2e_test("e2e_inference")
91
+ .description("E2E inference test with explicit initializers")
92
+ .override_input_shapes(X=[1, c, 2, 2])
93
+ .override_output_shapes(batchnormalization_output=[1, c, 2, 2])
94
+ .override_initializer("scale", rng.normal(1.0, 0.1, c).astype(np.float32))
95
+ .override_initializer("B", rng.normal(0.0, 0.1, c).astype(np.float32))
96
+ .override_initializer(
97
+ "input_mean",
98
+ rng.normal(0.0, 0.1, c).astype(np.float32),
99
+ )
100
+ .override_initializer(
101
+ "input_var",
102
+ np.abs(rng.normal(0.5, 0.2, c)).astype(np.float32),
103
+ )
104
+ .tags("e2e", "inference")
105
+ .build(),
106
+ e2e_test("e2e_inference_small_2x2")
107
+ .description("E2E inference with small 2x2 spatial input")
108
+ .override_input_shapes(X=[1, 3, 2, 2])
109
+ .override_output_shapes(batchnormalization_output=[1, 3, 2, 2])
110
+ .override_initializer("scale", np.array([1.0, 0.9, 1.1], dtype=np.float32))
111
+ .override_initializer("B", np.array([0.0, 0.1, -0.1], dtype=np.float32))
112
+ .override_initializer(
113
+ "input_mean",
114
+ np.array([0.5, -0.5, 0.0], dtype=np.float32),
115
+ )
116
+ .override_initializer(
117
+ "input_var",
118
+ np.array([0.25, 0.5, 0.1], dtype=np.float32),
119
+ )
120
+ .tags("e2e", "small", "2x2")
121
+ .build(),
122
+ e2e_test("e2e_inference_wide_input")
123
+ .description("E2E inference with wider input shape (C=4, H=2, W=8)")
124
+ .override_input_shapes(X=[2, 4, 2, 8])
125
+ .override_output_shapes(batchnormalization_output=[2, 4, 2, 8])
126
+ .override_initializer(
127
+ "scale",
128
+ np.array([1.0, 0.8, 1.2, 0.9], dtype=np.float32),
129
+ )
130
+ .override_initializer(
131
+ "B",
132
+ np.array([0.0, 0.1, -0.1, 0.05], dtype=np.float32),
133
+ )
134
+ .override_initializer(
135
+ "input_mean",
136
+ np.array([0.0, 0.5, -0.5, 0.2], dtype=np.float32),
137
+ )
138
+ .override_initializer(
139
+ "input_var",
140
+ np.array([1.0, 0.5, 0.25, 0.1], dtype=np.float32),
141
+ )
142
+ .tags("e2e", "wide", "C4")
143
+ .build(),
144
+ e2e_test("e2e_inference_batch2_channels3")
145
+ .description("E2E inference with batch size 2 and 3 channels")
146
+ .override_input_shapes(X=[2, 3, 4, 4])
147
+ .override_output_shapes(batchnormalization_output=[2, 3, 4, 4])
148
+ .override_initializer("scale", np.array([0.5, 1.0, 1.5], dtype=np.float32))
149
+ .override_initializer("B", np.array([0.0, 0.0, 0.0], dtype=np.float32))
150
+ .override_initializer(
151
+ "input_mean",
152
+ np.array([-0.5, 0.0, 0.5], dtype=np.float32),
153
+ )
154
+ .override_initializer(
155
+ "input_var",
156
+ np.array([0.2, 0.5, 0.8], dtype=np.float32),
157
+ )
158
+ .tags("e2e", "batch2", "C3")
159
+ .build(),
160
+ e2e_test("e2e_inference_high_epsilon")
161
+ .description("E2E inference with high epsilon for numerical stability")
162
+ .override_input_shapes(X=[1, 2, 4, 4])
163
+ .override_output_shapes(batchnormalization_output=[1, 2, 4, 4])
164
+ .override_initializer("scale", np.array([1.0, 1.0], dtype=np.float32))
165
+ .override_initializer("B", np.array([0.1, -0.1], dtype=np.float32))
166
+ .override_initializer("input_mean", np.array([0.0, 0.5], dtype=np.float32))
167
+ .override_initializer(
168
+ "input_var",
169
+ np.array([0.0, 0.0], dtype=np.float32),
170
+ ) # tiny variance
171
+ .override_attrs(epsilon=1e-2)
172
+ .tags("e2e", "high_epsilon", "numerical_stability")
173
+ .build(),
174
+ e2e_test("e2e_inference_non_square")
175
+ .description("E2E inference with non-square spatial dimensions")
176
+ .override_input_shapes(X=[1, 3, 2, 5])
177
+ .override_output_shapes(batchnormalization_output=[1, 3, 2, 5])
178
+ .override_initializer("scale", np.array([1.0, 0.9, 1.1], dtype=np.float32))
179
+ .override_initializer("B", np.array([0.0, 0.1, -0.1], dtype=np.float32))
180
+ .override_initializer(
181
+ "input_mean",
182
+ np.array([0.1, -0.1, 0.0], dtype=np.float32),
183
+ )
184
+ .override_initializer(
185
+ "input_var",
186
+ np.array([0.5, 0.25, 0.75], dtype=np.float32),
187
+ )
188
+ .tags("e2e", "non_square", "C3")
189
+ .build(),
190
+ ]
@@ -0,0 +1,102 @@
1
+ import numpy as np
2
+
3
+ from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
4
+ from python.tests.onnx_quantizer_tests.layers.base import (
5
+ BaseLayerConfigProvider,
6
+ LayerTestConfig,
7
+ LayerTestSpec,
8
+ e2e_test,
9
+ edge_case_test,
10
+ valid_test,
11
+ )
12
+
13
+
14
+ class MulConfigProvider(BaseLayerConfigProvider):
15
+ """Test configuration provider for Mul layer"""
16
+
17
+ @property
18
+ def layer_name(self) -> str:
19
+ return "Mul"
20
+
21
+ def get_config(self) -> LayerTestConfig:
22
+ return LayerTestConfig(
23
+ op_type="Mul",
24
+ valid_inputs=["A", "B"],
25
+ valid_attributes={}, # Mul has no layer-specific attributes
26
+ required_initializers={},
27
+ input_shapes={
28
+ "A": [1, 3, 4, 4],
29
+ "B": [1, 3, 4, 4],
30
+ },
31
+ output_shapes={
32
+ "mul_output": [1, 3, 4, 4],
33
+ },
34
+ )
35
+
36
+ def get_test_specs(self) -> list[LayerTestSpec]:
37
+ rng = np.random.default_rng(TEST_RNG_SEED)
38
+ return [
39
+ # --- VALID TESTS ---
40
+ valid_test("basic")
41
+ .description("Basic elementwise Mul of two same-shaped tensors")
42
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4])
43
+ .tags("basic", "elementwise", "Mul")
44
+ .build(),
45
+ valid_test("broadcast_mul")
46
+ .description("mul with Numpy-style broadcasting along spatial dimensions")
47
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1])
48
+ .tags("broadcast", "elementwise", "mul", "onnx14")
49
+ .build(),
50
+ valid_test("initializer_mul")
51
+ .description(
52
+ "mul where second input (B) is a tensor initializer instead of input",
53
+ )
54
+ .override_input_shapes(A=[1, 3, 4, 4])
55
+ .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4)))
56
+ .tags("initializer", "elementwise", "mul", "onnxruntime")
57
+ .build(),
58
+ valid_test("scalar_mul")
59
+ .description("mul scalar (initializer) to tensor")
60
+ .override_input_shapes(A=[1, 3, 4, 4])
61
+ .override_initializer("B", np.array([2.0], dtype=np.float32))
62
+ .tags("scalar", "elementwise", "mul")
63
+ .build(),
64
+ # # --- E2E TESTS ---
65
+ e2e_test("e2e_mul")
66
+ .description("End-to-end mul test with random inputs")
67
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4])
68
+ .override_output_shapes(mul_output=[1, 3, 4, 4])
69
+ .tags("e2e", "mul", "2d")
70
+ .build(),
71
+ e2e_test("e2e_initializer_mul")
72
+ .description(
73
+ "mul where second input (B) is a tensor initializer instead of input",
74
+ )
75
+ .override_input_shapes(A=[1, 3, 4, 4])
76
+ .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4)))
77
+ .tags("initializer", "elementwise", "mul", "onnxruntime")
78
+ .build(),
79
+ e2e_test("e2e_broadcast_mul")
80
+ .description("mul with Numpy-style broadcasting along spatial dimensions")
81
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1])
82
+ .tags("broadcast", "elementwise", "mul", "onnx14")
83
+ .build(),
84
+ e2e_test("e2e_scalar_mul")
85
+ .description("mul scalar (initializer) to tensor")
86
+ .override_input_shapes(A=[1, 3, 4, 4])
87
+ .override_initializer("B", np.array([2.0], dtype=np.float32))
88
+ .tags("scalar", "elementwise", "mul")
89
+ .build(),
90
+ # # --- EDGE CASES ---
91
+ edge_case_test("empty_tensor")
92
+ .description("mul with empty tensor input (zero elements)")
93
+ .override_input_shapes(A=[0], B=[0])
94
+ .tags("edge", "empty", "mul")
95
+ .build(),
96
+ edge_case_test("large_tensor")
97
+ .description("Large tensor mul performance/stress test")
98
+ .override_input_shapes(A=[1, 64, 256, 256], B=[1, 64, 256, 256])
99
+ .tags("large", "performance", "mul")
100
+ .skip("Performance test, skipped by default")
101
+ .build(),
102
+ ]
@@ -0,0 +1,102 @@
1
+ import numpy as np
2
+
3
+ from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
4
+ from python.tests.onnx_quantizer_tests.layers.base import (
5
+ BaseLayerConfigProvider,
6
+ LayerTestConfig,
7
+ LayerTestSpec,
8
+ e2e_test,
9
+ edge_case_test,
10
+ valid_test,
11
+ )
12
+
13
+
14
+ class SubConfigProvider(BaseLayerConfigProvider):
15
+ """Test configuration provider for Sub layer"""
16
+
17
+ @property
18
+ def layer_name(self) -> str:
19
+ return "Sub"
20
+
21
+ def get_config(self) -> LayerTestConfig:
22
+ return LayerTestConfig(
23
+ op_type="Sub",
24
+ valid_inputs=["A", "B"],
25
+ valid_attributes={}, # Sub has no layer-specific attributes
26
+ required_initializers={},
27
+ input_shapes={
28
+ "A": [1, 3, 4, 4],
29
+ "B": [1, 3, 4, 4],
30
+ },
31
+ output_shapes={
32
+ "sub_output": [1, 3, 4, 4],
33
+ },
34
+ )
35
+
36
+ def get_test_specs(self) -> list[LayerTestSpec]:
37
+ rng = np.random.default_rng(TEST_RNG_SEED)
38
+ return [
39
+ # --- VALID TESTS ---
40
+ valid_test("basic")
41
+ .description("Basic elementwise Sub of two same-shaped tensors")
42
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4])
43
+ .tags("basic", "elementwise", "Sub")
44
+ .build(),
45
+ valid_test("broadcast_Sub")
46
+ .description("Sub with Numpy-style broadcasting along spatial dimensions")
47
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1])
48
+ .tags("broadcast", "elementwise", "Sub", "onnx14")
49
+ .build(),
50
+ valid_test("initializer_Sub")
51
+ .description(
52
+ "Sub where second input (B) is a tensor initializer instead of input",
53
+ )
54
+ .override_input_shapes(A=[1, 3, 4, 4])
55
+ .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4)))
56
+ .tags("initializer", "elementwise", "Sub", "onnxruntime")
57
+ .build(),
58
+ valid_test("scalar_Sub")
59
+ .description("Sub scalar (initializer) to tensor")
60
+ .override_input_shapes(A=[1, 3, 4, 4])
61
+ .override_initializer("B", np.array([2.0], dtype=np.float32))
62
+ .tags("scalar", "elementwise", "Sub")
63
+ .build(),
64
+ # --- E2E TESTS ---
65
+ e2e_test("e2e_Sub")
66
+ .description("End-to-end Sub test with random inputs")
67
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 4, 4])
68
+ .override_output_shapes(sub_output=[1, 3, 4, 4])
69
+ .tags("e2e", "Sub", "2d")
70
+ .build(),
71
+ e2e_test("e2e_initializer_Sub")
72
+ .description(
73
+ "Sub where second input (B) is a tensor initializer instead of input",
74
+ )
75
+ .override_input_shapes(A=[1, 3, 4, 4])
76
+ .override_initializer("B", rng.normal(0, 1, (1, 3, 4, 4)))
77
+ .tags("initializer", "elementwise", "Sub", "onnxruntime")
78
+ .build(),
79
+ e2e_test("e2e_broadcast_Sub")
80
+ .description("Sub with Numpy-style broadcasting along spatial dimensions")
81
+ .override_input_shapes(A=[1, 3, 4, 4], B=[1, 3, 1, 1])
82
+ .tags("broadcast", "elementwise", "Sub", "onnx14")
83
+ .build(),
84
+ e2e_test("e2e_scalar_Sub")
85
+ .description("Sub scalar (initializer) to tensor")
86
+ .override_input_shapes(A=[1, 3, 4, 4])
87
+ .override_initializer("B", np.array([2.0], dtype=np.float32))
88
+ .tags("scalar", "elementwise", "Sub")
89
+ .build(),
90
+ # # --- EDGE CASES ---
91
+ edge_case_test("empty_tensor")
92
+ .description("Sub with empty tensor input (zero elements)")
93
+ .override_input_shapes(A=[0], B=[0])
94
+ .tags("edge", "empty", "Sub")
95
+ .build(),
96
+ edge_case_test("large_tensor")
97
+ .description("Large tensor Sub performance/stress test")
98
+ .override_input_shapes(A=[1, 64, 256, 256], B=[1, 64, 256, 256])
99
+ .tags("large", "performance", "Sub")
100
+ .skip("Performance test, skipped by default")
101
+ .build(),
102
+ ]
@@ -139,6 +139,8 @@ class TestQuantize(BaseQuantizerTest):
139
139
  node: NodeProto,
140
140
  result_node: NodeProto,
141
141
  ) -> bool:
142
+ if node.op_type == "BatchNormalization":
143
+ pytest.skip(f"{node.op_type} alters the node structure by design")
142
144
  if node.op_type in result_node.op_type:
143
145
  # Assert there are no less attributes in the new node
144
146
  assert len(node.attribute) <= len(result_node.attribute)