JSTprove 1.1.0__py3-none-macosx_11_0_arm64.whl → 1.3.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/METADATA +3 -3
  2. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/RECORD +40 -26
  3. python/core/binaries/onnx_generic_circuit_1-3-0 +0 -0
  4. python/core/circuits/base.py +29 -12
  5. python/core/circuits/errors.py +1 -2
  6. python/core/model_processing/converters/base.py +3 -3
  7. python/core/model_processing/converters/onnx_converter.py +28 -27
  8. python/core/model_processing/onnx_custom_ops/__init__.py +5 -4
  9. python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
  10. python/core/model_processing/onnx_custom_ops/mul.py +66 -0
  11. python/core/model_processing/onnx_quantizer/exceptions.py +2 -2
  12. python/core/model_processing/onnx_quantizer/layers/base.py +101 -0
  13. python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
  14. python/core/model_processing/onnx_quantizer/layers/clip.py +92 -0
  15. python/core/model_processing/onnx_quantizer/layers/max.py +49 -0
  16. python/core/model_processing/onnx_quantizer/layers/min.py +54 -0
  17. python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
  18. python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
  19. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -0
  20. python/core/model_templates/circuit_template.py +48 -38
  21. python/core/utils/errors.py +1 -1
  22. python/core/utils/scratch_tests.py +29 -23
  23. python/scripts/gen_and_bench.py +2 -2
  24. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +18 -14
  25. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +11 -13
  26. python/tests/circuit_parent_classes/test_ort_custom_layers.py +35 -53
  27. python/tests/onnx_quantizer_tests/layers/base.py +1 -3
  28. python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
  29. python/tests/onnx_quantizer_tests/layers/clip_config.py +127 -0
  30. python/tests/onnx_quantizer_tests/layers/max_config.py +100 -0
  31. python/tests/onnx_quantizer_tests/layers/min_config.py +94 -0
  32. python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
  33. python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
  34. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +6 -5
  35. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +8 -1
  36. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +17 -8
  37. python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
  38. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/WHEEL +0 -0
  39. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/entry_points.txt +0 -0
  40. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/licenses/LICENSE +0 -0
  41. {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: JSTprove
3
- Version: 1.1.0
3
+ Version: 1.3.0
4
4
  Summary: Zero-knowledge proofs of ML inference on ONNX models
5
5
  Author: Inference Labs Inc
6
6
  Requires-Python: >=3.10
@@ -45,7 +45,7 @@ Dynamic: license-file
45
45
  Zero-knowledge proofs of ML inference on **ONNX** models — powered by [Polyhedra Network’s **Expander**](https://github.com/PolyhedraZK/Expander) (GKR/sum-check prover) and [**Expander Compiler Collection (ECC)**](https://github.com/PolyhedraZK/ExpanderCompilerCollection).
46
46
 
47
47
  * 🎯 **You bring ONNX** → we quantize, compile to a circuit, generate a witness, prove, and verify — via a simple CLI.
48
- * ✅ Supported ops (current): **Conv2D**, **GEMM/MatMul (FC)**, **ReLU**, **MaxPool2D**, **Add**.
48
+ * ✅ Supported ops (current): **Conv2D**, **GEMM/MatMul (FC)**, **ReLU**, **MaxPool2D**, **Add**, **Mul**, **Sub**, **BatchNorm**.
49
49
  * 🧰 CLI details: see **[docs/cli.md](docs/cli.md)**
50
50
 
51
51
  👉 Just want to see it in action? Jump to [Quickstart (LeNet demo)](#quickstart-lenet-demo).<br>
@@ -85,7 +85,7 @@ You provide an **ONNX** model and inputs; JSTprove handles **quantization**, **c
85
85
  ### High-level architecture
86
86
 
87
87
  * **Python pipeline:** Converts **ONNX → quantized ONNX**, prepares I/O, drives the Rust runner, exposes the **CLI**.
88
- * **Rust crate:** `rust/jstprove_circuits` implements layer circuits (Conv2D, ReLU, MaxPool2D, GEMM/FC) and a runner.
88
+ * **Rust crate:** `rust/jstprove_circuits` implements layer circuits (Conv2D, ReLU, MaxPool2D, GEMM/FC, BatchNorm) and a runner.
89
89
  * **Circuit frontend:** [ECC](https://github.com/PolyhedraZK/ExpanderCompilerCollection) Rust API for arithmetic circuits.
90
90
  * **Prover backend:** [Expander](https://github.com/PolyhedraZK/Expander) (GKR/sum-check prover/verification).
91
91
 
@@ -1,49 +1,57 @@
1
- jstprove-1.1.0.dist-info/licenses/LICENSE,sha256=UXQRcYRUH-PfN27n3P-FMaZFY6jr9jFPKcwT7CWbljw,1160
1
+ jstprove-1.3.0.dist-info/licenses/LICENSE,sha256=UXQRcYRUH-PfN27n3P-FMaZFY6jr9jFPKcwT7CWbljw,1160
2
2
  python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  python/core/__init__.py,sha256=RlfbqGAaUulKl44QGMCkkGJBQZ8R_AgC5bU5zS7BjnA,97
4
4
  python/core/binaries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  python/core/binaries/expander-exec,sha256=C_1JcezdfLp9sFOQ2z3wp2gcq1k8zjIR09CxJKGGIuM,7095168
6
- python/core/binaries/onnx_generic_circuit_1-1-0,sha256=2YBhVx-neun-Dmx3ntyLq20qwsLrY9coOcU2bNLprZ0,3086160
6
+ python/core/binaries/onnx_generic_circuit_1-3-0,sha256=qbHC9SV_NuNv-vZs5MwXV0NxXkqTTZlmLH59WYCVrC8,3221088
7
7
  python/core/circuit_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  python/core/circuit_models/generic_onnx.py,sha256=P65UZkfVBTE6YhaQ951S6QoTHPuU5ntDt8QL5pXghvw,8787
9
9
  python/core/circuit_models/simple_circuit.py,sha256=igQrZtQyreyHc26iAgCyDb0TuD2bJAoumYhc1pYPDzQ,4682
10
10
  python/core/circuits/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- python/core/circuits/base.py,sha256=tvCHwk_V2ftEocQkmoK5Nf98Iy0F0Ce1FWp03HhNHfA,41274
12
- python/core/circuits/errors.py,sha256=KzIXyi2ssVvBmXV0Rgn0dBfsTgweKHjeSvP2byRmqGc,5964
11
+ python/core/circuits/base.py,sha256=_XSs2LFyBMZEkBSvRp53zc-XaGOopNxYg2xHQz9sqt0,41991
12
+ python/core/circuits/errors.py,sha256=JDNa23wMwNQDTFY0IpDpHDMZ9gOdjDdQmB4GBhL_DCg,5913
13
13
  python/core/circuits/zk_model_base.py,sha256=5ggOaJjs2_MJvn-PO1cPN3i7U-XR4L-0zJGYuLVKOLc,820
14
14
  python/core/model_processing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  python/core/model_processing/errors.py,sha256=uh2YFjuuU5JM3anMtSTLAH-zjlNAKStmLDZqRUgBWS8,4611
16
16
  python/core/model_processing/converters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- python/core/model_processing/converters/base.py,sha256=eG7iRDbDJJDTG2cCVgYlPlfkpmYPEnMzjGNK9wrA1m0,4303
18
- python/core/model_processing/converters/onnx_converter.py,sha256=BJc6rU3wLHI3imt8yzm8Cngri3KvcBSUbJ3Urw2PoEQ,44560
19
- python/core/model_processing/onnx_custom_ops/__init__.py,sha256=ofecV9pzpDJJl_r6inRw9JOKxtfK2rzzxWahAq9BKXE,475
17
+ python/core/model_processing/converters/base.py,sha256=o6bNwmqD9sOM9taqMb0ed6804RugQiU3va0rY_EA5SE,4265
18
+ python/core/model_processing/converters/onnx_converter.py,sha256=-eXdF6tfluFRxGgnQtJQ8R2309aYX-8z8HzMxk_Qv8I,44340
19
+ python/core/model_processing/onnx_custom_ops/__init__.py,sha256=ZKUC4ToRxgEEMHcTyERATVEN0KSDs-9cM1T-tTw3I1g,525
20
+ python/core/model_processing/onnx_custom_ops/batchnorm.py,sha256=8kg4iGGdt6B_fIJkpt4v5eNFpoHa4bjTB0NnCSmKFvE,1693
20
21
  python/core/model_processing/onnx_custom_ops/conv.py,sha256=6jJm3fcGWzcU4RjVgf179mPFCqsl4C3AR7bqQTffDgA,3464
21
22
  python/core/model_processing/onnx_custom_ops/custom_helpers.py,sha256=2WdnHw9NAoN_6wjIBoAQDyL6wEIlZOqo6ysCZp5DpZs,1844
22
23
  python/core/model_processing/onnx_custom_ops/gemm.py,sha256=bnEUXhqQCEcH4TIfbMTsCTtAlAlRzFvl4jj8g2QZFWU,2674
23
24
  python/core/model_processing/onnx_custom_ops/maxpool.py,sha256=Sd3BwqpGLSVU2iuAAIXAHdI3WO27Aa3g3r29HPiECvM,2319
25
+ python/core/model_processing/onnx_custom_ops/mul.py,sha256=w6X1sl1HnzoUJx2Mm_LaoXGTpvtwXxr3zZDPySVHBcM,1888
24
26
  python/core/model_processing/onnx_custom_ops/onnx_helpers.py,sha256=utnJuc5sgb_z1LgxuY9y2cQbMpdEJ8xOOrcP8DhfDCM,5686
25
27
  python/core/model_processing/onnx_custom_ops/relu.py,sha256=pZsPXC_r0FPggURKDphh8P1IRXY0w4hH7ExBmYTlWjE,1202
26
28
  python/core/model_processing/onnx_quantizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
- python/core/model_processing/onnx_quantizer/exceptions.py,sha256=_YaXXEMbfD1P8N86L5YIz3uCilkuzlhv_2lU90T4FfA,5646
28
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=POoDEBFzkr145P4INgAux2LQY2GdpsBtRpw_UuKVNhw,7679
29
+ python/core/model_processing/onnx_quantizer/exceptions.py,sha256=vzxBRbpvk4ZZbgacDISnqmQQKj7Ls46V08ilHnhaJy0,5645
30
+ python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=5I67frJn4j2T1LTvODHixQK4VaqazJFJ0T1BCvqLPgg,9655
29
31
  python/core/model_processing/onnx_quantizer/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
32
  python/core/model_processing/onnx_quantizer/layers/add.py,sha256=AGxzqMa0jABIEKOIgPqEAA7EpZtynQtnD9nxI2NHc0s,1409
31
- python/core/model_processing/onnx_quantizer/layers/base.py,sha256=LvyTvmR2w6jYSJiBvyFluaDgL_Voc6dZ00TTWi6V7Tc,17426
33
+ python/core/model_processing/onnx_quantizer/layers/base.py,sha256=zUUAZpXCtbxbhHzuYczTNZPe-xWr6TxpmoAIDe4kCo4,21176
34
+ python/core/model_processing/onnx_quantizer/layers/batchnorm.py,sha256=KSBDPHd52f5Qyf-cnIDFPmfzssaJgMPiTmpIWEdM41U,7718
35
+ python/core/model_processing/onnx_quantizer/layers/clip.py,sha256=HrhiLtqC3cIAvU0wRCqp8_8ZSFH8a3F1Jf_qkXlY44s,3043
32
36
  python/core/model_processing/onnx_quantizer/layers/constant.py,sha256=l1IvgvXkmFMiaBsym8wchPF-y1ZH-c5PmFUy92IXWok,3694
33
37
  python/core/model_processing/onnx_quantizer/layers/conv.py,sha256=TlUpCRO6PPqH7MPkIrEiEcVfzuiN1WMYEiNIjhYXtWM,4451
34
38
  python/core/model_processing/onnx_quantizer/layers/gemm.py,sha256=7fCUMv8OLVZ45a2lYjA2XNvcW3By7lSbX7zeForNK-0,3950
39
+ python/core/model_processing/onnx_quantizer/layers/max.py,sha256=3gUxrdXwcVAtgR-_j4xQ0085Wj0oEBLT897TImxF2d4,1343
35
40
  python/core/model_processing/onnx_quantizer/layers/maxpool.py,sha256=PJ8hZPPBpfWV_RZdySl50-BU8TATjcg8Tg_mrAVS1Ic,4916
41
+ python/core/model_processing/onnx_quantizer/layers/min.py,sha256=cQbXzGOApR6HUJZMARXy87W8IbUC562jnAQm8J8ynQI,1709
42
+ python/core/model_processing/onnx_quantizer/layers/mul.py,sha256=qHsmnYPH-c5uiFeDCvV6e1xSgmIXJ64Sjvh0LYDYEqQ,1396
36
43
  python/core/model_processing/onnx_quantizer/layers/relu.py,sha256=d-5fyeKNLTgKKnqCwURpxkjl7QdbJQpuovtCFBM03FA,1685
44
+ python/core/model_processing/onnx_quantizer/layers/sub.py,sha256=M7D98TZBNP9-2R9MX6mcpYlrWFxTiX9JCs3XNcg1U-Q,1409
37
45
  python/core/model_templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
- python/core/model_templates/circuit_template.py,sha256=X8bA4AdmtQeb3ltU74GaWYfrOFhqs_DOpUqRMFXLAD8,2352
46
+ python/core/model_templates/circuit_template.py,sha256=OAqMRshi9OiJYoqpjkg5tUfNf18MfZmhsxxD6SANm_4,2106
39
47
  python/core/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
48
  python/core/utils/benchmarking_helpers.py,sha256=0nT38SCrjP_BlvJODsc9twF9ZmIFg_1sAvSyeNfv4mQ,5235
41
49
  python/core/utils/constants.py,sha256=Qu5_6OUe1XIsL-IY5_4923eN7x1-SPv6ohQonztAobA,102
42
- python/core/utils/errors.py,sha256=vTlluhbSqmyI5e1JNLEZ1mQ-dG_Wbxe4p5l4aa59zAY,3739
50
+ python/core/utils/errors.py,sha256=Uf57cRKpot_u5Yr8HRmjLmInkdd_x0x5YpTGBncZgl4,3722
43
51
  python/core/utils/general_layer_functions.py,sha256=tg2WWhmR-4TlKn8OeCu1qNbLf8qdKVP3jl9mhZn_sTg,9781
44
52
  python/core/utils/helper_functions.py,sha256=3JwJa4wHoUBteukDw4bAetqMsQLeJ0_sJ0qIdKy7GCY,37097
45
53
  python/core/utils/model_registry.py,sha256=aZg_9LEqsBXK84oxQ8A3NGZl-9aGnLgfR-kgxkOwV50,4895
46
- python/core/utils/scratch_tests.py,sha256=UYXsWIBh_27OxnyfH9CuxeNFT-OWCK0YpJ-j-8f0QHc,2332
54
+ python/core/utils/scratch_tests.py,sha256=o2VDTk8QBKA3UHHE-h7Ghtoge6kGG7G-8qwvesuTFFc,2281
47
55
  python/core/utils/witness_utils.py,sha256=ukvbF6EaHMPzRQVZad9wQ9gISRwBGQ1hEAHzc5TpGuw,9488
48
56
  python/frontend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
49
57
  python/frontend/cli.py,sha256=lkvhzQC6bv0AgWUypg_cH-JT574r89qgTIsgHDT9GRg,3106
@@ -62,47 +70,53 @@ python/frontend/commands/bench/model.py,sha256=SaIWXAXZbWGbrNqEo5bs4NwgZfMOmmxaC
62
70
  python/frontend/commands/bench/sweep.py,sha256=rl-QBS9eXgQkuPJBhsU4CohfE1PdJvnM8NRhNU7ztQw,5279
63
71
  python/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
64
72
  python/scripts/benchmark_runner.py,sha256=sjbqaLrdjt94AoyQXAxT4FhsN6aRu5idTRQ5uHmZOWM,28593
65
- python/scripts/gen_and_bench.py,sha256=9kcIj-K_nG-G194C68Uig-Yw-p3nYKESACIpWRflmts,16276
73
+ python/scripts/gen_and_bench.py,sha256=V36x7djYmHlveAJgYzMlXwnmF0gAGO3-1mg9PWOmpj8,16249
66
74
  python/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
67
75
  python/tests/test_cli.py,sha256=OiAyG3aBpukk0i5FFWbiKaF42wf-7By-UWDHNjwtsqo,27042
68
76
  python/tests/circuit_e2e_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py,sha256=Ic9hprCn1Rs-XAF-SUmBNEDn65yaCxUK9z5875KPg5o,39416
70
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py,sha256=4toXD0pJMYVZFL1O9JZAJF-iqbi9k1eyuk_goUnchRo,5190
77
+ python/tests/circuit_e2e_tests/circuit_model_developer_test.py,sha256=8hl8SKw7obXplo0jsiKoKIZLxlu1_HhXvGDeSBDBars,39456
78
+ python/tests/circuit_e2e_tests/helper_fns_for_tests.py,sha256=uEThqTsRdNJivHwAv-aJIUtSPlmVHdhMZqZSH1OqhDE,5177
71
79
  python/tests/circuit_e2e_tests/other_e2e_test.py,sha256=amWRa1tIBHdQpd9-XS7vBXG0tkdV_9K9fH-FT5LFh7E,11301
72
80
  python/tests/circuit_parent_classes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
73
81
  python/tests/circuit_parent_classes/test_circuit.py,sha256=5vgcZHD2wY_pIRFNAhEZuBJD4uw2QyTck75Z9CJaACE,45968
74
82
  python/tests/circuit_parent_classes/test_onnx_converter.py,sha256=sJ0o8sducNUtmYKmsqfx7WEsIEd6oNbnWk71rXS_nIU,6575
75
- python/tests/circuit_parent_classes/test_ort_custom_layers.py,sha256=FEEY8nbuMC2xb6WrBsik7TeDde6SlMwwG9PKSqmCymo,3980
83
+ python/tests/circuit_parent_classes/test_ort_custom_layers.py,sha256=PBKjt6Mu3lRco4ijD2BLwAHPWRFic-OUwWPVsvBoEpU,3042
76
84
  python/tests/onnx_quantizer_tests/__init__.py,sha256=IZPGWHgjoay3gM1p2WJNh5cnZ79EP2VP-bcKy8AfJjY,18
77
85
  python/tests/onnx_quantizer_tests/test_base_layer.py,sha256=Ro7k-eUbGCyfIZ-OVNjLlCIz3mb02uHFWboFuWOdXKs,6526
78
86
  python/tests/onnx_quantizer_tests/test_exceptions.py,sha256=pwhARalEXx7REkcnIVZPi-4J1wgzgZN4xG-wLsx4rTs,3473
79
87
  python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py,sha256=m6mNe1KDRFIE2P0YURTIAim9-Di0BoPPAaaOOlorDIk,7367
80
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py,sha256=M8N3KxApfIpZIu2Swh_z8eSy3DDqB3XxebN685hHHlw,4052
88
+ python/tests/onnx_quantizer_tests/test_registered_quantizers.py,sha256=lw_jYSbQ9ZM9P-jSt_LFSuve9vQ22cLtSui0W3zGqpo,4209
81
89
  python/tests/onnx_quantizer_tests/testing_helper_functions.py,sha256=N0fQv2pYzUCVZ7wkcR8gEKs5zTXT1hWrK-HKSTQYvYU,534
82
90
  python/tests/onnx_quantizer_tests/layers/__init__.py,sha256=xP-RmW6LfIANgK1s9Q0KZet2yvNr-3c6YIVLAAQqGUY,404
83
91
  python/tests/onnx_quantizer_tests/layers/add_config.py,sha256=T3tGddupDtrvLck2SL2yETDblNtv0aU7Tl7fNyZUhO4,4133
84
- python/tests/onnx_quantizer_tests/layers/base.py,sha256=uLCqhMcBA7zWiRSLRMNKKb4A9N27l-RUqSEEQ8SR3xI,9393
92
+ python/tests/onnx_quantizer_tests/layers/base.py,sha256=3nqmU2PgOdK_mPkz-YHg3idgr-PXYbu5kCIY-Uic5yo,9317
93
+ python/tests/onnx_quantizer_tests/layers/batchnorm_config.py,sha256=P-sZuHAdEfNczcgTeLjqJnEbpqN3dKTsbqvY4-SBqiQ,8231
94
+ python/tests/onnx_quantizer_tests/layers/clip_config.py,sha256=-OuhnUgz6xY4iW1jUR7W-J__Ie9lXI9vplmzp8qXqRc,4973
85
95
  python/tests/onnx_quantizer_tests/layers/constant_config.py,sha256=RdrKNMNZjI3Sk5o8WLNqmBUyYVJRWgtFbQ6oFWMwyQk,1193
86
96
  python/tests/onnx_quantizer_tests/layers/conv_config.py,sha256=H0ioW4H3ei5IK4tKhrA0ffThxJ4K5oO9jIs9A0T0VaM,6005
87
97
  python/tests/onnx_quantizer_tests/layers/factory.py,sha256=WLLEP9ECmSpTliSjhtdWOHcX1xOi6HM10S9Y4re1A74,4844
88
98
  python/tests/onnx_quantizer_tests/layers/flatten_config.py,sha256=Xln5Hh6gyeM5gGRCjLGvIL-u08NEs1tXSF32urCqPfE,2110
89
99
  python/tests/onnx_quantizer_tests/layers/gemm_config.py,sha256=t7nJY-Wnj6YUD821-jaWzgrQVPa6ytwER3hFMsvyY6Y,7294
100
+ python/tests/onnx_quantizer_tests/layers/max_config.py,sha256=vzR8-2wbPGcH0GMmAJ_sXSEdMtZOjVNGufU__N3Jfyw,3906
90
101
  python/tests/onnx_quantizer_tests/layers/maxpool_config.py,sha256=XfTPk_ZQXEzaCjHHymSLVv2HS-PKH1rS9IuyyoEtM78,3176
102
+ python/tests/onnx_quantizer_tests/layers/min_config.py,sha256=izKtCaMXoQHiAfmcGlJRQdKMQz3Su8n0p2mEn0y56Do,3774
103
+ python/tests/onnx_quantizer_tests/layers/mul_config.py,sha256=_Oy4b97ORxFlF3w0BmJ94hNA968HQx2AvwYiASrGPxw,4135
91
104
  python/tests/onnx_quantizer_tests/layers/relu_config.py,sha256=_aHuddDApLUBOa0FiR9h4fNfmMSnH5r4JzOMLW0KaTk,2197
92
105
  python/tests/onnx_quantizer_tests/layers/reshape_config.py,sha256=fZchSqIAy76m7j97wVC_UI6slSpv8nbwukhkbGR2sRE,2203
106
+ python/tests/onnx_quantizer_tests/layers/sub_config.py,sha256=IxF18mG9kjlEiKYSNG912CEcBxOFGxIWoRAwjvBXiRo,4133
93
107
  python/tests/onnx_quantizer_tests/layers_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
94
108
  python/tests/onnx_quantizer_tests/layers_tests/base_test.py,sha256=UgbcT97tgcuTtO1pOADpww9bz_JElKiI2mxLJYKyF1k,2992
95
109
  python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py,sha256=Vxn4LEWHZeGa_vS1-7ptFqSSBb0D-3BG-ETocP4pvsI,3651
96
110
  python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py,sha256=40779aaHgdryVwLlIO18F1d7uSLSXdJUG5Uj_5-xD4U,6712
97
111
  python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py,sha256=t5c_zqO4Ex3HIFWcykX4PTftdKN7UWnEOF5blShL0Ik,1881
98
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py,sha256=Mq1-PBKR3756i9VrFOP5DY3GkRE32D6Hjd1fK9wZdVk,7228
99
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py,sha256=zclzXxtgA5BEmNwSf_aNbJgbsArMXn5WDdlxiMR2-aM,9255
112
+ python/tests/onnx_quantizer_tests/layers_tests/test_integration.py,sha256=xNt2STeXB33NcpteDThwGTSW1Hm15POf8a4aPBSVrvI,7254
113
+ python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py,sha256=DatmgvibQazP100B4NHDu7u-O2-f90juPKvPOXuPnXo,9491
100
114
  python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py,sha256=RfnIIiYbgPbU3620H6MPvSxE3MNR2G1yPELwdWV3mK4,4107
101
115
  python/tests/onnx_quantizer_tests/layers_tests/test_validation.py,sha256=jz-WtIEP-jjUklOOAnznwPUXbf07U2PAMGrhzMWP0JU,1371
102
116
  python/tests/utils_testing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
103
117
  python/tests/utils_testing/test_helper_functions.py,sha256=xmeGQieh4LE9U-CDKBlHhSWqH0cAmmDU3qXNbDkkvms,27192
104
- jstprove-1.1.0.dist-info/METADATA,sha256=3gdOLaD4eYGawv4SuvofjuzBW-y564J4gpNPXHFNY1A,14056
105
- jstprove-1.1.0.dist-info/WHEEL,sha256=jc2C2uw104ioj1TL9cE0YO67_kdAwX4W8JgYPomxr5M,105
106
- jstprove-1.1.0.dist-info/entry_points.txt,sha256=nGcTSO-4q08gPl1IoWdrPaiY7IbO7XvmXKkd34dYHc8,49
107
- jstprove-1.1.0.dist-info/top_level.txt,sha256=J-z0poNcsv31IHB413--iOY8LoHBKiTHeybHX3abokI,7
108
- jstprove-1.1.0.dist-info/RECORD,,
118
+ jstprove-1.3.0.dist-info/METADATA,sha256=CqGuzrQWy_MUYtOcy-0i8mB_2eAKUvQ1R3tjEX-3N4o,14100
119
+ jstprove-1.3.0.dist-info/WHEEL,sha256=jc2C2uw104ioj1TL9cE0YO67_kdAwX4W8JgYPomxr5M,105
120
+ jstprove-1.3.0.dist-info/entry_points.txt,sha256=nGcTSO-4q08gPl1IoWdrPaiY7IbO7XvmXKkd34dYHc8,49
121
+ jstprove-1.3.0.dist-info/top_level.txt,sha256=J-z0poNcsv31IHB413--iOY8LoHBKiTHeybHX3abokI,7
122
+ jstprove-1.3.0.dist-info/RECORD,,
@@ -4,13 +4,12 @@ import logging
4
4
  from pathlib import Path
5
5
  from typing import TYPE_CHECKING, Any
6
6
 
7
- from numpy import asarray, ndarray
7
+ import numpy as np
8
8
 
9
9
  from python.core.utils.errors import ShapeMismatchError
10
10
  from python.core.utils.witness_utils import compare_witness_to_io, load_witness
11
11
 
12
12
  if TYPE_CHECKING:
13
- import numpy as np
14
13
  import torch
15
14
 
16
15
  from python.core.circuits.errors import (
@@ -775,18 +774,18 @@ class Circuit:
775
774
  def reshape_inputs_for_inference(
776
775
  self: Circuit,
777
776
  inputs: dict[str],
778
- ) -> ndarray | dict[str, ndarray]:
777
+ ) -> np.ndarray | dict[str, np.ndarray]:
779
778
  """
780
779
  Reshape input tensors to match the model's expected input shape.
781
780
 
782
781
  Parameters
783
782
  ----------
784
- inputs : dict[str] or ndarray
783
+ inputs : dict[str] or np.ndarray
785
784
  Input tensors or a dictionary of tensors.
786
785
 
787
786
  Returns
788
787
  -------
789
- ndarray or dict[str, ndarray]
788
+ np.ndarray or dict[str, np.ndarray]
790
789
  Reshaped input(s) ready for inference.
791
790
  """
792
791
 
@@ -801,15 +800,33 @@ class Circuit:
801
800
  if isinstance(inputs, dict):
802
801
  if len(inputs) == 1:
803
802
  only_key = next(iter(inputs))
804
- inputs = asarray(inputs[only_key])
803
+ value = np.asarray(inputs[only_key])
804
+
805
+ # If shape is a dict, extract the shape for this key
806
+ if isinstance(shape, dict):
807
+ key_shape = shape.get(only_key, None)
808
+ if key_shape is None:
809
+ raise CircuitConfigurationError(
810
+ missing_attributes=[f"input_shape[{only_key!r}]"],
811
+ )
812
+ shape = key_shape
813
+
814
+ # From here on, treat it as a regular reshape
815
+ inputs = value
805
816
  else:
806
817
  return self._reshape_dict_inputs(inputs, shape)
807
818
 
808
819
  # --- Regular reshape ---
820
+ if not isinstance(shape, (list, tuple)):
821
+ msg = (
822
+ f"Expected list or tuple shape for reshape, got {type(shape).__name__}"
823
+ )
824
+ raise CircuitInputError(msg)
825
+
809
826
  try:
810
- return asarray(inputs).reshape(shape)
827
+ return np.asarray(inputs).reshape(shape)
811
828
  except Exception as e:
812
- raise ShapeMismatchError(shape, list(asarray(inputs).shape)) from e
829
+ raise ShapeMismatchError(shape, list(np.asarray(inputs).shape)) from e
813
830
 
814
831
  def _reshape_dict_inputs(
815
832
  self: Circuit,
@@ -824,7 +841,7 @@ class Circuit:
824
841
  )
825
842
  raise CircuitInputError(msg, parameter="shape", expected="dict")
826
843
  for key, value in inputs.items():
827
- tensor = asarray(value)
844
+ tensor = np.asarray(value)
828
845
  try:
829
846
  inputs[key] = tensor.reshape(shape[key])
830
847
  except Exception as e:
@@ -867,16 +884,16 @@ class Circuit:
867
884
  value = inputs[key]
868
885
 
869
886
  # --- handle unsupported input types BEFORE entering try ---
870
- if not isinstance(value, (ndarray, list, tuple)):
887
+ if not isinstance(value, (np.ndarray, list, tuple)):
871
888
  msg = f"Unsupported input type for key '{key}': {type(value).__name__}"
872
889
  raise CircuitProcessingError(message=msg)
873
890
 
874
891
  try:
875
892
  # Convert to tensor, flatten, and back to list
876
- if isinstance(value, ndarray):
893
+ if isinstance(value, np.ndarray):
877
894
  flattened = value.flatten().tolist()
878
895
  else:
879
- flattened = asarray(value).flatten().tolist()
896
+ flattened = np.asarray(value).flatten().tolist()
880
897
  except Exception as e:
881
898
  msg = f"Failed to flatten input '{key}' (type {type(value).__name__})"
882
899
  raise CircuitProcessingError(message=msg) from e
@@ -1,4 +1,3 @@
1
- # python/core/utils/exceptions.py
2
1
  from __future__ import annotations
3
2
 
4
3
  from python.core.utils.helper_functions import RunType
@@ -68,7 +67,7 @@ class CircuitInputError(CircuitError):
68
67
  actual (any): Actual value encountered (optional).
69
68
  """
70
69
 
71
- def __init__( # noqa: PLR0913
70
+ def __init__(
72
71
  self: CircuitInputError,
73
72
  message: str | None = None,
74
73
  parameter: str | None = None,
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from enum import Enum
5
- from typing import TYPE_CHECKING, Optional, Union
5
+ from typing import TYPE_CHECKING
6
6
 
7
7
  if TYPE_CHECKING:
8
8
  import numpy as np
@@ -16,10 +16,10 @@ class ModelType(Enum):
16
16
 
17
17
  ONNXLayerDict = dict[
18
18
  str,
19
- Union[int, str, list[str], dict[str, list[int]], Optional[list], Optional[dict]],
19
+ int | str | list[str] | dict[str, list[int]] | list | None | dict,
20
20
  ]
21
21
 
22
- CircuitParamsDict = dict[str, Union[int, dict[str, bool]]]
22
+ CircuitParamsDict = dict[str, int | dict[str, bool]]
23
23
 
24
24
 
25
25
  class ModelConverter(ABC):
@@ -247,6 +247,7 @@ class ONNXConverter(ModelConverter):
247
247
 
248
248
  def analyze_layers(
249
249
  self: ONNXConverter,
250
+ model: onnx.ModelProto,
250
251
  output_name_to_shape: dict[str, list[int]] | None = None,
251
252
  ) -> tuple[list[ONNXLayer], list[ONNXLayer]]:
252
253
  """Analyze the onnx model graph into
@@ -268,29 +269,29 @@ class ONNXConverter(ModelConverter):
268
269
  id_count = 0
269
270
  # Apply shape inference on the model
270
271
  if not output_name_to_shape:
271
- inferred_model = shape_inference.infer_shapes(self.model)
272
+ inferred_model = shape_inference.infer_shapes(model)
272
273
  self._onnx_check_model_safely(inferred_model)
273
274
 
274
275
  output_name_to_shape = extract_shape_dict(inferred_model)
275
276
  domain_to_version = {
276
- opset.domain: opset.version for opset in self.model.opset_import
277
+ opset.domain: opset.version for opset in model.opset_import
277
278
  }
278
279
 
279
280
  id_count = 0
280
281
  architecture = self.get_model_architecture(
281
- self.model,
282
+ model,
282
283
  output_name_to_shape,
283
284
  domain_to_version,
284
285
  )
285
286
  w_and_b = self.get_model_w_and_b(
286
- self.model,
287
+ model,
287
288
  output_name_to_shape,
288
289
  id_count,
289
290
  domain_to_version,
290
291
  )
291
292
  except InvalidModelError:
292
293
  raise
293
- except (ValueError, TypeError, RuntimeError, OSError, onnx.ONNXException) as e:
294
+ except (ValueError, TypeError, RuntimeError, OSError) as e:
294
295
  raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
295
296
  except Exception as e:
296
297
  raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
@@ -557,6 +558,7 @@ class ONNXConverter(ModelConverter):
557
558
  output_shapes = {
558
559
  out_name: output_name_to_shape.get(out_name, []) for out_name in outputs
559
560
  }
561
+
560
562
  return ONNXLayer(
561
563
  id=layer_id,
562
564
  name=name,
@@ -605,6 +607,7 @@ class ONNXConverter(ModelConverter):
605
607
  np_data = onnx.numpy_helper.to_array(node, constant_dtype)
606
608
  except (ValueError, TypeError, onnx.ONNXException, Exception) as e:
607
609
  raise SerializationError(
610
+ model_type=self.model_type,
608
611
  tensor_name=node.name,
609
612
  reason=f"Failed to convert tensor: {e!s}",
610
613
  ) from e
@@ -1040,38 +1043,36 @@ class ONNXConverter(ModelConverter):
1040
1043
  ``rescale_config``.
1041
1044
  """
1042
1045
  inferred_model = shape_inference.infer_shapes(self.model)
1043
-
1044
- scaling = BaseOpQuantizer.get_scaling(
1045
- scale_base=getattr(self, "scale_base", 2),
1046
- scale_exponent=(getattr(self, "scale_exponent", 18)),
1047
- )
1046
+ scale_base = getattr(self, "scale_base", 2)
1047
+ scale_exponent = getattr(self, "scale_exponent", 18)
1048
1048
 
1049
1049
  # Check the model and print Y"s shape information
1050
1050
  self._onnx_check_model_safely(inferred_model)
1051
1051
  output_name_to_shape = extract_shape_dict(inferred_model)
1052
- (architecture, w_and_b) = self.analyze_layers(output_name_to_shape)
1053
- for w in w_and_b:
1052
+ scaled_and_transformed_model = self.op_quantizer.apply_pre_analysis_transforms(
1053
+ inferred_model,
1054
+ scale_exponent=scale_exponent,
1055
+ scale_base=scale_base,
1056
+ )
1057
+ # Get layers in correct format
1058
+ (architecture, w_and_b) = self.analyze_layers(
1059
+ scaled_and_transformed_model,
1060
+ output_name_to_shape,
1061
+ )
1062
+
1063
+ def _convert_tensor_to_int_list(w: ONNXLayer) -> list:
1054
1064
  try:
1055
- w_and_b_array = np.asarray(w.tensor)
1056
- except (ValueError, TypeError, Exception) as e:
1065
+ arr = np.asarray(w.tensor).astype(np.int64)
1066
+ return arr.tolist()
1067
+ except Exception as e:
1057
1068
  raise SerializationError(
1058
1069
  tensor_name=getattr(w, "name", None),
1070
+ model_type=self.model_type,
1059
1071
  reason=f"cannot convert to ndarray: {e}",
1060
1072
  ) from e
1061
1073
 
1062
- try:
1063
- # TODO @jsgold-1: We need a better way to distinguish bias tensors from weight tensors # noqa: FIX002, TD003,E501
1064
- if "bias" in w.name:
1065
- w_and_b_scaled = w_and_b_array * scaling * scaling
1066
- else:
1067
- w_and_b_scaled = w_and_b_array * scaling
1068
- w_and_b_out = w_and_b_scaled.astype(np.int64).tolist()
1069
- w.tensor = w_and_b_out
1070
- except (ValueError, TypeError, OverflowError, Exception) as e:
1071
- raise SerializationError(
1072
- tensor_name=getattr(w, "name", None),
1073
- reason=str(e),
1074
- ) from e
1074
+ for w in w_and_b:
1075
+ w.tensor = _convert_tensor_to_int_list(w)
1075
1076
 
1076
1077
  inputs = []
1077
1078
  outputs = []
@@ -1,16 +1,17 @@
1
1
  import importlib
2
2
  import pkgutil
3
- import os
3
+ from pathlib import Path
4
4
 
5
5
  # Get the package name of the current module
6
6
  package_name = __name__
7
7
 
8
8
  # Dynamically import all .py files in this package directory (except __init__.py)
9
- package_dir = os.path.dirname(__file__)
9
+ package_dir = Path(__file__).parent.as_posix()
10
10
 
11
- __all__ = []
11
+
12
+ __all__: list[str] = []
12
13
 
13
14
  for _, module_name, is_pkg in pkgutil.iter_modules([package_dir]):
14
15
  if not is_pkg and (module_name != "custom_helpers"):
15
16
  importlib.import_module(f"{package_name}.{module_name}")
16
- __all__.append(module_name)
17
+ __all__.append(module_name) # noqa: PYI056
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from onnxruntime_extensions import PyCustomOpDef, onnx_op
5
+
6
+ from .custom_helpers import rescaling
7
+
8
+
9
+ @onnx_op(
10
+ op_type="Int64BatchNorm",
11
+ domain="ai.onnx.contrib",
12
+ inputs=[
13
+ PyCustomOpDef.dt_int64, # X (int64)
14
+ PyCustomOpDef.dt_int64, # mul (int64 scaled multiplier)
15
+ PyCustomOpDef.dt_int64, # add (int64 scaled adder)
16
+ PyCustomOpDef.dt_int64, # scaling_factor
17
+ ],
18
+ outputs=[PyCustomOpDef.dt_int64],
19
+ attrs={"rescale": PyCustomOpDef.dt_int64},
20
+ )
21
+ def int64_batchnorm(
22
+ x: np.ndarray,
23
+ mul: np.ndarray,
24
+ add: np.ndarray,
25
+ scaling_factor: np.ndarray | None = None,
26
+ rescale: int | None = None,
27
+ ) -> np.ndarray:
28
+ """
29
+ Int64 BatchNorm (folded into affine transform).
30
+
31
+ Computes:
32
+ Y = X * mul + add
33
+ where mul/add are already scaled to int64.
34
+
35
+ Parameters
36
+ ----------
37
+ x : Input int64 tensor
38
+ mul : Per-channel int64 scale multipliers
39
+ add : Per-channel int64 bias terms
40
+ scaling_factor: factor to rescale
41
+ rescale : Optional flag to apply post-scaling
42
+
43
+ Returns
44
+ -------
45
+ numpy.ndarray (int64)
46
+ """
47
+ try:
48
+ # Broadcasting shapes must match batchnorm layout: NCHW
49
+ # Typically mul/add have shape [C]
50
+ dims_x = len(x.shape)
51
+ dim_ones = (1,) * (dims_x - 2)
52
+ mul = mul.reshape(-1, *dim_ones)
53
+ add = add.reshape(-1, *dim_ones)
54
+
55
+ y = x * mul + add
56
+
57
+ if rescale is not None:
58
+ y = rescaling(scaling_factor, rescale, y)
59
+
60
+ return y.astype(np.int64)
61
+
62
+ except Exception as e:
63
+ msg = f"Int64BatchNorm failed: {e}"
64
+ raise RuntimeError(msg) from e
@@ -0,0 +1,66 @@
1
+ import numpy as np
2
+ from onnxruntime_extensions import PyCustomOpDef, onnx_op
3
+
4
+ from .custom_helpers import rescaling
5
+
6
+
7
+ @onnx_op(
8
+ op_type="Int64Mul",
9
+ domain="ai.onnx.contrib",
10
+ inputs=[
11
+ PyCustomOpDef.dt_int64,
12
+ PyCustomOpDef.dt_int64,
13
+ PyCustomOpDef.dt_int64, # Scalar
14
+ ],
15
+ outputs=[PyCustomOpDef.dt_int64],
16
+ attrs={
17
+ "rescale": PyCustomOpDef.dt_int64,
18
+ },
19
+ )
20
+ def int64_mul(
21
+ a: np.ndarray,
22
+ b: np.ndarray,
23
+ scaling_factor: np.ndarray | None = None,
24
+ rescale: int | None = None,
25
+ ) -> np.ndarray:
26
+ """
27
+ Performs a Mul (hadamard product) operation on int64 input tensors.
28
+
29
+ This function is registered as a custom ONNX operator via onnxruntime_extensions
30
+ and is used in the JSTprove quantized inference pipeline.
31
+ It applies Mul with the rescaling the outputs back to the original scale.
32
+
33
+ Parameters
34
+ ----------
35
+ a : np.ndarray
36
+ First input tensor with dtype int64.
37
+ b : np.ndarray
38
+ Second input tensor with dtype int64.
39
+ scaling_factor : Scaling factor for rescaling the output.
40
+ Optional scalar tensor for rescaling when rescale=1.
41
+ rescale : int, optional
42
+ Whether to apply rescaling (0=no, 1=yes).
43
+
44
+ Returns
45
+ -------
46
+ numpy.ndarray
47
+ Mul tensor with dtype int64.
48
+
49
+ Notes
50
+ -----
51
+ - This op is part of the `ai.onnx.contrib` custom domain.
52
+ - ONNX Runtime Extensions is required to register this op.
53
+
54
+ References
55
+ ----------
56
+ For more information on the Mul operation, please refer to the
57
+ ONNX standard Mul operator documentation:
58
+ https://onnx.ai/onnx/operators/onnx__Mul.html
59
+ """
60
+ try:
61
+ result = a * b
62
+ result = rescaling(scaling_factor, rescale, result)
63
+ return result.astype(np.int64)
64
+ except Exception as e:
65
+ msg = f"Int64Mul failed: {e}"
66
+ raise RuntimeError(msg) from e
@@ -31,7 +31,7 @@ class InvalidParamError(QuantizationError):
31
31
  quantization the quantization process.
32
32
  """
33
33
 
34
- def __init__( # noqa: PLR0913
34
+ def __init__(
35
35
  self: QuantizationError,
36
36
  node_name: str,
37
37
  op_type: str,
@@ -151,7 +151,7 @@ class InvalidConfigError(QuantizationError):
151
151
  def __init__(
152
152
  self: QuantizationError,
153
153
  key: str,
154
- value: str | float | bool | None,
154
+ value: str | float | bool | None, # noqa: FBT001
155
155
  expected: str | None = None,
156
156
  ) -> None:
157
157
  """Initialize InvalidConfigError with context about the bad config.