JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
- python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
- python/core/circuit_models/generic_onnx.py +43 -9
- python/core/circuits/base.py +231 -71
- python/core/model_processing/converters/onnx_converter.py +114 -59
- python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
- python/core/model_processing/onnx_custom_ops/mul.py +66 -0
- python/core/model_processing/onnx_custom_ops/relu.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
- python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
- python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
- python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
- python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
- python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
- python/core/utils/general_layer_functions.py +17 -12
- python/core/utils/model_registry.py +6 -3
- python/scripts/gen_and_bench.py +2 -2
- python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
- python/tests/circuit_parent_classes/test_circuit.py +561 -38
- python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
- python/tests/onnx_quantizer_tests/__init__.py +1 -0
- python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
- python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/base.py +279 -0
- python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
- python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
- python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
- python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
- python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
- python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
- python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
- python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
- python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
- python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
- python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
- python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
- python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
- python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -5,6 +5,10 @@ import logging
|
|
|
5
5
|
from dataclasses import asdict, dataclass
|
|
6
6
|
from importlib.metadata import version as get_version
|
|
7
7
|
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from onnxruntime import NodeArg
|
|
8
12
|
|
|
9
13
|
import numpy as np
|
|
10
14
|
import onnx
|
|
@@ -18,6 +22,7 @@ from onnxruntime_extensions import get_library_path
|
|
|
18
22
|
|
|
19
23
|
import python.core.model_processing.onnx_custom_ops # noqa: F401
|
|
20
24
|
from python.core import PACKAGE_NAME
|
|
25
|
+
from python.core.circuits.errors import CircuitConfigurationError
|
|
21
26
|
from python.core.model_processing.converters.base import ModelConverter, ModelType
|
|
22
27
|
from python.core.model_processing.errors import (
|
|
23
28
|
InferenceError,
|
|
@@ -242,6 +247,7 @@ class ONNXConverter(ModelConverter):
|
|
|
242
247
|
|
|
243
248
|
def analyze_layers(
|
|
244
249
|
self: ONNXConverter,
|
|
250
|
+
model: onnx.ModelProto,
|
|
245
251
|
output_name_to_shape: dict[str, list[int]] | None = None,
|
|
246
252
|
) -> tuple[list[ONNXLayer], list[ONNXLayer]]:
|
|
247
253
|
"""Analyze the onnx model graph into
|
|
@@ -263,29 +269,29 @@ class ONNXConverter(ModelConverter):
|
|
|
263
269
|
id_count = 0
|
|
264
270
|
# Apply shape inference on the model
|
|
265
271
|
if not output_name_to_shape:
|
|
266
|
-
inferred_model = shape_inference.infer_shapes(
|
|
272
|
+
inferred_model = shape_inference.infer_shapes(model)
|
|
267
273
|
self._onnx_check_model_safely(inferred_model)
|
|
268
274
|
|
|
269
275
|
output_name_to_shape = extract_shape_dict(inferred_model)
|
|
270
276
|
domain_to_version = {
|
|
271
|
-
opset.domain: opset.version for opset in
|
|
277
|
+
opset.domain: opset.version for opset in model.opset_import
|
|
272
278
|
}
|
|
273
279
|
|
|
274
280
|
id_count = 0
|
|
275
281
|
architecture = self.get_model_architecture(
|
|
276
|
-
|
|
282
|
+
model,
|
|
277
283
|
output_name_to_shape,
|
|
278
284
|
domain_to_version,
|
|
279
285
|
)
|
|
280
286
|
w_and_b = self.get_model_w_and_b(
|
|
281
|
-
|
|
287
|
+
model,
|
|
282
288
|
output_name_to_shape,
|
|
283
289
|
id_count,
|
|
284
290
|
domain_to_version,
|
|
285
291
|
)
|
|
286
292
|
except InvalidModelError:
|
|
287
293
|
raise
|
|
288
|
-
except (ValueError, TypeError, RuntimeError, OSError
|
|
294
|
+
except (ValueError, TypeError, RuntimeError, OSError) as e:
|
|
289
295
|
raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
|
|
290
296
|
except Exception as e:
|
|
291
297
|
raise LayerAnalysisError(model_type=self.model_type, reason=str(e)) from e
|
|
@@ -508,9 +514,9 @@ class ONNXConverter(ModelConverter):
|
|
|
508
514
|
opts,
|
|
509
515
|
providers=["CPUExecutionProvider"],
|
|
510
516
|
)
|
|
511
|
-
except (OSError,
|
|
517
|
+
except (OSError, RuntimeError, Exception) as e:
|
|
512
518
|
raise InferenceError(
|
|
513
|
-
model_path,
|
|
519
|
+
model_path=model_path,
|
|
514
520
|
model_type=self.model_type,
|
|
515
521
|
reason=str(e),
|
|
516
522
|
) from e
|
|
@@ -552,6 +558,7 @@ class ONNXConverter(ModelConverter):
|
|
|
552
558
|
output_shapes = {
|
|
553
559
|
out_name: output_name_to_shape.get(out_name, []) for out_name in outputs
|
|
554
560
|
}
|
|
561
|
+
|
|
555
562
|
return ONNXLayer(
|
|
556
563
|
id=layer_id,
|
|
557
564
|
name=name,
|
|
@@ -600,6 +607,7 @@ class ONNXConverter(ModelConverter):
|
|
|
600
607
|
np_data = onnx.numpy_helper.to_array(node, constant_dtype)
|
|
601
608
|
except (ValueError, TypeError, onnx.ONNXException, Exception) as e:
|
|
602
609
|
raise SerializationError(
|
|
610
|
+
model_type=self.model_type,
|
|
603
611
|
tensor_name=node.name,
|
|
604
612
|
reason=f"Failed to convert tensor: {e!s}",
|
|
605
613
|
) from e
|
|
@@ -877,8 +885,10 @@ class ONNXConverter(ModelConverter):
|
|
|
877
885
|
OSError,
|
|
878
886
|
Exception,
|
|
879
887
|
) as e:
|
|
880
|
-
msg =
|
|
881
|
-
|
|
888
|
+
msg = (
|
|
889
|
+
"Quantization failed for model"
|
|
890
|
+
f" '{getattr(self, 'model_file_name', 'unknown')}': {e!s}"
|
|
891
|
+
)
|
|
882
892
|
raise ModelConversionError(
|
|
883
893
|
msg,
|
|
884
894
|
model_type=self.model_type,
|
|
@@ -1033,38 +1043,36 @@ class ONNXConverter(ModelConverter):
|
|
|
1033
1043
|
``rescale_config``.
|
|
1034
1044
|
"""
|
|
1035
1045
|
inferred_model = shape_inference.infer_shapes(self.model)
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
scale_base=getattr(self, "scale_base", 2),
|
|
1039
|
-
scale_exponent=(getattr(self, "scale_exponent", 18)),
|
|
1040
|
-
)
|
|
1046
|
+
scale_base = getattr(self, "scale_base", 2)
|
|
1047
|
+
scale_exponent = getattr(self, "scale_exponent", 18)
|
|
1041
1048
|
|
|
1042
1049
|
# Check the model and print Y"s shape information
|
|
1043
1050
|
self._onnx_check_model_safely(inferred_model)
|
|
1044
1051
|
output_name_to_shape = extract_shape_dict(inferred_model)
|
|
1045
|
-
|
|
1046
|
-
|
|
1052
|
+
scaled_and_transformed_model = self.op_quantizer.apply_pre_analysis_transforms(
|
|
1053
|
+
inferred_model,
|
|
1054
|
+
scale_exponent=scale_exponent,
|
|
1055
|
+
scale_base=scale_base,
|
|
1056
|
+
)
|
|
1057
|
+
# Get layers in correct format
|
|
1058
|
+
(architecture, w_and_b) = self.analyze_layers(
|
|
1059
|
+
scaled_and_transformed_model,
|
|
1060
|
+
output_name_to_shape,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
def _convert_tensor_to_int_list(w: ONNXLayer) -> list:
|
|
1047
1064
|
try:
|
|
1048
|
-
|
|
1049
|
-
|
|
1065
|
+
arr = np.asarray(w.tensor).astype(np.int64)
|
|
1066
|
+
return arr.tolist()
|
|
1067
|
+
except Exception as e:
|
|
1050
1068
|
raise SerializationError(
|
|
1051
1069
|
tensor_name=getattr(w, "name", None),
|
|
1070
|
+
model_type=self.model_type,
|
|
1052
1071
|
reason=f"cannot convert to ndarray: {e}",
|
|
1053
1072
|
) from e
|
|
1054
1073
|
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
if "bias" in w.name:
|
|
1058
|
-
w_and_b_scaled = w_and_b_array * scaling * scaling
|
|
1059
|
-
else:
|
|
1060
|
-
w_and_b_scaled = w_and_b_array * scaling
|
|
1061
|
-
w_and_b_out = w_and_b_scaled.astype(np.int64).tolist()
|
|
1062
|
-
w.tensor = w_and_b_out
|
|
1063
|
-
except (ValueError, TypeError, OverflowError, Exception) as e:
|
|
1064
|
-
raise SerializationError(
|
|
1065
|
-
tensor_name=getattr(w, "name", None),
|
|
1066
|
-
reason=str(e),
|
|
1067
|
-
) from e
|
|
1074
|
+
for w in w_and_b:
|
|
1075
|
+
w.tensor = _convert_tensor_to_int_list(w)
|
|
1068
1076
|
|
|
1069
1077
|
inputs = []
|
|
1070
1078
|
outputs = []
|
|
@@ -1118,45 +1126,92 @@ class ONNXConverter(ModelConverter):
|
|
|
1118
1126
|
rescale_config=getattr(self, "rescale_config", {}),
|
|
1119
1127
|
)
|
|
1120
1128
|
|
|
1129
|
+
def _process_single_input_for_get_outputs(
|
|
1130
|
+
self: ONNXConverter,
|
|
1131
|
+
value: np.ndarray | torch.Tensor,
|
|
1132
|
+
input_def: NodeArg,
|
|
1133
|
+
) -> np.ndarray:
|
|
1134
|
+
"""Process a single input tensor according to dtype and scale settings."""
|
|
1135
|
+
value = torch.as_tensor(value)
|
|
1136
|
+
|
|
1137
|
+
if value.dtype in (
|
|
1138
|
+
torch.int8,
|
|
1139
|
+
torch.int16,
|
|
1140
|
+
torch.int32,
|
|
1141
|
+
torch.int64,
|
|
1142
|
+
torch.uint8,
|
|
1143
|
+
):
|
|
1144
|
+
value = value.double()
|
|
1145
|
+
value = value / BaseOpQuantizer.get_scaling(
|
|
1146
|
+
scale_base=self.scale_base,
|
|
1147
|
+
scale_exponent=self.scale_exponent,
|
|
1148
|
+
)
|
|
1149
|
+
|
|
1150
|
+
if input_def.type == "tensor(double)":
|
|
1151
|
+
return np.asarray(value).astype(np.float64)
|
|
1152
|
+
return np.asarray(value)
|
|
1153
|
+
|
|
1121
1154
|
def get_outputs(
|
|
1122
1155
|
self: ONNXConverter,
|
|
1123
|
-
inputs: np.ndarray | torch.Tensor,
|
|
1156
|
+
inputs: np.ndarray | torch.Tensor | dict[str, np.ndarray | torch.Tensor],
|
|
1124
1157
|
) -> list[np.ndarray]:
|
|
1125
1158
|
"""Run the currently loaded (quantized) model via ONNX Runtime.
|
|
1126
1159
|
|
|
1127
1160
|
Args:
|
|
1128
|
-
inputs
|
|
1161
|
+
inputs: Single tensor/array or a dict of named inputs.
|
|
1129
1162
|
|
|
1130
1163
|
Returns:
|
|
1131
|
-
|
|
1164
|
+
list[np.ndarray]: List of output arrays from ONNX Runtime inference.
|
|
1132
1165
|
"""
|
|
1166
|
+
|
|
1167
|
+
def _raise_type_error(inputs: np.ndarray | torch.Tensor) -> None:
|
|
1168
|
+
msg = (
|
|
1169
|
+
"Expected np.ndarray, torch.Tensor, or dict "
|
|
1170
|
+
f"for inputs, got {type(inputs)}"
|
|
1171
|
+
)
|
|
1172
|
+
raise TypeError(msg)
|
|
1173
|
+
|
|
1174
|
+
def _raise_value_error(msg: str) -> None:
|
|
1175
|
+
raise ValueError(msg)
|
|
1176
|
+
|
|
1177
|
+
def _raise_no_scale_configs() -> None:
|
|
1178
|
+
raise CircuitConfigurationError(
|
|
1179
|
+
missing_attributes=["scale_base", "scale_exponent"],
|
|
1180
|
+
)
|
|
1181
|
+
|
|
1182
|
+
scale_base = getattr(self, "scale_base", None)
|
|
1183
|
+
scale_exponent = getattr(self, "scale_exponent", None)
|
|
1184
|
+
|
|
1133
1185
|
try:
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
):
|
|
1146
|
-
inputs
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1186
|
+
input_defs = self.ort_sess.get_inputs()
|
|
1187
|
+
output_defs = self.ort_sess.get_outputs()
|
|
1188
|
+
output_names = [out.name for out in output_defs]
|
|
1189
|
+
|
|
1190
|
+
if scale_base is None or scale_exponent is None:
|
|
1191
|
+
_raise_no_scale_configs()
|
|
1192
|
+
|
|
1193
|
+
# Normalize inputs into a dict
|
|
1194
|
+
if isinstance(inputs, (np.ndarray, torch.Tensor)):
|
|
1195
|
+
input_name = input_defs[0].name
|
|
1196
|
+
inputs = {input_name: inputs}
|
|
1197
|
+
elif not isinstance(inputs, dict):
|
|
1198
|
+
_raise_type_error(inputs)
|
|
1199
|
+
|
|
1200
|
+
# Process inputs
|
|
1201
|
+
processed_inputs = {}
|
|
1202
|
+
for input_def in input_defs:
|
|
1203
|
+
name = input_def.name
|
|
1204
|
+
if name not in inputs:
|
|
1205
|
+
_raise_value_error(
|
|
1206
|
+
f"Missing required input '{name}' in provided inputs",
|
|
1207
|
+
)
|
|
1208
|
+
processed_inputs[name] = self._process_single_input_for_get_outputs(
|
|
1209
|
+
inputs[name],
|
|
1210
|
+
input_def,
|
|
1150
1211
|
)
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
{input_name: np.asarray(inputs).astype(np.float64)},
|
|
1155
|
-
)
|
|
1156
|
-
return self.ort_sess.run(
|
|
1157
|
-
[output_name],
|
|
1158
|
-
{input_name: np.asarray(inputs)},
|
|
1159
|
-
)
|
|
1212
|
+
|
|
1213
|
+
return self.ort_sess.run(output_names, processed_inputs)
|
|
1214
|
+
|
|
1160
1215
|
except (RuntimeError, ValueError, TypeError, Exception) as e:
|
|
1161
1216
|
raise InferenceError(
|
|
1162
1217
|
model_path=getattr(self, "quantized_model_path", None),
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from onnxruntime_extensions import PyCustomOpDef, onnx_op
|
|
5
|
+
|
|
6
|
+
from .custom_helpers import rescaling
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@onnx_op(
|
|
10
|
+
op_type="Int64BatchNorm",
|
|
11
|
+
domain="ai.onnx.contrib",
|
|
12
|
+
inputs=[
|
|
13
|
+
PyCustomOpDef.dt_int64, # X (int64)
|
|
14
|
+
PyCustomOpDef.dt_int64, # mul (int64 scaled multiplier)
|
|
15
|
+
PyCustomOpDef.dt_int64, # add (int64 scaled adder)
|
|
16
|
+
PyCustomOpDef.dt_int64, # scaling_factor
|
|
17
|
+
],
|
|
18
|
+
outputs=[PyCustomOpDef.dt_int64],
|
|
19
|
+
attrs={"rescale": PyCustomOpDef.dt_int64},
|
|
20
|
+
)
|
|
21
|
+
def int64_batchnorm(
|
|
22
|
+
x: np.ndarray,
|
|
23
|
+
mul: np.ndarray,
|
|
24
|
+
add: np.ndarray,
|
|
25
|
+
scaling_factor: np.ndarray | None = None,
|
|
26
|
+
rescale: int | None = None,
|
|
27
|
+
) -> np.ndarray:
|
|
28
|
+
"""
|
|
29
|
+
Int64 BatchNorm (folded into affine transform).
|
|
30
|
+
|
|
31
|
+
Computes:
|
|
32
|
+
Y = X * mul + add
|
|
33
|
+
where mul/add are already scaled to int64.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
x : Input int64 tensor
|
|
38
|
+
mul : Per-channel int64 scale multipliers
|
|
39
|
+
add : Per-channel int64 bias terms
|
|
40
|
+
scaling_factor: factor to rescale
|
|
41
|
+
rescale : Optional flag to apply post-scaling
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
numpy.ndarray (int64)
|
|
46
|
+
"""
|
|
47
|
+
try:
|
|
48
|
+
# Broadcasting shapes must match batchnorm layout: NCHW
|
|
49
|
+
# Typically mul/add have shape [C]
|
|
50
|
+
dims_x = len(x.shape)
|
|
51
|
+
dim_ones = (1,) * (dims_x - 2)
|
|
52
|
+
mul = mul.reshape(-1, *dim_ones)
|
|
53
|
+
add = add.reshape(-1, *dim_ones)
|
|
54
|
+
|
|
55
|
+
y = x * mul + add
|
|
56
|
+
|
|
57
|
+
if rescale is not None:
|
|
58
|
+
y = rescaling(scaling_factor, rescale, y)
|
|
59
|
+
|
|
60
|
+
return y.astype(np.int64)
|
|
61
|
+
|
|
62
|
+
except Exception as e:
|
|
63
|
+
msg = f"Int64BatchNorm failed: {e}"
|
|
64
|
+
raise RuntimeError(msg) from e
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnxruntime_extensions import PyCustomOpDef, onnx_op
|
|
3
|
+
|
|
4
|
+
from .custom_helpers import rescaling
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@onnx_op(
|
|
8
|
+
op_type="Int64Mul",
|
|
9
|
+
domain="ai.onnx.contrib",
|
|
10
|
+
inputs=[
|
|
11
|
+
PyCustomOpDef.dt_int64,
|
|
12
|
+
PyCustomOpDef.dt_int64,
|
|
13
|
+
PyCustomOpDef.dt_int64, # Scalar
|
|
14
|
+
],
|
|
15
|
+
outputs=[PyCustomOpDef.dt_int64],
|
|
16
|
+
attrs={
|
|
17
|
+
"rescale": PyCustomOpDef.dt_int64,
|
|
18
|
+
},
|
|
19
|
+
)
|
|
20
|
+
def int64_mul(
|
|
21
|
+
a: np.ndarray,
|
|
22
|
+
b: np.ndarray,
|
|
23
|
+
scaling_factor: np.ndarray | None = None,
|
|
24
|
+
rescale: int | None = None,
|
|
25
|
+
) -> np.ndarray:
|
|
26
|
+
"""
|
|
27
|
+
Performs a Mul (hadamard product) operation on int64 input tensors.
|
|
28
|
+
|
|
29
|
+
This function is registered as a custom ONNX operator via onnxruntime_extensions
|
|
30
|
+
and is used in the JSTprove quantized inference pipeline.
|
|
31
|
+
It applies Mul with the rescaling the outputs back to the original scale.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
a : np.ndarray
|
|
36
|
+
First input tensor with dtype int64.
|
|
37
|
+
b : np.ndarray
|
|
38
|
+
Second input tensor with dtype int64.
|
|
39
|
+
scaling_factor : Scaling factor for rescaling the output.
|
|
40
|
+
Optional scalar tensor for rescaling when rescale=1.
|
|
41
|
+
rescale : int, optional
|
|
42
|
+
Whether to apply rescaling (0=no, 1=yes).
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
numpy.ndarray
|
|
47
|
+
Mul tensor with dtype int64.
|
|
48
|
+
|
|
49
|
+
Notes
|
|
50
|
+
-----
|
|
51
|
+
- This op is part of the `ai.onnx.contrib` custom domain.
|
|
52
|
+
- ONNX Runtime Extensions is required to register this op.
|
|
53
|
+
|
|
54
|
+
References
|
|
55
|
+
----------
|
|
56
|
+
For more information on the Mul operation, please refer to the
|
|
57
|
+
ONNX standard Mul operator documentation:
|
|
58
|
+
https://onnx.ai/onnx/operators/onnx__Mul.html
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
result = a * b
|
|
62
|
+
result = rescaling(scaling_factor, rescale, result)
|
|
63
|
+
return result.astype(np.int64)
|
|
64
|
+
except Exception as e:
|
|
65
|
+
msg = f"Int64Mul failed: {e}"
|
|
66
|
+
raise RuntimeError(msg) from e
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
import onnx
|
|
7
|
+
|
|
8
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
9
|
+
BaseOpQuantizer,
|
|
10
|
+
QuantizerBase,
|
|
11
|
+
ScaleConfig,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class QuantizeAdd(QuantizerBase):
|
|
16
|
+
OP_TYPE = "Add"
|
|
17
|
+
DOMAIN = ""
|
|
18
|
+
USE_WB = True
|
|
19
|
+
USE_SCALING = False
|
|
20
|
+
SCALE_PLAN: ClassVar = {0: 1, 1: 1}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AddQuantizer(BaseOpQuantizer, QuantizeAdd):
|
|
24
|
+
"""
|
|
25
|
+
Quantizer for ONNX Add layers.
|
|
26
|
+
|
|
27
|
+
- Uses standard ONNX Add layer in standard domain, and
|
|
28
|
+
makes relevant additional changes to the graph.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self: AddQuantizer,
|
|
33
|
+
new_initializers: list[onnx.TensorProto] | None = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
super().__init__()
|
|
36
|
+
# Only replace if caller provided something
|
|
37
|
+
if new_initializers is not None:
|
|
38
|
+
self.new_initializers = new_initializers
|
|
39
|
+
|
|
40
|
+
def quantize(
|
|
41
|
+
self: AddQuantizer,
|
|
42
|
+
node: onnx.NodeProto,
|
|
43
|
+
graph: onnx.GraphProto,
|
|
44
|
+
scale_config: ScaleConfig,
|
|
45
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
46
|
+
) -> list[onnx.NodeProto]:
|
|
47
|
+
return QuantizeAdd.quantize(self, node, graph, scale_config, initializer_map)
|
|
48
|
+
|
|
49
|
+
def check_supported(
|
|
50
|
+
self: AddQuantizer,
|
|
51
|
+
node: onnx.NodeProto,
|
|
52
|
+
initializer_map: dict[str, onnx.TensorProto] | None = None,
|
|
53
|
+
) -> None:
|
|
54
|
+
pass
|