JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +5 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from python.core.model_processing.converters.base import ModelType
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModelConversionError(Exception):
|
|
10
|
+
"""Base class for all model conversion errors."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self: ModelConversionError,
|
|
14
|
+
message: str,
|
|
15
|
+
model_type: ModelType,
|
|
16
|
+
context: dict | None = None,
|
|
17
|
+
) -> None:
|
|
18
|
+
self.message = message
|
|
19
|
+
self.model_type = model_type
|
|
20
|
+
self.context = context or {}
|
|
21
|
+
super().__init__(self.__str__())
|
|
22
|
+
|
|
23
|
+
def __str__(self: ModelConversionError) -> str:
|
|
24
|
+
msg = f"Error converting {self.model_type.value} model: {self.message}"
|
|
25
|
+
if self.context:
|
|
26
|
+
ctx_str = ", ".join(f"{k}={v}" for k, v in self.context.items())
|
|
27
|
+
msg += f"(Context: {ctx_str})"
|
|
28
|
+
return msg
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ModelLoadError(ModelConversionError):
|
|
32
|
+
"""Raised when an ONNX model cannot be loaded."""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self: ModelLoadError,
|
|
36
|
+
file_path: str,
|
|
37
|
+
model_type: ModelType,
|
|
38
|
+
reason: str = "",
|
|
39
|
+
) -> None:
|
|
40
|
+
message = f"Failed to load {model_type.value} model from '{file_path}'."
|
|
41
|
+
if reason:
|
|
42
|
+
message += f" Reason: {reason}"
|
|
43
|
+
super().__init__(message, model_type, context={"file_path": file_path})
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ModelSaveError(ModelConversionError):
|
|
47
|
+
"""Raised when saving a model fails."""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self: ModelSaveError,
|
|
51
|
+
file_path: str,
|
|
52
|
+
model_type: ModelType,
|
|
53
|
+
reason: str = "",
|
|
54
|
+
) -> None:
|
|
55
|
+
message = f"Failed to save {model_type.value} model to '{file_path}'."
|
|
56
|
+
if reason:
|
|
57
|
+
message += f" Reason: {reason}"
|
|
58
|
+
super().__init__(message, model_type, context={"file_path": file_path})
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class InferenceError(ModelConversionError):
|
|
62
|
+
"""Raised when inference via ONNX Runtime fails."""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self: InferenceError,
|
|
66
|
+
model_type: ModelType,
|
|
67
|
+
model_path: str | None = None,
|
|
68
|
+
reason: str = "",
|
|
69
|
+
) -> None:
|
|
70
|
+
message = f" {model_type.value} inference failed."
|
|
71
|
+
if model_path:
|
|
72
|
+
message += f" Model: '{model_path}'."
|
|
73
|
+
if reason:
|
|
74
|
+
message += f" Reason: {reason}"
|
|
75
|
+
super().__init__(message, model_type, context={"model_path": model_path})
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class LayerAnalysisError(ModelConversionError):
|
|
79
|
+
"""Raised when analyzing model layers fails."""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self: LayerAnalysisError,
|
|
83
|
+
model_type: ModelType,
|
|
84
|
+
layer_name: str | None = None,
|
|
85
|
+
reason: str = "",
|
|
86
|
+
) -> None:
|
|
87
|
+
message = "Layer analysis failed."
|
|
88
|
+
if layer_name:
|
|
89
|
+
message += f" Problematic layer: '{layer_name}'."
|
|
90
|
+
if reason:
|
|
91
|
+
message += f" Reason: {reason}"
|
|
92
|
+
super().__init__(message, model_type, context={"layer_name": layer_name})
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class IOInfoExtractionError(ModelConversionError):
|
|
96
|
+
"""Raised when extracting input/output info fails."""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self: IOInfoExtractionError,
|
|
100
|
+
model_type: ModelType,
|
|
101
|
+
model_path: str | None = None,
|
|
102
|
+
reason: str = "",
|
|
103
|
+
) -> None:
|
|
104
|
+
message = "Failed to extract input/output info from model."
|
|
105
|
+
if model_path:
|
|
106
|
+
message += f" Model: '{model_path}'."
|
|
107
|
+
if reason:
|
|
108
|
+
message += f" Reason: {reason}"
|
|
109
|
+
super().__init__(message, model_type, context={"model_path": model_path})
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class InvalidModelError(ModelConversionError):
|
|
113
|
+
"""Raised when an ONNX model fails validation checks (onnx.checker)."""
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self: InvalidModelError,
|
|
117
|
+
model_type: ModelType,
|
|
118
|
+
model_path: str | None = None,
|
|
119
|
+
reason: str = "",
|
|
120
|
+
) -> None:
|
|
121
|
+
msg = f"The {model_type.value} model is invalid."
|
|
122
|
+
if model_path:
|
|
123
|
+
msg += f" Model: '{model_path}'."
|
|
124
|
+
if reason:
|
|
125
|
+
msg += f" Reason: {reason}"
|
|
126
|
+
super().__init__(
|
|
127
|
+
message=msg,
|
|
128
|
+
model_type=model_type,
|
|
129
|
+
context={"model_path": model_path},
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class SerializationError(ModelConversionError):
|
|
134
|
+
"""Raised when model data cannot be serialized to the required format."""
|
|
135
|
+
|
|
136
|
+
def __init__(
|
|
137
|
+
self: SerializationError,
|
|
138
|
+
model_type: ModelType,
|
|
139
|
+
tensor_name: str | None = None,
|
|
140
|
+
reason: str = "",
|
|
141
|
+
) -> None:
|
|
142
|
+
message = "Failed to serialize model data."
|
|
143
|
+
if tensor_name:
|
|
144
|
+
message += f" Tensor: '{tensor_name}'."
|
|
145
|
+
if reason:
|
|
146
|
+
message += f" Reason: {reason}"
|
|
147
|
+
super().__init__(message, model_type, context={"tensor_name": tensor_name})
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import pkgutil
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
# Get the package name of the current module
|
|
6
|
+
package_name = __name__
|
|
7
|
+
|
|
8
|
+
# Dynamically import all .py files in this package directory (except __init__.py)
|
|
9
|
+
package_dir = os.path.dirname(__file__)
|
|
10
|
+
|
|
11
|
+
__all__ = []
|
|
12
|
+
|
|
13
|
+
for _, module_name, is_pkg in pkgutil.iter_modules([package_dir]):
|
|
14
|
+
if not is_pkg and (module_name != "custom_helpers"):
|
|
15
|
+
importlib.import_module(f"{package_name}.{module_name}")
|
|
16
|
+
__all__.append(module_name)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as f
|
|
6
|
+
from onnxruntime_extensions import PyCustomOpDef, onnx_op
|
|
7
|
+
|
|
8
|
+
from .custom_helpers import parse_attr, rescaling
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@onnx_op(
|
|
12
|
+
op_type="Int64Conv",
|
|
13
|
+
domain="ai.onnx.contrib",
|
|
14
|
+
inputs=[
|
|
15
|
+
PyCustomOpDef.dt_int64, # X
|
|
16
|
+
PyCustomOpDef.dt_int64, # W
|
|
17
|
+
PyCustomOpDef.dt_int64, # B
|
|
18
|
+
PyCustomOpDef.dt_int64, # scaling factor
|
|
19
|
+
],
|
|
20
|
+
outputs=[PyCustomOpDef.dt_int64],
|
|
21
|
+
attrs={
|
|
22
|
+
"auto_pad": PyCustomOpDef.dt_string,
|
|
23
|
+
"strides": PyCustomOpDef.dt_string,
|
|
24
|
+
"pads": PyCustomOpDef.dt_string,
|
|
25
|
+
"dilations": PyCustomOpDef.dt_string,
|
|
26
|
+
"group": PyCustomOpDef.dt_int64,
|
|
27
|
+
"kernel_shape": PyCustomOpDef.dt_string,
|
|
28
|
+
"rescale": PyCustomOpDef.dt_int64,
|
|
29
|
+
},
|
|
30
|
+
)
|
|
31
|
+
def int64_conv(
|
|
32
|
+
x: np.ndarray,
|
|
33
|
+
w: np.ndarray,
|
|
34
|
+
b: np.ndarray | None = None,
|
|
35
|
+
scaling_factor: np.ndarray | None = None,
|
|
36
|
+
auto_pad: str | None = None,
|
|
37
|
+
dilations: str | None = None,
|
|
38
|
+
group: int | None = None,
|
|
39
|
+
kernel_shape: str | None = None,
|
|
40
|
+
pads: str | None = None,
|
|
41
|
+
strides: str | None = None,
|
|
42
|
+
rescale: int | None = None,
|
|
43
|
+
) -> np.ndarray:
|
|
44
|
+
"""
|
|
45
|
+
Performs a convolution on int64 input tensors.
|
|
46
|
+
|
|
47
|
+
This function is registered as a custom ONNX operator via onnxruntime_extensions
|
|
48
|
+
and is used in the JSTprove quantized inference pipeline. It parses ONNX-style
|
|
49
|
+
convolution attributes, applies convolution
|
|
50
|
+
and optionally rescales the result.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
X : Input tensor with dtype int64.
|
|
55
|
+
W : Convolution weight tensor with dtype int64.
|
|
56
|
+
B : Optional bias tensor with dtype int64.
|
|
57
|
+
scaling_factor : Scaling factor for rescaling the output.
|
|
58
|
+
auto_pad : Optional ONNX auto padding type (`SAME_UPPER`, `SAME_LOWER`, `VALID`).
|
|
59
|
+
dilations : Dilation values for the convolution (default: `[1, 1]`).
|
|
60
|
+
group : Group value for the convolution (default: 1).
|
|
61
|
+
kernel_shape : Kernel shape (default: `[3, 3]`).
|
|
62
|
+
pads : Padding values (default: `[0, 0, 0, 0]`).
|
|
63
|
+
strides : Stride values (default: `[1, 1]`).
|
|
64
|
+
rescale : Optional flag to apply output rescaling or not.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
numpy.ndarray
|
|
69
|
+
Convolved tensor with dtype int64.
|
|
70
|
+
|
|
71
|
+
Notes
|
|
72
|
+
-----
|
|
73
|
+
- This op is part of the `ai.onnx.contrib` custom domain.
|
|
74
|
+
- ONNX Runtime Extensions is required to register this op.
|
|
75
|
+
|
|
76
|
+
References
|
|
77
|
+
----------
|
|
78
|
+
For more information on the convolution operation, please refer to the
|
|
79
|
+
ONNX standard Conv operator documentation:
|
|
80
|
+
https://onnx.ai/onnx/operators/onnx__Conv.html
|
|
81
|
+
"""
|
|
82
|
+
_ = auto_pad
|
|
83
|
+
try:
|
|
84
|
+
strides = parse_attr(strides, [1, 1])
|
|
85
|
+
dilations = parse_attr(dilations, [1, 1])
|
|
86
|
+
pads = parse_attr(pads, [0, 0, 0, 0])
|
|
87
|
+
kernel_shape = parse_attr(kernel_shape, [3, 3])
|
|
88
|
+
|
|
89
|
+
x = torch.from_numpy(x)
|
|
90
|
+
w = torch.from_numpy(w)
|
|
91
|
+
b = torch.from_numpy(b)
|
|
92
|
+
|
|
93
|
+
result = (
|
|
94
|
+
f.conv2d(
|
|
95
|
+
x,
|
|
96
|
+
w,
|
|
97
|
+
bias=b,
|
|
98
|
+
stride=strides,
|
|
99
|
+
padding=pads[:2],
|
|
100
|
+
dilation=dilations,
|
|
101
|
+
groups=group,
|
|
102
|
+
)
|
|
103
|
+
.numpy()
|
|
104
|
+
.astype(np.int64)
|
|
105
|
+
)
|
|
106
|
+
result = rescaling(scaling_factor, rescale, result)
|
|
107
|
+
return result.astype(np.int64)
|
|
108
|
+
|
|
109
|
+
except Exception as e:
|
|
110
|
+
msg = f"Int64Conv failed: {e}"
|
|
111
|
+
raise RuntimeError(msg) from e
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TypeVar
|
|
4
|
+
|
|
5
|
+
T = TypeVar("T")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def rescaling(scaling_factor: int, rescale: int, y: int) -> int:
|
|
9
|
+
"""Applies integer rescaling to a value based on the given scaling factor.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
scaling_factor (int): The divisor to apply when rescaling.
|
|
13
|
+
Must be provided if `rescale` is True.
|
|
14
|
+
rescale (int): Whether to apply rescaling. (0 -> no rescaling, 1 -> rescaling).
|
|
15
|
+
Y (int): The value to be rescaled.
|
|
16
|
+
|
|
17
|
+
Raises:
|
|
18
|
+
NotImplementedError: If `rescale` is 1 but `scaling_factor` is not provided.
|
|
19
|
+
NotImplementedError: If `rescale` is not 0 or 1.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
int: The rescaled value if `rescale` is True, otherwise the original value.
|
|
23
|
+
"""
|
|
24
|
+
if rescale == 1:
|
|
25
|
+
if scaling_factor is None:
|
|
26
|
+
msg = "scaling_factor must be specified when rescale=1"
|
|
27
|
+
raise ValueError(msg)
|
|
28
|
+
return y // scaling_factor
|
|
29
|
+
if rescale == 0:
|
|
30
|
+
return y
|
|
31
|
+
msg = f"Rescale must be 0 or 1, got {rescale}"
|
|
32
|
+
raise ValueError(msg)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def parse_attr(attr: str, default: T) -> T:
|
|
36
|
+
"""Parses an attribute list of strings into a list of integers.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
attr (str): Attribute to parse. If a string, it must be
|
|
40
|
+
comma-separated integers (e.g., "1, 2, 3").
|
|
41
|
+
If None, returns `default`.
|
|
42
|
+
default (T): Default value to return if `attr` is None.
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ValueError: If `attr` is a string but cannot be parsed into integers.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
T: Parsed list of integers if attr is provided, otherwise the default value.
|
|
49
|
+
"""
|
|
50
|
+
if attr is None:
|
|
51
|
+
return default
|
|
52
|
+
try:
|
|
53
|
+
return [int(x.strip()) for x in attr.split(",")]
|
|
54
|
+
except ValueError as e:
|
|
55
|
+
msg = f"Invalid attribute format: {attr}"
|
|
56
|
+
raise ValueError(msg) from e
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from onnxruntime_extensions import PyCustomOpDef, onnx_op
|
|
5
|
+
|
|
6
|
+
from .custom_helpers import rescaling
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@onnx_op(
|
|
10
|
+
op_type="Int64Gemm",
|
|
11
|
+
domain="ai.onnx.contrib",
|
|
12
|
+
inputs=[
|
|
13
|
+
PyCustomOpDef.dt_int64, # X
|
|
14
|
+
PyCustomOpDef.dt_int64, # W
|
|
15
|
+
PyCustomOpDef.dt_int64, # B
|
|
16
|
+
PyCustomOpDef.dt_int64, # Scalar
|
|
17
|
+
],
|
|
18
|
+
outputs=[PyCustomOpDef.dt_int64],
|
|
19
|
+
attrs={
|
|
20
|
+
"alpha": PyCustomOpDef.dt_float,
|
|
21
|
+
"beta": PyCustomOpDef.dt_float,
|
|
22
|
+
"transA": PyCustomOpDef.dt_int64,
|
|
23
|
+
"transB": PyCustomOpDef.dt_int64,
|
|
24
|
+
"rescale": PyCustomOpDef.dt_int64,
|
|
25
|
+
},
|
|
26
|
+
)
|
|
27
|
+
def int64_gemm7(
|
|
28
|
+
a: np.ndarray,
|
|
29
|
+
b: np.ndarray,
|
|
30
|
+
c: np.ndarray | None = None,
|
|
31
|
+
scaling_factor: np.ndarray | None = None,
|
|
32
|
+
alpha: float | None = None,
|
|
33
|
+
beta: float | None = None,
|
|
34
|
+
transA: int | None = None, # noqa: N803
|
|
35
|
+
transB: int | None = None, # noqa: N803
|
|
36
|
+
rescale: int | None = None,
|
|
37
|
+
) -> np.ndarray:
|
|
38
|
+
"""
|
|
39
|
+
Performs a Gemm (alternatively: Linear layer) on int64 input tensors.
|
|
40
|
+
|
|
41
|
+
This function is registered as a custom ONNX operator via onnxruntime_extensions
|
|
42
|
+
and is used in the JSTprove quantized inference pipeline. It parses ONNX-style
|
|
43
|
+
gemm attributes, applies gemm
|
|
44
|
+
and optionally rescales the result.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
a : Input tensor with dtype int64.
|
|
49
|
+
b : Gemm weight tensor with dtype int64.
|
|
50
|
+
c : Optional bias tensor with dtype int64.
|
|
51
|
+
scaling_factor : Scaling factor for rescaling the output.
|
|
52
|
+
alpha : alpha value for Gemm operation.
|
|
53
|
+
beta : beta value for Gemm operation.
|
|
54
|
+
transA : Transpose the a matrix before the Gemm operation
|
|
55
|
+
transB : Transpose the b matrix before the Gemm operation
|
|
56
|
+
rescale : Optional flag to apply output rescaling or not.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
numpy.ndarray
|
|
61
|
+
Gemm tensor with dtype int64.
|
|
62
|
+
|
|
63
|
+
Notes
|
|
64
|
+
-----
|
|
65
|
+
- This op is part of the `ai.onnx.contrib` custom domain.
|
|
66
|
+
- ONNX Runtime Extensions is required to register this op.
|
|
67
|
+
|
|
68
|
+
References
|
|
69
|
+
----------
|
|
70
|
+
For more information on the gemm operation, please refer to the
|
|
71
|
+
ONNX standard Gemm operator documentation:
|
|
72
|
+
https://onnx.ai/onnx/operators/onnx__Gemm.html
|
|
73
|
+
"""
|
|
74
|
+
try:
|
|
75
|
+
alpha = int(alpha)
|
|
76
|
+
beta = int(beta)
|
|
77
|
+
|
|
78
|
+
a = a.T if transA else a
|
|
79
|
+
b = b.T if transB else b
|
|
80
|
+
|
|
81
|
+
result = alpha * (a @ b)
|
|
82
|
+
|
|
83
|
+
if c is not None:
|
|
84
|
+
result += beta * c
|
|
85
|
+
|
|
86
|
+
result = rescaling(scaling_factor, rescale, result)
|
|
87
|
+
return result.astype(np.int64)
|
|
88
|
+
|
|
89
|
+
except Exception as e:
|
|
90
|
+
msg = f"Int64Gemm failed: {e}"
|
|
91
|
+
raise RuntimeError(msg) from e
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as f
|
|
6
|
+
from onnxruntime_extensions import PyCustomOpDef, onnx_op
|
|
7
|
+
|
|
8
|
+
from .custom_helpers import parse_attr
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@onnx_op(
|
|
12
|
+
op_type="Int64MaxPool",
|
|
13
|
+
domain="ai.onnx.contrib",
|
|
14
|
+
inputs=[PyCustomOpDef.dt_int64], # input tensor
|
|
15
|
+
outputs=[PyCustomOpDef.dt_int64],
|
|
16
|
+
attrs={
|
|
17
|
+
"strides": PyCustomOpDef.dt_string,
|
|
18
|
+
"pads": PyCustomOpDef.dt_string,
|
|
19
|
+
"kernel_shape": PyCustomOpDef.dt_string,
|
|
20
|
+
"dilations": PyCustomOpDef.dt_string,
|
|
21
|
+
},
|
|
22
|
+
)
|
|
23
|
+
def int64_maxpool(
|
|
24
|
+
x: np.ndarray,
|
|
25
|
+
strides: str | None = None,
|
|
26
|
+
pads: str | None = None,
|
|
27
|
+
kernel_shape: str | None = None,
|
|
28
|
+
dilations: str | None = None,
|
|
29
|
+
) -> np.ndarray:
|
|
30
|
+
"""
|
|
31
|
+
Performs a MaxPool operation on int64 input tensors.
|
|
32
|
+
|
|
33
|
+
This function is registered as a custom ONNX operator via onnxruntime_extensions
|
|
34
|
+
and is used in the JSTprove quantized inference pipeline. It parses ONNX-style
|
|
35
|
+
maxpool attributes and applies maxpool.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
X : Input tensor with dtype int64.
|
|
40
|
+
kernel_shape : Kernel shape (default: `[2, 2]`).
|
|
41
|
+
pads : Padding values (default: `[0, 0, 0, 0]`).
|
|
42
|
+
strides : Stride values (default: `[1, 1]`).
|
|
43
|
+
dilations : dilation values (default: `[1, 1]`).
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
numpy.ndarray
|
|
49
|
+
Maxpool tensor with dtype int64.
|
|
50
|
+
|
|
51
|
+
Notes
|
|
52
|
+
-----
|
|
53
|
+
- This op is part of the `ai.onnx.contrib` custom domain.
|
|
54
|
+
- ONNX Runtime Extensions is required to register this op.
|
|
55
|
+
|
|
56
|
+
References
|
|
57
|
+
----------
|
|
58
|
+
For more information on the maxpool operation, please refer to the
|
|
59
|
+
ONNX standard MaxPool operator documentation:
|
|
60
|
+
https://onnx.ai/onnx/operators/onnx__MaxPool.html
|
|
61
|
+
"""
|
|
62
|
+
try:
|
|
63
|
+
strides = parse_attr(strides, [1, 1])
|
|
64
|
+
pads = parse_attr(pads, [0, 0])
|
|
65
|
+
kernel_size = parse_attr(kernel_shape, [2, 2])
|
|
66
|
+
dilations = parse_attr(dilations, [1, 1])
|
|
67
|
+
|
|
68
|
+
x = torch.from_numpy(x)
|
|
69
|
+
result = f.max_pool2d(
|
|
70
|
+
x,
|
|
71
|
+
kernel_size=kernel_size,
|
|
72
|
+
stride=strides,
|
|
73
|
+
padding=pads[:2],
|
|
74
|
+
dilation=dilations,
|
|
75
|
+
)
|
|
76
|
+
return result.numpy().astype(np.int64)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
msg = f"Int64Gemm failed: {e}"
|
|
79
|
+
raise RuntimeError(msg) from e
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import onnx
|
|
6
|
+
from onnx import AttributeProto, numpy_helper
|
|
7
|
+
|
|
8
|
+
ATTRIBUTE_PARSERS = {
|
|
9
|
+
AttributeProto.FLOAT: lambda a: a.f,
|
|
10
|
+
AttributeProto.INT: lambda a: a.i,
|
|
11
|
+
AttributeProto.STRING: lambda a: a.s.decode("utf-8", errors="replace"),
|
|
12
|
+
AttributeProto.FLOATS: lambda a: list(a.floats),
|
|
13
|
+
AttributeProto.INTS: lambda a: list(a.ints),
|
|
14
|
+
AttributeProto.STRINGS: lambda a: [
|
|
15
|
+
s.decode("utf-8", errors="replace") for s in a.strings
|
|
16
|
+
],
|
|
17
|
+
AttributeProto.TENSOR: lambda a: numpy_helper.to_array(a.t).tolist(),
|
|
18
|
+
AttributeProto.TENSORS: lambda a: [
|
|
19
|
+
numpy_helper.to_array(t).tolist() for t in a.tensors
|
|
20
|
+
],
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def parse_attribute(
|
|
25
|
+
attr: AttributeProto,
|
|
26
|
+
) -> float | int | str | list[int] | list[float] | list[str]:
|
|
27
|
+
"""Parse ONNX attributes into a Python-native type.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
attr (AttributeProto): The ONNX attribute to parse.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If the attribute type is unsupported.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Any: The parsed attribute value as a Python type.
|
|
37
|
+
"""
|
|
38
|
+
parser = ATTRIBUTE_PARSERS.get(attr.type)
|
|
39
|
+
if parser is None:
|
|
40
|
+
msg = f"Unsupported attribute type: {attr.type}"
|
|
41
|
+
raise ValueError(msg)
|
|
42
|
+
return parser(attr)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def parse_attributes(attrs: list[AttributeProto]) -> dict[str, Any]:
|
|
46
|
+
"""Parse multiple ONNX attributes into a dictionary.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
attrs (list[AttributeProto]): List of ONNX attributes.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
dict[str, Any]: Mapping of attribute names to their parsed values.
|
|
53
|
+
"""
|
|
54
|
+
return {attr.name: parse_attribute(attr) for attr in attrs}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def extract_shape_dict(inferred_model: onnx.GraphProto) -> dict[str, list[int]]:
|
|
58
|
+
"""Extract shape information from an ONNX model's graph.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
inferred_model (onnx.GraphProto): The inferred ONNX model graph.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
dict[str, list[int]]: Mapping from tensor names to their shape dimensions.
|
|
65
|
+
Unknown dimensions are returned as 1.
|
|
66
|
+
"""
|
|
67
|
+
value_info = {}
|
|
68
|
+
graph = inferred_model.graph
|
|
69
|
+
all_info = list(graph.value_info) + list(graph.output) + list(graph.input)
|
|
70
|
+
for vi in all_info:
|
|
71
|
+
if vi.type.HasField("tensor_type"):
|
|
72
|
+
shape = [
|
|
73
|
+
# TODO@jsgold-1: figure out how to deal with bad value # noqa: FIX002, TD003, E501
|
|
74
|
+
d.dim_value if d.HasField("dim_value") else 1
|
|
75
|
+
for d in vi.type.tensor_type.shape.dim
|
|
76
|
+
]
|
|
77
|
+
value_info[vi.name] = shape
|
|
78
|
+
return value_info
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def replace_input_references(
|
|
82
|
+
graph: onnx.GraphProto,
|
|
83
|
+
old_output: str,
|
|
84
|
+
new_output: str,
|
|
85
|
+
) -> None:
|
|
86
|
+
"""Replace all references to an input tensor in an ONNX graph.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
graph (onnx.GraphProto): The ONNX graph to modify.
|
|
90
|
+
old_output (str): The original tensor name to replace.
|
|
91
|
+
new_output (str): The new tensor name.
|
|
92
|
+
"""
|
|
93
|
+
for node in graph.node:
|
|
94
|
+
for i, input_name in enumerate(node.input):
|
|
95
|
+
if input_name == old_output:
|
|
96
|
+
node.input[i] = new_output
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def extract_attributes(node: onnx.NodeProto) -> dict:
|
|
100
|
+
"""Extract all attributes from an ONNX node into a Python dictionary.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
node (onnx.NodeProto): The ONNX node to extract attributes from.
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If an attribute type is unsupported.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
dict: Mapping of attribute names to Python-native values.
|
|
110
|
+
"""
|
|
111
|
+
attrs = {}
|
|
112
|
+
for attr in node.attribute:
|
|
113
|
+
name = attr.name
|
|
114
|
+
val = onnx.helper.get_attribute_value(attr)
|
|
115
|
+
|
|
116
|
+
if attr.type == AttributeProto.FLOAT:
|
|
117
|
+
attrs[name] = float(val)
|
|
118
|
+
elif attr.type == AttributeProto.INT:
|
|
119
|
+
attrs[name] = int(val)
|
|
120
|
+
elif attr.type == AttributeProto.FLOATS:
|
|
121
|
+
attrs[name] = [
|
|
122
|
+
float(x) for x in val
|
|
123
|
+
] # ← you want to ensure these are int if your op expects it
|
|
124
|
+
elif attr.type == AttributeProto.INTS:
|
|
125
|
+
attrs[name] = ",".join(str(v) for v in val)
|
|
126
|
+
elif attr.type == AttributeProto.STRING:
|
|
127
|
+
attrs[name] = val.decode("utf-8") if isinstance(val, bytes) else val
|
|
128
|
+
elif attr.type == AttributeProto.BOOL:
|
|
129
|
+
attrs[name] = bool(val)
|
|
130
|
+
else:
|
|
131
|
+
msg = f"Unsupported attribute type: {attr.name} (type={attr.type})"
|
|
132
|
+
raise ValueError(msg)
|
|
133
|
+
return attrs
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def get_input_shapes(onnx_model: onnx.ModelProto) -> dict:
|
|
137
|
+
"""Get the input tensor shapes from an ONNX model.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
onnx_model (onnx.ModelProto): The ONNX model.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
dict: Mapping from input tensor names to their shape dimensions.
|
|
144
|
+
"""
|
|
145
|
+
input_shapes = {}
|
|
146
|
+
for model_in in onnx_model.graph.input:
|
|
147
|
+
input_name = model_in.name
|
|
148
|
+
# Get the shape from the input's type information
|
|
149
|
+
shape = [dim.dim_value for dim in model_in.type.tensor_type.shape.dim]
|
|
150
|
+
input_shapes[input_name] = shape
|
|
151
|
+
return input_shapes
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def get_attribute_ints(
|
|
155
|
+
node: onnx.NodeProto,
|
|
156
|
+
name: str,
|
|
157
|
+
default: list[int] | None = None,
|
|
158
|
+
) -> list[int]:
|
|
159
|
+
"""Retrieve a list of integer values from an ONNX node's attribute.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
node (onnx.NodeProto): The ONNX node.
|
|
163
|
+
name (str): Name of the attribute to retrieve.
|
|
164
|
+
default (list[int], optional):
|
|
165
|
+
Default list to return if the attribute is missing. Defaults to None.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
list[int]: List of integers from the attribute, or the default if not found.
|
|
169
|
+
"""
|
|
170
|
+
for attr in node.attribute:
|
|
171
|
+
if attr.name == name and attr.type == onnx.AttributeProto.INTS:
|
|
172
|
+
return list(attr.ints)
|
|
173
|
+
return default if default is not None else []
|