JSTprove 1.1.0__py3-none-macosx_11_0_arm64.whl → 1.3.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/METADATA +3 -3
- {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/RECORD +40 -26
- python/core/binaries/onnx_generic_circuit_1-3-0 +0 -0
- python/core/circuits/base.py +29 -12
- python/core/circuits/errors.py +1 -2
- python/core/model_processing/converters/base.py +3 -3
- python/core/model_processing/converters/onnx_converter.py +28 -27
- python/core/model_processing/onnx_custom_ops/__init__.py +5 -4
- python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
- python/core/model_processing/onnx_custom_ops/mul.py +66 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +2 -2
- python/core/model_processing/onnx_quantizer/layers/base.py +101 -0
- python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
- python/core/model_processing/onnx_quantizer/layers/clip.py +92 -0
- python/core/model_processing/onnx_quantizer/layers/max.py +49 -0
- python/core/model_processing/onnx_quantizer/layers/min.py +54 -0
- python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
- python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -0
- python/core/model_templates/circuit_template.py +48 -38
- python/core/utils/errors.py +1 -1
- python/core/utils/scratch_tests.py +29 -23
- python/scripts/gen_and_bench.py +2 -2
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +18 -14
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +11 -13
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +35 -53
- python/tests/onnx_quantizer_tests/layers/base.py +1 -3
- python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
- python/tests/onnx_quantizer_tests/layers/clip_config.py +127 -0
- python/tests/onnx_quantizer_tests/layers/max_config.py +100 -0
- python/tests/onnx_quantizer_tests/layers/min_config.py +94 -0
- python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +6 -5
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +8 -1
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py +17 -8
- python/core/binaries/onnx_generic_circuit_1-1-0 +0 -0
- {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/WHEEL +0 -0
- {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.1.0.dist-info → jstprove-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: JSTprove
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.3.0
|
|
4
4
|
Summary: Zero-knowledge proofs of ML inference on ONNX models
|
|
5
5
|
Author: Inference Labs Inc
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -45,7 +45,7 @@ Dynamic: license-file
|
|
|
45
45
|
Zero-knowledge proofs of ML inference on **ONNX** models — powered by [Polyhedra Network’s **Expander**](https://github.com/PolyhedraZK/Expander) (GKR/sum-check prover) and [**Expander Compiler Collection (ECC)**](https://github.com/PolyhedraZK/ExpanderCompilerCollection).
|
|
46
46
|
|
|
47
47
|
* 🎯 **You bring ONNX** → we quantize, compile to a circuit, generate a witness, prove, and verify — via a simple CLI.
|
|
48
|
-
* ✅ Supported ops (current): **Conv2D**, **GEMM/MatMul (FC)**, **ReLU**, **MaxPool2D**, **Add**.
|
|
48
|
+
* ✅ Supported ops (current): **Conv2D**, **GEMM/MatMul (FC)**, **ReLU**, **MaxPool2D**, **Add**, **Mul**, **Sub**, **BatchNorm**.
|
|
49
49
|
* 🧰 CLI details: see **[docs/cli.md](docs/cli.md)**
|
|
50
50
|
|
|
51
51
|
👉 Just want to see it in action? Jump to [Quickstart (LeNet demo)](#quickstart-lenet-demo).<br>
|
|
@@ -85,7 +85,7 @@ You provide an **ONNX** model and inputs; JSTprove handles **quantization**, **c
|
|
|
85
85
|
### High-level architecture
|
|
86
86
|
|
|
87
87
|
* **Python pipeline:** Converts **ONNX → quantized ONNX**, prepares I/O, drives the Rust runner, exposes the **CLI**.
|
|
88
|
-
* **Rust crate:** `rust/jstprove_circuits` implements layer circuits (Conv2D, ReLU, MaxPool2D, GEMM/FC) and a runner.
|
|
88
|
+
* **Rust crate:** `rust/jstprove_circuits` implements layer circuits (Conv2D, ReLU, MaxPool2D, GEMM/FC, BatchNorm) and a runner.
|
|
89
89
|
* **Circuit frontend:** [ECC](https://github.com/PolyhedraZK/ExpanderCompilerCollection) Rust API for arithmetic circuits.
|
|
90
90
|
* **Prover backend:** [Expander](https://github.com/PolyhedraZK/Expander) (GKR/sum-check prover/verification).
|
|
91
91
|
|
|
@@ -1,49 +1,57 @@
|
|
|
1
|
-
jstprove-1.
|
|
1
|
+
jstprove-1.3.0.dist-info/licenses/LICENSE,sha256=UXQRcYRUH-PfN27n3P-FMaZFY6jr9jFPKcwT7CWbljw,1160
|
|
2
2
|
python/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
python/core/__init__.py,sha256=RlfbqGAaUulKl44QGMCkkGJBQZ8R_AgC5bU5zS7BjnA,97
|
|
4
4
|
python/core/binaries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
python/core/binaries/expander-exec,sha256=C_1JcezdfLp9sFOQ2z3wp2gcq1k8zjIR09CxJKGGIuM,7095168
|
|
6
|
-
python/core/binaries/onnx_generic_circuit_1-
|
|
6
|
+
python/core/binaries/onnx_generic_circuit_1-3-0,sha256=qbHC9SV_NuNv-vZs5MwXV0NxXkqTTZlmLH59WYCVrC8,3221088
|
|
7
7
|
python/core/circuit_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
python/core/circuit_models/generic_onnx.py,sha256=P65UZkfVBTE6YhaQ951S6QoTHPuU5ntDt8QL5pXghvw,8787
|
|
9
9
|
python/core/circuit_models/simple_circuit.py,sha256=igQrZtQyreyHc26iAgCyDb0TuD2bJAoumYhc1pYPDzQ,4682
|
|
10
10
|
python/core/circuits/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
python/core/circuits/base.py,sha256=
|
|
12
|
-
python/core/circuits/errors.py,sha256=
|
|
11
|
+
python/core/circuits/base.py,sha256=_XSs2LFyBMZEkBSvRp53zc-XaGOopNxYg2xHQz9sqt0,41991
|
|
12
|
+
python/core/circuits/errors.py,sha256=JDNa23wMwNQDTFY0IpDpHDMZ9gOdjDdQmB4GBhL_DCg,5913
|
|
13
13
|
python/core/circuits/zk_model_base.py,sha256=5ggOaJjs2_MJvn-PO1cPN3i7U-XR4L-0zJGYuLVKOLc,820
|
|
14
14
|
python/core/model_processing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
15
|
python/core/model_processing/errors.py,sha256=uh2YFjuuU5JM3anMtSTLAH-zjlNAKStmLDZqRUgBWS8,4611
|
|
16
16
|
python/core/model_processing/converters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
-
python/core/model_processing/converters/base.py,sha256=
|
|
18
|
-
python/core/model_processing/converters/onnx_converter.py,sha256
|
|
19
|
-
python/core/model_processing/onnx_custom_ops/__init__.py,sha256=
|
|
17
|
+
python/core/model_processing/converters/base.py,sha256=o6bNwmqD9sOM9taqMb0ed6804RugQiU3va0rY_EA5SE,4265
|
|
18
|
+
python/core/model_processing/converters/onnx_converter.py,sha256=-eXdF6tfluFRxGgnQtJQ8R2309aYX-8z8HzMxk_Qv8I,44340
|
|
19
|
+
python/core/model_processing/onnx_custom_ops/__init__.py,sha256=ZKUC4ToRxgEEMHcTyERATVEN0KSDs-9cM1T-tTw3I1g,525
|
|
20
|
+
python/core/model_processing/onnx_custom_ops/batchnorm.py,sha256=8kg4iGGdt6B_fIJkpt4v5eNFpoHa4bjTB0NnCSmKFvE,1693
|
|
20
21
|
python/core/model_processing/onnx_custom_ops/conv.py,sha256=6jJm3fcGWzcU4RjVgf179mPFCqsl4C3AR7bqQTffDgA,3464
|
|
21
22
|
python/core/model_processing/onnx_custom_ops/custom_helpers.py,sha256=2WdnHw9NAoN_6wjIBoAQDyL6wEIlZOqo6ysCZp5DpZs,1844
|
|
22
23
|
python/core/model_processing/onnx_custom_ops/gemm.py,sha256=bnEUXhqQCEcH4TIfbMTsCTtAlAlRzFvl4jj8g2QZFWU,2674
|
|
23
24
|
python/core/model_processing/onnx_custom_ops/maxpool.py,sha256=Sd3BwqpGLSVU2iuAAIXAHdI3WO27Aa3g3r29HPiECvM,2319
|
|
25
|
+
python/core/model_processing/onnx_custom_ops/mul.py,sha256=w6X1sl1HnzoUJx2Mm_LaoXGTpvtwXxr3zZDPySVHBcM,1888
|
|
24
26
|
python/core/model_processing/onnx_custom_ops/onnx_helpers.py,sha256=utnJuc5sgb_z1LgxuY9y2cQbMpdEJ8xOOrcP8DhfDCM,5686
|
|
25
27
|
python/core/model_processing/onnx_custom_ops/relu.py,sha256=pZsPXC_r0FPggURKDphh8P1IRXY0w4hH7ExBmYTlWjE,1202
|
|
26
28
|
python/core/model_processing/onnx_quantizer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
|
-
python/core/model_processing/onnx_quantizer/exceptions.py,sha256=
|
|
28
|
-
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=
|
|
29
|
+
python/core/model_processing/onnx_quantizer/exceptions.py,sha256=vzxBRbpvk4ZZbgacDISnqmQQKj7Ls46V08ilHnhaJy0,5645
|
|
30
|
+
python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py,sha256=5I67frJn4j2T1LTvODHixQK4VaqazJFJ0T1BCvqLPgg,9655
|
|
29
31
|
python/core/model_processing/onnx_quantizer/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
32
|
python/core/model_processing/onnx_quantizer/layers/add.py,sha256=AGxzqMa0jABIEKOIgPqEAA7EpZtynQtnD9nxI2NHc0s,1409
|
|
31
|
-
python/core/model_processing/onnx_quantizer/layers/base.py,sha256=
|
|
33
|
+
python/core/model_processing/onnx_quantizer/layers/base.py,sha256=zUUAZpXCtbxbhHzuYczTNZPe-xWr6TxpmoAIDe4kCo4,21176
|
|
34
|
+
python/core/model_processing/onnx_quantizer/layers/batchnorm.py,sha256=KSBDPHd52f5Qyf-cnIDFPmfzssaJgMPiTmpIWEdM41U,7718
|
|
35
|
+
python/core/model_processing/onnx_quantizer/layers/clip.py,sha256=HrhiLtqC3cIAvU0wRCqp8_8ZSFH8a3F1Jf_qkXlY44s,3043
|
|
32
36
|
python/core/model_processing/onnx_quantizer/layers/constant.py,sha256=l1IvgvXkmFMiaBsym8wchPF-y1ZH-c5PmFUy92IXWok,3694
|
|
33
37
|
python/core/model_processing/onnx_quantizer/layers/conv.py,sha256=TlUpCRO6PPqH7MPkIrEiEcVfzuiN1WMYEiNIjhYXtWM,4451
|
|
34
38
|
python/core/model_processing/onnx_quantizer/layers/gemm.py,sha256=7fCUMv8OLVZ45a2lYjA2XNvcW3By7lSbX7zeForNK-0,3950
|
|
39
|
+
python/core/model_processing/onnx_quantizer/layers/max.py,sha256=3gUxrdXwcVAtgR-_j4xQ0085Wj0oEBLT897TImxF2d4,1343
|
|
35
40
|
python/core/model_processing/onnx_quantizer/layers/maxpool.py,sha256=PJ8hZPPBpfWV_RZdySl50-BU8TATjcg8Tg_mrAVS1Ic,4916
|
|
41
|
+
python/core/model_processing/onnx_quantizer/layers/min.py,sha256=cQbXzGOApR6HUJZMARXy87W8IbUC562jnAQm8J8ynQI,1709
|
|
42
|
+
python/core/model_processing/onnx_quantizer/layers/mul.py,sha256=qHsmnYPH-c5uiFeDCvV6e1xSgmIXJ64Sjvh0LYDYEqQ,1396
|
|
36
43
|
python/core/model_processing/onnx_quantizer/layers/relu.py,sha256=d-5fyeKNLTgKKnqCwURpxkjl7QdbJQpuovtCFBM03FA,1685
|
|
44
|
+
python/core/model_processing/onnx_quantizer/layers/sub.py,sha256=M7D98TZBNP9-2R9MX6mcpYlrWFxTiX9JCs3XNcg1U-Q,1409
|
|
37
45
|
python/core/model_templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
38
|
-
python/core/model_templates/circuit_template.py,sha256=
|
|
46
|
+
python/core/model_templates/circuit_template.py,sha256=OAqMRshi9OiJYoqpjkg5tUfNf18MfZmhsxxD6SANm_4,2106
|
|
39
47
|
python/core/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
40
48
|
python/core/utils/benchmarking_helpers.py,sha256=0nT38SCrjP_BlvJODsc9twF9ZmIFg_1sAvSyeNfv4mQ,5235
|
|
41
49
|
python/core/utils/constants.py,sha256=Qu5_6OUe1XIsL-IY5_4923eN7x1-SPv6ohQonztAobA,102
|
|
42
|
-
python/core/utils/errors.py,sha256=
|
|
50
|
+
python/core/utils/errors.py,sha256=Uf57cRKpot_u5Yr8HRmjLmInkdd_x0x5YpTGBncZgl4,3722
|
|
43
51
|
python/core/utils/general_layer_functions.py,sha256=tg2WWhmR-4TlKn8OeCu1qNbLf8qdKVP3jl9mhZn_sTg,9781
|
|
44
52
|
python/core/utils/helper_functions.py,sha256=3JwJa4wHoUBteukDw4bAetqMsQLeJ0_sJ0qIdKy7GCY,37097
|
|
45
53
|
python/core/utils/model_registry.py,sha256=aZg_9LEqsBXK84oxQ8A3NGZl-9aGnLgfR-kgxkOwV50,4895
|
|
46
|
-
python/core/utils/scratch_tests.py,sha256=
|
|
54
|
+
python/core/utils/scratch_tests.py,sha256=o2VDTk8QBKA3UHHE-h7Ghtoge6kGG7G-8qwvesuTFFc,2281
|
|
47
55
|
python/core/utils/witness_utils.py,sha256=ukvbF6EaHMPzRQVZad9wQ9gISRwBGQ1hEAHzc5TpGuw,9488
|
|
48
56
|
python/frontend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
49
57
|
python/frontend/cli.py,sha256=lkvhzQC6bv0AgWUypg_cH-JT574r89qgTIsgHDT9GRg,3106
|
|
@@ -62,47 +70,53 @@ python/frontend/commands/bench/model.py,sha256=SaIWXAXZbWGbrNqEo5bs4NwgZfMOmmxaC
|
|
|
62
70
|
python/frontend/commands/bench/sweep.py,sha256=rl-QBS9eXgQkuPJBhsU4CohfE1PdJvnM8NRhNU7ztQw,5279
|
|
63
71
|
python/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
64
72
|
python/scripts/benchmark_runner.py,sha256=sjbqaLrdjt94AoyQXAxT4FhsN6aRu5idTRQ5uHmZOWM,28593
|
|
65
|
-
python/scripts/gen_and_bench.py,sha256=
|
|
73
|
+
python/scripts/gen_and_bench.py,sha256=V36x7djYmHlveAJgYzMlXwnmF0gAGO3-1mg9PWOmpj8,16249
|
|
66
74
|
python/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
67
75
|
python/tests/test_cli.py,sha256=OiAyG3aBpukk0i5FFWbiKaF42wf-7By-UWDHNjwtsqo,27042
|
|
68
76
|
python/tests/circuit_e2e_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
69
|
-
python/tests/circuit_e2e_tests/circuit_model_developer_test.py,sha256=
|
|
70
|
-
python/tests/circuit_e2e_tests/helper_fns_for_tests.py,sha256=
|
|
77
|
+
python/tests/circuit_e2e_tests/circuit_model_developer_test.py,sha256=8hl8SKw7obXplo0jsiKoKIZLxlu1_HhXvGDeSBDBars,39456
|
|
78
|
+
python/tests/circuit_e2e_tests/helper_fns_for_tests.py,sha256=uEThqTsRdNJivHwAv-aJIUtSPlmVHdhMZqZSH1OqhDE,5177
|
|
71
79
|
python/tests/circuit_e2e_tests/other_e2e_test.py,sha256=amWRa1tIBHdQpd9-XS7vBXG0tkdV_9K9fH-FT5LFh7E,11301
|
|
72
80
|
python/tests/circuit_parent_classes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
73
81
|
python/tests/circuit_parent_classes/test_circuit.py,sha256=5vgcZHD2wY_pIRFNAhEZuBJD4uw2QyTck75Z9CJaACE,45968
|
|
74
82
|
python/tests/circuit_parent_classes/test_onnx_converter.py,sha256=sJ0o8sducNUtmYKmsqfx7WEsIEd6oNbnWk71rXS_nIU,6575
|
|
75
|
-
python/tests/circuit_parent_classes/test_ort_custom_layers.py,sha256=
|
|
83
|
+
python/tests/circuit_parent_classes/test_ort_custom_layers.py,sha256=PBKjt6Mu3lRco4ijD2BLwAHPWRFic-OUwWPVsvBoEpU,3042
|
|
76
84
|
python/tests/onnx_quantizer_tests/__init__.py,sha256=IZPGWHgjoay3gM1p2WJNh5cnZ79EP2VP-bcKy8AfJjY,18
|
|
77
85
|
python/tests/onnx_quantizer_tests/test_base_layer.py,sha256=Ro7k-eUbGCyfIZ-OVNjLlCIz3mb02uHFWboFuWOdXKs,6526
|
|
78
86
|
python/tests/onnx_quantizer_tests/test_exceptions.py,sha256=pwhARalEXx7REkcnIVZPi-4J1wgzgZN4xG-wLsx4rTs,3473
|
|
79
87
|
python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py,sha256=m6mNe1KDRFIE2P0YURTIAim9-Di0BoPPAaaOOlorDIk,7367
|
|
80
|
-
python/tests/onnx_quantizer_tests/test_registered_quantizers.py,sha256=
|
|
88
|
+
python/tests/onnx_quantizer_tests/test_registered_quantizers.py,sha256=lw_jYSbQ9ZM9P-jSt_LFSuve9vQ22cLtSui0W3zGqpo,4209
|
|
81
89
|
python/tests/onnx_quantizer_tests/testing_helper_functions.py,sha256=N0fQv2pYzUCVZ7wkcR8gEKs5zTXT1hWrK-HKSTQYvYU,534
|
|
82
90
|
python/tests/onnx_quantizer_tests/layers/__init__.py,sha256=xP-RmW6LfIANgK1s9Q0KZet2yvNr-3c6YIVLAAQqGUY,404
|
|
83
91
|
python/tests/onnx_quantizer_tests/layers/add_config.py,sha256=T3tGddupDtrvLck2SL2yETDblNtv0aU7Tl7fNyZUhO4,4133
|
|
84
|
-
python/tests/onnx_quantizer_tests/layers/base.py,sha256=
|
|
92
|
+
python/tests/onnx_quantizer_tests/layers/base.py,sha256=3nqmU2PgOdK_mPkz-YHg3idgr-PXYbu5kCIY-Uic5yo,9317
|
|
93
|
+
python/tests/onnx_quantizer_tests/layers/batchnorm_config.py,sha256=P-sZuHAdEfNczcgTeLjqJnEbpqN3dKTsbqvY4-SBqiQ,8231
|
|
94
|
+
python/tests/onnx_quantizer_tests/layers/clip_config.py,sha256=-OuhnUgz6xY4iW1jUR7W-J__Ie9lXI9vplmzp8qXqRc,4973
|
|
85
95
|
python/tests/onnx_quantizer_tests/layers/constant_config.py,sha256=RdrKNMNZjI3Sk5o8WLNqmBUyYVJRWgtFbQ6oFWMwyQk,1193
|
|
86
96
|
python/tests/onnx_quantizer_tests/layers/conv_config.py,sha256=H0ioW4H3ei5IK4tKhrA0ffThxJ4K5oO9jIs9A0T0VaM,6005
|
|
87
97
|
python/tests/onnx_quantizer_tests/layers/factory.py,sha256=WLLEP9ECmSpTliSjhtdWOHcX1xOi6HM10S9Y4re1A74,4844
|
|
88
98
|
python/tests/onnx_quantizer_tests/layers/flatten_config.py,sha256=Xln5Hh6gyeM5gGRCjLGvIL-u08NEs1tXSF32urCqPfE,2110
|
|
89
99
|
python/tests/onnx_quantizer_tests/layers/gemm_config.py,sha256=t7nJY-Wnj6YUD821-jaWzgrQVPa6ytwER3hFMsvyY6Y,7294
|
|
100
|
+
python/tests/onnx_quantizer_tests/layers/max_config.py,sha256=vzR8-2wbPGcH0GMmAJ_sXSEdMtZOjVNGufU__N3Jfyw,3906
|
|
90
101
|
python/tests/onnx_quantizer_tests/layers/maxpool_config.py,sha256=XfTPk_ZQXEzaCjHHymSLVv2HS-PKH1rS9IuyyoEtM78,3176
|
|
102
|
+
python/tests/onnx_quantizer_tests/layers/min_config.py,sha256=izKtCaMXoQHiAfmcGlJRQdKMQz3Su8n0p2mEn0y56Do,3774
|
|
103
|
+
python/tests/onnx_quantizer_tests/layers/mul_config.py,sha256=_Oy4b97ORxFlF3w0BmJ94hNA968HQx2AvwYiASrGPxw,4135
|
|
91
104
|
python/tests/onnx_quantizer_tests/layers/relu_config.py,sha256=_aHuddDApLUBOa0FiR9h4fNfmMSnH5r4JzOMLW0KaTk,2197
|
|
92
105
|
python/tests/onnx_quantizer_tests/layers/reshape_config.py,sha256=fZchSqIAy76m7j97wVC_UI6slSpv8nbwukhkbGR2sRE,2203
|
|
106
|
+
python/tests/onnx_quantizer_tests/layers/sub_config.py,sha256=IxF18mG9kjlEiKYSNG912CEcBxOFGxIWoRAwjvBXiRo,4133
|
|
93
107
|
python/tests/onnx_quantizer_tests/layers_tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
94
108
|
python/tests/onnx_quantizer_tests/layers_tests/base_test.py,sha256=UgbcT97tgcuTtO1pOADpww9bz_JElKiI2mxLJYKyF1k,2992
|
|
95
109
|
python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py,sha256=Vxn4LEWHZeGa_vS1-7ptFqSSBb0D-3BG-ETocP4pvsI,3651
|
|
96
110
|
python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py,sha256=40779aaHgdryVwLlIO18F1d7uSLSXdJUG5Uj_5-xD4U,6712
|
|
97
111
|
python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py,sha256=t5c_zqO4Ex3HIFWcykX4PTftdKN7UWnEOF5blShL0Ik,1881
|
|
98
|
-
python/tests/onnx_quantizer_tests/layers_tests/test_integration.py,sha256=
|
|
99
|
-
python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py,sha256=
|
|
112
|
+
python/tests/onnx_quantizer_tests/layers_tests/test_integration.py,sha256=xNt2STeXB33NcpteDThwGTSW1Hm15POf8a4aPBSVrvI,7254
|
|
113
|
+
python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py,sha256=DatmgvibQazP100B4NHDu7u-O2-f90juPKvPOXuPnXo,9491
|
|
100
114
|
python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py,sha256=RfnIIiYbgPbU3620H6MPvSxE3MNR2G1yPELwdWV3mK4,4107
|
|
101
115
|
python/tests/onnx_quantizer_tests/layers_tests/test_validation.py,sha256=jz-WtIEP-jjUklOOAnznwPUXbf07U2PAMGrhzMWP0JU,1371
|
|
102
116
|
python/tests/utils_testing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
103
117
|
python/tests/utils_testing/test_helper_functions.py,sha256=xmeGQieh4LE9U-CDKBlHhSWqH0cAmmDU3qXNbDkkvms,27192
|
|
104
|
-
jstprove-1.
|
|
105
|
-
jstprove-1.
|
|
106
|
-
jstprove-1.
|
|
107
|
-
jstprove-1.
|
|
108
|
-
jstprove-1.
|
|
118
|
+
jstprove-1.3.0.dist-info/METADATA,sha256=CqGuzrQWy_MUYtOcy-0i8mB_2eAKUvQ1R3tjEX-3N4o,14100
|
|
119
|
+
jstprove-1.3.0.dist-info/WHEEL,sha256=jc2C2uw104ioj1TL9cE0YO67_kdAwX4W8JgYPomxr5M,105
|
|
120
|
+
jstprove-1.3.0.dist-info/entry_points.txt,sha256=nGcTSO-4q08gPl1IoWdrPaiY7IbO7XvmXKkd34dYHc8,49
|
|
121
|
+
jstprove-1.3.0.dist-info/top_level.txt,sha256=J-z0poNcsv31IHB413--iOY8LoHBKiTHeybHX3abokI,7
|
|
122
|
+
jstprove-1.3.0.dist-info/RECORD,,
|
|
Binary file
|
python/core/circuits/base.py
CHANGED
|
@@ -4,13 +4,12 @@ import logging
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import TYPE_CHECKING, Any
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
import numpy as np
|
|
8
8
|
|
|
9
9
|
from python.core.utils.errors import ShapeMismatchError
|
|
10
10
|
from python.core.utils.witness_utils import compare_witness_to_io, load_witness
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
|
-
import numpy as np
|
|
14
13
|
import torch
|
|
15
14
|
|
|
16
15
|
from python.core.circuits.errors import (
|
|
@@ -775,18 +774,18 @@ class Circuit:
|
|
|
775
774
|
def reshape_inputs_for_inference(
|
|
776
775
|
self: Circuit,
|
|
777
776
|
inputs: dict[str],
|
|
778
|
-
) -> ndarray | dict[str, ndarray]:
|
|
777
|
+
) -> np.ndarray | dict[str, np.ndarray]:
|
|
779
778
|
"""
|
|
780
779
|
Reshape input tensors to match the model's expected input shape.
|
|
781
780
|
|
|
782
781
|
Parameters
|
|
783
782
|
----------
|
|
784
|
-
inputs : dict[str] or ndarray
|
|
783
|
+
inputs : dict[str] or np.ndarray
|
|
785
784
|
Input tensors or a dictionary of tensors.
|
|
786
785
|
|
|
787
786
|
Returns
|
|
788
787
|
-------
|
|
789
|
-
ndarray or dict[str, ndarray]
|
|
788
|
+
np.ndarray or dict[str, np.ndarray]
|
|
790
789
|
Reshaped input(s) ready for inference.
|
|
791
790
|
"""
|
|
792
791
|
|
|
@@ -801,15 +800,33 @@ class Circuit:
|
|
|
801
800
|
if isinstance(inputs, dict):
|
|
802
801
|
if len(inputs) == 1:
|
|
803
802
|
only_key = next(iter(inputs))
|
|
804
|
-
|
|
803
|
+
value = np.asarray(inputs[only_key])
|
|
804
|
+
|
|
805
|
+
# If shape is a dict, extract the shape for this key
|
|
806
|
+
if isinstance(shape, dict):
|
|
807
|
+
key_shape = shape.get(only_key, None)
|
|
808
|
+
if key_shape is None:
|
|
809
|
+
raise CircuitConfigurationError(
|
|
810
|
+
missing_attributes=[f"input_shape[{only_key!r}]"],
|
|
811
|
+
)
|
|
812
|
+
shape = key_shape
|
|
813
|
+
|
|
814
|
+
# From here on, treat it as a regular reshape
|
|
815
|
+
inputs = value
|
|
805
816
|
else:
|
|
806
817
|
return self._reshape_dict_inputs(inputs, shape)
|
|
807
818
|
|
|
808
819
|
# --- Regular reshape ---
|
|
820
|
+
if not isinstance(shape, (list, tuple)):
|
|
821
|
+
msg = (
|
|
822
|
+
f"Expected list or tuple shape for reshape, got {type(shape).__name__}"
|
|
823
|
+
)
|
|
824
|
+
raise CircuitInputError(msg)
|
|
825
|
+
|
|
809
826
|
try:
|
|
810
|
-
return asarray(inputs).reshape(shape)
|
|
827
|
+
return np.asarray(inputs).reshape(shape)
|
|
811
828
|
except Exception as e:
|
|
812
|
-
raise ShapeMismatchError(shape, list(asarray(inputs).shape)) from e
|
|
829
|
+
raise ShapeMismatchError(shape, list(np.asarray(inputs).shape)) from e
|
|
813
830
|
|
|
814
831
|
def _reshape_dict_inputs(
|
|
815
832
|
self: Circuit,
|
|
@@ -824,7 +841,7 @@ class Circuit:
|
|
|
824
841
|
)
|
|
825
842
|
raise CircuitInputError(msg, parameter="shape", expected="dict")
|
|
826
843
|
for key, value in inputs.items():
|
|
827
|
-
tensor = asarray(value)
|
|
844
|
+
tensor = np.asarray(value)
|
|
828
845
|
try:
|
|
829
846
|
inputs[key] = tensor.reshape(shape[key])
|
|
830
847
|
except Exception as e:
|
|
@@ -867,16 +884,16 @@ class Circuit:
|
|
|
867
884
|
value = inputs[key]
|
|
868
885
|
|
|
869
886
|
# --- handle unsupported input types BEFORE entering try ---
|
|
870
|
-
if not isinstance(value, (ndarray, list, tuple)):
|
|
887
|
+
if not isinstance(value, (np.ndarray, list, tuple)):
|
|
871
888
|
msg = f"Unsupported input type for key '{key}': {type(value).__name__}"
|
|
872
889
|
raise CircuitProcessingError(message=msg)
|
|
873
890
|
|
|
874
891
|
try:
|
|
875
892
|
# Convert to tensor, flatten, and back to list
|
|
876
|
-
if isinstance(value, ndarray):
|
|
893
|
+
if isinstance(value, np.ndarray):
|
|
877
894
|
flattened = value.flatten().tolist()
|
|
878
895
|
else:
|
|
879
|
-
flattened = asarray(value).flatten().tolist()
|
|
896
|
+
flattened = np.asarray(value).flatten().tolist()
|
|
880
897
|
except Exception as e:
|
|
881
898
|
msg = f"Failed to flatten input '{key}' (type {type(value).__name__})"
|
|
882
899
|
raise CircuitProcessingError(message=msg) from e
|
python/core/circuits/errors.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# python/core/utils/exceptions.py
|
|
2
1
|
from __future__ import annotations
|
|
3
2
|
|
|
4
3
|
from python.core.utils.helper_functions import RunType
|
|
@@ -68,7 +67,7 @@ class CircuitInputError(CircuitError):
|
|
|
68
67
|
actual (any): Actual value encountered (optional).
|
|
69
68
|
"""
|
|
70
69
|
|
|
71
|
-
def __init__(
|
|
70
|
+
def __init__(
|
|
72
71
|
self: CircuitInputError,
|
|
73
72
|
message: str | None = None,
|
|
74
73
|
parameter: str | None = None,
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
from enum import Enum
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
6
|
|
|
7
7
|
if TYPE_CHECKING:
|
|
8
8
|
import numpy as np
|
|
@@ -16,10 +16,10 @@ class ModelType(Enum):
|
|
|
16
16
|
|
|
17
17
|
ONNXLayerDict = dict[
|
|
18
18
|
str,
|
|
19
|
-
|
|
19
|
+
int | str | list[str] | dict[str, list[int]] | list | None | dict,
|
|
20
20
|
]
|
|
21
21
|
|
|
22
|
-
CircuitParamsDict = dict[str,
|
|
22
|
+
CircuitParamsDict = dict[str, int | dict[str, bool]]
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class ModelConverter(ABC):
|
|
@@ -247,6 +247,7 @@ class ONNXConverter(ModelConverter):
|
|
|
247
247
|
|
|
248
248
|
def analyze_layers(
|
|
249
249
|
self: ONNXConverter,
|
|
250
|
+
model: onnx.ModelProto,
|
|
250
251
|
output_name_to_shape: dict[str, list[int]] | None = None,
|
|
251
252
|
) -> tuple[list[ONNXLayer], list[ONNXLayer]]:
|
|
252
253
|
"""Analyze the onnx model graph into
|
|
@@ -268,29 +269,29 @@ class ONNXConverter(ModelConverter):
|
|
|
268
269
|
id_count = 0
|
|
269
270
|
# Apply shape inference on the model
|
|
270
271
|
if not output_name_to_shape:
|
|
271
|
-
inferred_model = shape_inference.infer_shapes(
|
|
272
|
+
inferred_model = shape_inference.infer_shapes(model)
|
|
272
273
|
self._onnx_check_model_safely(inferred_model)
|
|
273
274
|
|
|
274
275
|
output_name_to_shape = extract_shape_dict(inferred_model)
|
|
275
276
|
domain_to_version = {
|
|
276
|
-
opset.domain: opset.version for opset in
|
|
277
|
+
opset.domain: opset.version for opset in model.opset_import
|
|
277
278
|
}
|
|
278
279
|
|
|
279
280
|
id_count = 0
|
|
280
281
|
architecture = self.get_model_architecture(
|
|
281
|
-
|
|
282
|
+
model,
|
|
282
283
|
output_name_to_shape,
|
|
283
284
|
domain_to_version,
|
|
284
285
|
)
|
|
285
286
|
w_and_b = self.get_model_w_and_b(
|
|
286
|
-
|
|
287
|
+
model,
|
|
287
288
|
output_name_to_shape,
|
|
288
289
|
id_count,
|
|
289
290
|
domain_to_version,
|
|
290
291
|
)
|
|
291
292
|
except InvalidModelError:
|
|
292
293
|
raise
|
|
293
|
-
except (ValueError, TypeError, RuntimeError, OSError
|
|
294
|
+
except (ValueError, TypeError, RuntimeError, OSError) as e:
|
|
294
295
|
raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
|
|
295
296
|
except Exception as e:
|
|
296
297
|
raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
|
|
@@ -557,6 +558,7 @@ class ONNXConverter(ModelConverter):
|
|
|
557
558
|
output_shapes = {
|
|
558
559
|
out_name: output_name_to_shape.get(out_name, []) for out_name in outputs
|
|
559
560
|
}
|
|
561
|
+
|
|
560
562
|
return ONNXLayer(
|
|
561
563
|
id=layer_id,
|
|
562
564
|
name=name,
|
|
@@ -605,6 +607,7 @@ class ONNXConverter(ModelConverter):
|
|
|
605
607
|
np_data = onnx.numpy_helper.to_array(node, constant_dtype)
|
|
606
608
|
except (ValueError, TypeError, onnx.ONNXException, Exception) as e:
|
|
607
609
|
raise SerializationError(
|
|
610
|
+
model_type=self.model_type,
|
|
608
611
|
tensor_name=node.name,
|
|
609
612
|
reason=f"Failed to convert tensor: {e!s}",
|
|
610
613
|
) from e
|
|
@@ -1040,38 +1043,36 @@ class ONNXConverter(ModelConverter):
|
|
|
1040
1043
|
``rescale_config``.
|
|
1041
1044
|
"""
|
|
1042
1045
|
inferred_model = shape_inference.infer_shapes(self.model)
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
scale_base=getattr(self, "scale_base", 2),
|
|
1046
|
-
scale_exponent=(getattr(self, "scale_exponent", 18)),
|
|
1047
|
-
)
|
|
1046
|
+
scale_base = getattr(self, "scale_base", 2)
|
|
1047
|
+
scale_exponent = getattr(self, "scale_exponent", 18)
|
|
1048
1048
|
|
|
1049
1049
|
# Check the model and print Y"s shape information
|
|
1050
1050
|
self._onnx_check_model_safely(inferred_model)
|
|
1051
1051
|
output_name_to_shape = extract_shape_dict(inferred_model)
|
|
1052
|
-
|
|
1053
|
-
|
|
1052
|
+
scaled_and_transformed_model = self.op_quantizer.apply_pre_analysis_transforms(
|
|
1053
|
+
inferred_model,
|
|
1054
|
+
scale_exponent=scale_exponent,
|
|
1055
|
+
scale_base=scale_base,
|
|
1056
|
+
)
|
|
1057
|
+
# Get layers in correct format
|
|
1058
|
+
(architecture, w_and_b) = self.analyze_layers(
|
|
1059
|
+
scaled_and_transformed_model,
|
|
1060
|
+
output_name_to_shape,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
def _convert_tensor_to_int_list(w: ONNXLayer) -> list:
|
|
1054
1064
|
try:
|
|
1055
|
-
|
|
1056
|
-
|
|
1065
|
+
arr = np.asarray(w.tensor).astype(np.int64)
|
|
1066
|
+
return arr.tolist()
|
|
1067
|
+
except Exception as e:
|
|
1057
1068
|
raise SerializationError(
|
|
1058
1069
|
tensor_name=getattr(w, "name", None),
|
|
1070
|
+
model_type=self.model_type,
|
|
1059
1071
|
reason=f"cannot convert to ndarray: {e}",
|
|
1060
1072
|
) from e
|
|
1061
1073
|
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
if "bias" in w.name:
|
|
1065
|
-
w_and_b_scaled = w_and_b_array * scaling * scaling
|
|
1066
|
-
else:
|
|
1067
|
-
w_and_b_scaled = w_and_b_array * scaling
|
|
1068
|
-
w_and_b_out = w_and_b_scaled.astype(np.int64).tolist()
|
|
1069
|
-
w.tensor = w_and_b_out
|
|
1070
|
-
except (ValueError, TypeError, OverflowError, Exception) as e:
|
|
1071
|
-
raise SerializationError(
|
|
1072
|
-
tensor_name=getattr(w, "name", None),
|
|
1073
|
-
reason=str(e),
|
|
1074
|
-
) from e
|
|
1074
|
+
for w in w_and_b:
|
|
1075
|
+
w.tensor = _convert_tensor_to_int_list(w)
|
|
1075
1076
|
|
|
1076
1077
|
inputs = []
|
|
1077
1078
|
outputs = []
|
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
import importlib
|
|
2
2
|
import pkgutil
|
|
3
|
-
import
|
|
3
|
+
from pathlib import Path
|
|
4
4
|
|
|
5
5
|
# Get the package name of the current module
|
|
6
6
|
package_name = __name__
|
|
7
7
|
|
|
8
8
|
# Dynamically import all .py files in this package directory (except __init__.py)
|
|
9
|
-
package_dir =
|
|
9
|
+
package_dir = Path(__file__).parent.as_posix()
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
|
|
12
|
+
__all__: list[str] = []
|
|
12
13
|
|
|
13
14
|
for _, module_name, is_pkg in pkgutil.iter_modules([package_dir]):
|
|
14
15
|
if not is_pkg and (module_name != "custom_helpers"):
|
|
15
16
|
importlib.import_module(f"{package_name}.{module_name}")
|
|
16
|
-
__all__.append(module_name)
|
|
17
|
+
__all__.append(module_name) # noqa: PYI056
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from onnxruntime_extensions import PyCustomOpDef, onnx_op
|
|
5
|
+
|
|
6
|
+
from .custom_helpers import rescaling
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@onnx_op(
|
|
10
|
+
op_type="Int64BatchNorm",
|
|
11
|
+
domain="ai.onnx.contrib",
|
|
12
|
+
inputs=[
|
|
13
|
+
PyCustomOpDef.dt_int64, # X (int64)
|
|
14
|
+
PyCustomOpDef.dt_int64, # mul (int64 scaled multiplier)
|
|
15
|
+
PyCustomOpDef.dt_int64, # add (int64 scaled adder)
|
|
16
|
+
PyCustomOpDef.dt_int64, # scaling_factor
|
|
17
|
+
],
|
|
18
|
+
outputs=[PyCustomOpDef.dt_int64],
|
|
19
|
+
attrs={"rescale": PyCustomOpDef.dt_int64},
|
|
20
|
+
)
|
|
21
|
+
def int64_batchnorm(
|
|
22
|
+
x: np.ndarray,
|
|
23
|
+
mul: np.ndarray,
|
|
24
|
+
add: np.ndarray,
|
|
25
|
+
scaling_factor: np.ndarray | None = None,
|
|
26
|
+
rescale: int | None = None,
|
|
27
|
+
) -> np.ndarray:
|
|
28
|
+
"""
|
|
29
|
+
Int64 BatchNorm (folded into affine transform).
|
|
30
|
+
|
|
31
|
+
Computes:
|
|
32
|
+
Y = X * mul + add
|
|
33
|
+
where mul/add are already scaled to int64.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
x : Input int64 tensor
|
|
38
|
+
mul : Per-channel int64 scale multipliers
|
|
39
|
+
add : Per-channel int64 bias terms
|
|
40
|
+
scaling_factor: factor to rescale
|
|
41
|
+
rescale : Optional flag to apply post-scaling
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
numpy.ndarray (int64)
|
|
46
|
+
"""
|
|
47
|
+
try:
|
|
48
|
+
# Broadcasting shapes must match batchnorm layout: NCHW
|
|
49
|
+
# Typically mul/add have shape [C]
|
|
50
|
+
dims_x = len(x.shape)
|
|
51
|
+
dim_ones = (1,) * (dims_x - 2)
|
|
52
|
+
mul = mul.reshape(-1, *dim_ones)
|
|
53
|
+
add = add.reshape(-1, *dim_ones)
|
|
54
|
+
|
|
55
|
+
y = x * mul + add
|
|
56
|
+
|
|
57
|
+
if rescale is not None:
|
|
58
|
+
y = rescaling(scaling_factor, rescale, y)
|
|
59
|
+
|
|
60
|
+
return y.astype(np.int64)
|
|
61
|
+
|
|
62
|
+
except Exception as e:
|
|
63
|
+
msg = f"Int64BatchNorm failed: {e}"
|
|
64
|
+
raise RuntimeError(msg) from e
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnxruntime_extensions import PyCustomOpDef, onnx_op
|
|
3
|
+
|
|
4
|
+
from .custom_helpers import rescaling
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@onnx_op(
|
|
8
|
+
op_type="Int64Mul",
|
|
9
|
+
domain="ai.onnx.contrib",
|
|
10
|
+
inputs=[
|
|
11
|
+
PyCustomOpDef.dt_int64,
|
|
12
|
+
PyCustomOpDef.dt_int64,
|
|
13
|
+
PyCustomOpDef.dt_int64, # Scalar
|
|
14
|
+
],
|
|
15
|
+
outputs=[PyCustomOpDef.dt_int64],
|
|
16
|
+
attrs={
|
|
17
|
+
"rescale": PyCustomOpDef.dt_int64,
|
|
18
|
+
},
|
|
19
|
+
)
|
|
20
|
+
def int64_mul(
|
|
21
|
+
a: np.ndarray,
|
|
22
|
+
b: np.ndarray,
|
|
23
|
+
scaling_factor: np.ndarray | None = None,
|
|
24
|
+
rescale: int | None = None,
|
|
25
|
+
) -> np.ndarray:
|
|
26
|
+
"""
|
|
27
|
+
Performs a Mul (hadamard product) operation on int64 input tensors.
|
|
28
|
+
|
|
29
|
+
This function is registered as a custom ONNX operator via onnxruntime_extensions
|
|
30
|
+
and is used in the JSTprove quantized inference pipeline.
|
|
31
|
+
It applies Mul with the rescaling the outputs back to the original scale.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
a : np.ndarray
|
|
36
|
+
First input tensor with dtype int64.
|
|
37
|
+
b : np.ndarray
|
|
38
|
+
Second input tensor with dtype int64.
|
|
39
|
+
scaling_factor : Scaling factor for rescaling the output.
|
|
40
|
+
Optional scalar tensor for rescaling when rescale=1.
|
|
41
|
+
rescale : int, optional
|
|
42
|
+
Whether to apply rescaling (0=no, 1=yes).
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
numpy.ndarray
|
|
47
|
+
Mul tensor with dtype int64.
|
|
48
|
+
|
|
49
|
+
Notes
|
|
50
|
+
-----
|
|
51
|
+
- This op is part of the `ai.onnx.contrib` custom domain.
|
|
52
|
+
- ONNX Runtime Extensions is required to register this op.
|
|
53
|
+
|
|
54
|
+
References
|
|
55
|
+
----------
|
|
56
|
+
For more information on the Mul operation, please refer to the
|
|
57
|
+
ONNX standard Mul operator documentation:
|
|
58
|
+
https://onnx.ai/onnx/operators/onnx__Mul.html
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
result = a * b
|
|
62
|
+
result = rescaling(scaling_factor, rescale, result)
|
|
63
|
+
return result.astype(np.int64)
|
|
64
|
+
except Exception as e:
|
|
65
|
+
msg = f"Int64Mul failed: {e}"
|
|
66
|
+
raise RuntimeError(msg) from e
|
|
@@ -31,7 +31,7 @@ class InvalidParamError(QuantizationError):
|
|
|
31
31
|
quantization the quantization process.
|
|
32
32
|
"""
|
|
33
33
|
|
|
34
|
-
def __init__(
|
|
34
|
+
def __init__(
|
|
35
35
|
self: QuantizationError,
|
|
36
36
|
node_name: str,
|
|
37
37
|
op_type: str,
|
|
@@ -151,7 +151,7 @@ class InvalidConfigError(QuantizationError):
|
|
|
151
151
|
def __init__(
|
|
152
152
|
self: QuantizationError,
|
|
153
153
|
key: str,
|
|
154
|
-
value: str | float | bool | None,
|
|
154
|
+
value: str | float | bool | None, # noqa: FBT001
|
|
155
155
|
expected: str | None = None,
|
|
156
156
|
) -> None:
|
|
157
157
|
"""Initialize InvalidConfigError with context about the bad config.
|