JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of JSTprove might be problematic. Click here for more details.

Files changed (81) hide show
  1. jstprove-1.0.0.dist-info/METADATA +397 -0
  2. jstprove-1.0.0.dist-info/RECORD +81 -0
  3. jstprove-1.0.0.dist-info/WHEEL +5 -0
  4. jstprove-1.0.0.dist-info/entry_points.txt +2 -0
  5. jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
  6. jstprove-1.0.0.dist-info/top_level.txt +1 -0
  7. python/__init__.py +0 -0
  8. python/core/__init__.py +3 -0
  9. python/core/binaries/__init__.py +0 -0
  10. python/core/binaries/expander-exec +0 -0
  11. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  12. python/core/circuit_models/__init__.py +0 -0
  13. python/core/circuit_models/generic_onnx.py +231 -0
  14. python/core/circuit_models/simple_circuit.py +133 -0
  15. python/core/circuits/__init__.py +0 -0
  16. python/core/circuits/base.py +1000 -0
  17. python/core/circuits/errors.py +188 -0
  18. python/core/circuits/zk_model_base.py +25 -0
  19. python/core/model_processing/__init__.py +0 -0
  20. python/core/model_processing/converters/__init__.py +0 -0
  21. python/core/model_processing/converters/base.py +143 -0
  22. python/core/model_processing/converters/onnx_converter.py +1181 -0
  23. python/core/model_processing/errors.py +147 -0
  24. python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
  25. python/core/model_processing/onnx_custom_ops/conv.py +111 -0
  26. python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
  27. python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
  28. python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
  29. python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
  30. python/core/model_processing/onnx_custom_ops/relu.py +43 -0
  31. python/core/model_processing/onnx_quantizer/__init__.py +0 -0
  32. python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
  33. python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
  34. python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
  35. python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
  36. python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
  37. python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
  38. python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
  39. python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
  40. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
  41. python/core/model_templates/__init__.py +0 -0
  42. python/core/model_templates/circuit_template.py +57 -0
  43. python/core/utils/__init__.py +0 -0
  44. python/core/utils/benchmarking_helpers.py +163 -0
  45. python/core/utils/constants.py +4 -0
  46. python/core/utils/errors.py +117 -0
  47. python/core/utils/general_layer_functions.py +268 -0
  48. python/core/utils/helper_functions.py +1138 -0
  49. python/core/utils/model_registry.py +166 -0
  50. python/core/utils/scratch_tests.py +66 -0
  51. python/core/utils/witness_utils.py +291 -0
  52. python/frontend/__init__.py +0 -0
  53. python/frontend/cli.py +115 -0
  54. python/frontend/commands/__init__.py +17 -0
  55. python/frontend/commands/args.py +100 -0
  56. python/frontend/commands/base.py +199 -0
  57. python/frontend/commands/bench/__init__.py +54 -0
  58. python/frontend/commands/bench/list.py +42 -0
  59. python/frontend/commands/bench/model.py +172 -0
  60. python/frontend/commands/bench/sweep.py +212 -0
  61. python/frontend/commands/compile.py +58 -0
  62. python/frontend/commands/constants.py +5 -0
  63. python/frontend/commands/model_check.py +53 -0
  64. python/frontend/commands/prove.py +50 -0
  65. python/frontend/commands/verify.py +73 -0
  66. python/frontend/commands/witness.py +64 -0
  67. python/scripts/__init__.py +0 -0
  68. python/scripts/benchmark_runner.py +833 -0
  69. python/scripts/gen_and_bench.py +482 -0
  70. python/tests/__init__.py +0 -0
  71. python/tests/circuit_e2e_tests/__init__.py +0 -0
  72. python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
  73. python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
  74. python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
  75. python/tests/circuit_parent_classes/__init__.py +0 -0
  76. python/tests/circuit_parent_classes/test_circuit.py +969 -0
  77. python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
  78. python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
  79. python/tests/test_cli.py +1021 -0
  80. python/tests/utils_testing/__init__.py +0 -0
  81. python/tests/utils_testing/test_helper_functions.py +891 -0
@@ -0,0 +1,147 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from python.core.model_processing.converters.base import ModelType
7
+
8
+
9
+ class ModelConversionError(Exception):
10
+ """Base class for all model conversion errors."""
11
+
12
+ def __init__(
13
+ self: ModelConversionError,
14
+ message: str,
15
+ model_type: ModelType,
16
+ context: dict | None = None,
17
+ ) -> None:
18
+ self.message = message
19
+ self.model_type = model_type
20
+ self.context = context or {}
21
+ super().__init__(self.__str__())
22
+
23
+ def __str__(self: ModelConversionError) -> str:
24
+ msg = f"Error converting {self.model_type.value} model: {self.message}"
25
+ if self.context:
26
+ ctx_str = ", ".join(f"{k}={v}" for k, v in self.context.items())
27
+ msg += f"(Context: {ctx_str})"
28
+ return msg
29
+
30
+
31
+ class ModelLoadError(ModelConversionError):
32
+ """Raised when an ONNX model cannot be loaded."""
33
+
34
+ def __init__(
35
+ self: ModelLoadError,
36
+ file_path: str,
37
+ model_type: ModelType,
38
+ reason: str = "",
39
+ ) -> None:
40
+ message = f"Failed to load {model_type.value} model from '{file_path}'."
41
+ if reason:
42
+ message += f" Reason: {reason}"
43
+ super().__init__(message, model_type, context={"file_path": file_path})
44
+
45
+
46
+ class ModelSaveError(ModelConversionError):
47
+ """Raised when saving a model fails."""
48
+
49
+ def __init__(
50
+ self: ModelSaveError,
51
+ file_path: str,
52
+ model_type: ModelType,
53
+ reason: str = "",
54
+ ) -> None:
55
+ message = f"Failed to save {model_type.value} model to '{file_path}'."
56
+ if reason:
57
+ message += f" Reason: {reason}"
58
+ super().__init__(message, model_type, context={"file_path": file_path})
59
+
60
+
61
+ class InferenceError(ModelConversionError):
62
+ """Raised when inference via ONNX Runtime fails."""
63
+
64
+ def __init__(
65
+ self: InferenceError,
66
+ model_type: ModelType,
67
+ model_path: str | None = None,
68
+ reason: str = "",
69
+ ) -> None:
70
+ message = f" {model_type.value} inference failed."
71
+ if model_path:
72
+ message += f" Model: '{model_path}'."
73
+ if reason:
74
+ message += f" Reason: {reason}"
75
+ super().__init__(message, model_type, context={"model_path": model_path})
76
+
77
+
78
+ class LayerAnalysisError(ModelConversionError):
79
+ """Raised when analyzing model layers fails."""
80
+
81
+ def __init__(
82
+ self: LayerAnalysisError,
83
+ model_type: ModelType,
84
+ layer_name: str | None = None,
85
+ reason: str = "",
86
+ ) -> None:
87
+ message = "Layer analysis failed."
88
+ if layer_name:
89
+ message += f" Problematic layer: '{layer_name}'."
90
+ if reason:
91
+ message += f" Reason: {reason}"
92
+ super().__init__(message, model_type, context={"layer_name": layer_name})
93
+
94
+
95
+ class IOInfoExtractionError(ModelConversionError):
96
+ """Raised when extracting input/output info fails."""
97
+
98
+ def __init__(
99
+ self: IOInfoExtractionError,
100
+ model_type: ModelType,
101
+ model_path: str | None = None,
102
+ reason: str = "",
103
+ ) -> None:
104
+ message = "Failed to extract input/output info from model."
105
+ if model_path:
106
+ message += f" Model: '{model_path}'."
107
+ if reason:
108
+ message += f" Reason: {reason}"
109
+ super().__init__(message, model_type, context={"model_path": model_path})
110
+
111
+
112
+ class InvalidModelError(ModelConversionError):
113
+ """Raised when an ONNX model fails validation checks (onnx.checker)."""
114
+
115
+ def __init__(
116
+ self: InvalidModelError,
117
+ model_type: ModelType,
118
+ model_path: str | None = None,
119
+ reason: str = "",
120
+ ) -> None:
121
+ msg = f"The {model_type.value} model is invalid."
122
+ if model_path:
123
+ msg += f" Model: '{model_path}'."
124
+ if reason:
125
+ msg += f" Reason: {reason}"
126
+ super().__init__(
127
+ message=msg,
128
+ model_type=model_type,
129
+ context={"model_path": model_path},
130
+ )
131
+
132
+
133
+ class SerializationError(ModelConversionError):
134
+ """Raised when model data cannot be serialized to the required format."""
135
+
136
+ def __init__(
137
+ self: SerializationError,
138
+ model_type: ModelType,
139
+ tensor_name: str | None = None,
140
+ reason: str = "",
141
+ ) -> None:
142
+ message = "Failed to serialize model data."
143
+ if tensor_name:
144
+ message += f" Tensor: '{tensor_name}'."
145
+ if reason:
146
+ message += f" Reason: {reason}"
147
+ super().__init__(message, model_type, context={"tensor_name": tensor_name})
@@ -0,0 +1,16 @@
1
+ import importlib
2
+ import pkgutil
3
+ import os
4
+
5
+ # Get the package name of the current module
6
+ package_name = __name__
7
+
8
+ # Dynamically import all .py files in this package directory (except __init__.py)
9
+ package_dir = os.path.dirname(__file__)
10
+
11
+ __all__ = []
12
+
13
+ for _, module_name, is_pkg in pkgutil.iter_modules([package_dir]):
14
+ if not is_pkg and (module_name != "custom_helpers"):
15
+ importlib.import_module(f"{package_name}.{module_name}")
16
+ __all__.append(module_name)
@@ -0,0 +1,111 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as f
6
+ from onnxruntime_extensions import PyCustomOpDef, onnx_op
7
+
8
+ from .custom_helpers import parse_attr, rescaling
9
+
10
+
11
+ @onnx_op(
12
+ op_type="Int64Conv",
13
+ domain="ai.onnx.contrib",
14
+ inputs=[
15
+ PyCustomOpDef.dt_int64, # X
16
+ PyCustomOpDef.dt_int64, # W
17
+ PyCustomOpDef.dt_int64, # B
18
+ PyCustomOpDef.dt_int64, # scaling factor
19
+ ],
20
+ outputs=[PyCustomOpDef.dt_int64],
21
+ attrs={
22
+ "auto_pad": PyCustomOpDef.dt_string,
23
+ "strides": PyCustomOpDef.dt_string,
24
+ "pads": PyCustomOpDef.dt_string,
25
+ "dilations": PyCustomOpDef.dt_string,
26
+ "group": PyCustomOpDef.dt_int64,
27
+ "kernel_shape": PyCustomOpDef.dt_string,
28
+ "rescale": PyCustomOpDef.dt_int64,
29
+ },
30
+ )
31
+ def int64_conv(
32
+ x: np.ndarray,
33
+ w: np.ndarray,
34
+ b: np.ndarray | None = None,
35
+ scaling_factor: np.ndarray | None = None,
36
+ auto_pad: str | None = None,
37
+ dilations: str | None = None,
38
+ group: int | None = None,
39
+ kernel_shape: str | None = None,
40
+ pads: str | None = None,
41
+ strides: str | None = None,
42
+ rescale: int | None = None,
43
+ ) -> np.ndarray:
44
+ """
45
+ Performs a convolution on int64 input tensors.
46
+
47
+ This function is registered as a custom ONNX operator via onnxruntime_extensions
48
+ and is used in the JSTprove quantized inference pipeline. It parses ONNX-style
49
+ convolution attributes, applies convolution
50
+ and optionally rescales the result.
51
+
52
+ Parameters
53
+ ----------
54
+ X : Input tensor with dtype int64.
55
+ W : Convolution weight tensor with dtype int64.
56
+ B : Optional bias tensor with dtype int64.
57
+ scaling_factor : Scaling factor for rescaling the output.
58
+ auto_pad : Optional ONNX auto padding type (`SAME_UPPER`, `SAME_LOWER`, `VALID`).
59
+ dilations : Dilation values for the convolution (default: `[1, 1]`).
60
+ group : Group value for the convolution (default: 1).
61
+ kernel_shape : Kernel shape (default: `[3, 3]`).
62
+ pads : Padding values (default: `[0, 0, 0, 0]`).
63
+ strides : Stride values (default: `[1, 1]`).
64
+ rescale : Optional flag to apply output rescaling or not.
65
+
66
+ Returns
67
+ -------
68
+ numpy.ndarray
69
+ Convolved tensor with dtype int64.
70
+
71
+ Notes
72
+ -----
73
+ - This op is part of the `ai.onnx.contrib` custom domain.
74
+ - ONNX Runtime Extensions is required to register this op.
75
+
76
+ References
77
+ ----------
78
+ For more information on the convolution operation, please refer to the
79
+ ONNX standard Conv operator documentation:
80
+ https://onnx.ai/onnx/operators/onnx__Conv.html
81
+ """
82
+ _ = auto_pad
83
+ try:
84
+ strides = parse_attr(strides, [1, 1])
85
+ dilations = parse_attr(dilations, [1, 1])
86
+ pads = parse_attr(pads, [0, 0, 0, 0])
87
+ kernel_shape = parse_attr(kernel_shape, [3, 3])
88
+
89
+ x = torch.from_numpy(x)
90
+ w = torch.from_numpy(w)
91
+ b = torch.from_numpy(b)
92
+
93
+ result = (
94
+ f.conv2d(
95
+ x,
96
+ w,
97
+ bias=b,
98
+ stride=strides,
99
+ padding=pads[:2],
100
+ dilation=dilations,
101
+ groups=group,
102
+ )
103
+ .numpy()
104
+ .astype(np.int64)
105
+ )
106
+ result = rescaling(scaling_factor, rescale, result)
107
+ return result.astype(np.int64)
108
+
109
+ except Exception as e:
110
+ msg = f"Int64Conv failed: {e}"
111
+ raise RuntimeError(msg) from e
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TypeVar
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ def rescaling(scaling_factor: int, rescale: int, y: int) -> int:
9
+ """Applies integer rescaling to a value based on the given scaling factor.
10
+
11
+ Args:
12
+ scaling_factor (int): The divisor to apply when rescaling.
13
+ Must be provided if `rescale` is True.
14
+ rescale (int): Whether to apply rescaling. (0 -> no rescaling, 1 -> rescaling).
15
+ Y (int): The value to be rescaled.
16
+
17
+ Raises:
18
+ NotImplementedError: If `rescale` is 1 but `scaling_factor` is not provided.
19
+ NotImplementedError: If `rescale` is not 0 or 1.
20
+
21
+ Returns:
22
+ int: The rescaled value if `rescale` is True, otherwise the original value.
23
+ """
24
+ if rescale == 1:
25
+ if scaling_factor is None:
26
+ msg = "scaling_factor must be specified when rescale=1"
27
+ raise ValueError(msg)
28
+ return y // scaling_factor
29
+ if rescale == 0:
30
+ return y
31
+ msg = f"Rescale must be 0 or 1, got {rescale}"
32
+ raise ValueError(msg)
33
+
34
+
35
+ def parse_attr(attr: str, default: T) -> T:
36
+ """Parses an attribute list of strings into a list of integers.
37
+
38
+ Args:
39
+ attr (str): Attribute to parse. If a string, it must be
40
+ comma-separated integers (e.g., "1, 2, 3").
41
+ If None, returns `default`.
42
+ default (T): Default value to return if `attr` is None.
43
+
44
+ Raises:
45
+ ValueError: If `attr` is a string but cannot be parsed into integers.
46
+
47
+ Returns:
48
+ T: Parsed list of integers if attr is provided, otherwise the default value.
49
+ """
50
+ if attr is None:
51
+ return default
52
+ try:
53
+ return [int(x.strip()) for x in attr.split(",")]
54
+ except ValueError as e:
55
+ msg = f"Invalid attribute format: {attr}"
56
+ raise ValueError(msg) from e
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from onnxruntime_extensions import PyCustomOpDef, onnx_op
5
+
6
+ from .custom_helpers import rescaling
7
+
8
+
9
+ @onnx_op(
10
+ op_type="Int64Gemm",
11
+ domain="ai.onnx.contrib",
12
+ inputs=[
13
+ PyCustomOpDef.dt_int64, # X
14
+ PyCustomOpDef.dt_int64, # W
15
+ PyCustomOpDef.dt_int64, # B
16
+ PyCustomOpDef.dt_int64, # Scalar
17
+ ],
18
+ outputs=[PyCustomOpDef.dt_int64],
19
+ attrs={
20
+ "alpha": PyCustomOpDef.dt_float,
21
+ "beta": PyCustomOpDef.dt_float,
22
+ "transA": PyCustomOpDef.dt_int64,
23
+ "transB": PyCustomOpDef.dt_int64,
24
+ "rescale": PyCustomOpDef.dt_int64,
25
+ },
26
+ )
27
+ def int64_gemm7(
28
+ a: np.ndarray,
29
+ b: np.ndarray,
30
+ c: np.ndarray | None = None,
31
+ scaling_factor: np.ndarray | None = None,
32
+ alpha: float | None = None,
33
+ beta: float | None = None,
34
+ transA: int | None = None, # noqa: N803
35
+ transB: int | None = None, # noqa: N803
36
+ rescale: int | None = None,
37
+ ) -> np.ndarray:
38
+ """
39
+ Performs a Gemm (alternatively: Linear layer) on int64 input tensors.
40
+
41
+ This function is registered as a custom ONNX operator via onnxruntime_extensions
42
+ and is used in the JSTprove quantized inference pipeline. It parses ONNX-style
43
+ gemm attributes, applies gemm
44
+ and optionally rescales the result.
45
+
46
+ Parameters
47
+ ----------
48
+ a : Input tensor with dtype int64.
49
+ b : Gemm weight tensor with dtype int64.
50
+ c : Optional bias tensor with dtype int64.
51
+ scaling_factor : Scaling factor for rescaling the output.
52
+ alpha : alpha value for Gemm operation.
53
+ beta : beta value for Gemm operation.
54
+ transA : Transpose the a matrix before the Gemm operation
55
+ transB : Transpose the b matrix before the Gemm operation
56
+ rescale : Optional flag to apply output rescaling or not.
57
+
58
+ Returns
59
+ -------
60
+ numpy.ndarray
61
+ Gemm tensor with dtype int64.
62
+
63
+ Notes
64
+ -----
65
+ - This op is part of the `ai.onnx.contrib` custom domain.
66
+ - ONNX Runtime Extensions is required to register this op.
67
+
68
+ References
69
+ ----------
70
+ For more information on the gemm operation, please refer to the
71
+ ONNX standard Gemm operator documentation:
72
+ https://onnx.ai/onnx/operators/onnx__Gemm.html
73
+ """
74
+ try:
75
+ alpha = int(alpha)
76
+ beta = int(beta)
77
+
78
+ a = a.T if transA else a
79
+ b = b.T if transB else b
80
+
81
+ result = alpha * (a @ b)
82
+
83
+ if c is not None:
84
+ result += beta * c
85
+
86
+ result = rescaling(scaling_factor, rescale, result)
87
+ return result.astype(np.int64)
88
+
89
+ except Exception as e:
90
+ msg = f"Int64Gemm failed: {e}"
91
+ raise RuntimeError(msg) from e
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as f
6
+ from onnxruntime_extensions import PyCustomOpDef, onnx_op
7
+
8
+ from .custom_helpers import parse_attr
9
+
10
+
11
+ @onnx_op(
12
+ op_type="Int64MaxPool",
13
+ domain="ai.onnx.contrib",
14
+ inputs=[PyCustomOpDef.dt_int64], # input tensor
15
+ outputs=[PyCustomOpDef.dt_int64],
16
+ attrs={
17
+ "strides": PyCustomOpDef.dt_string,
18
+ "pads": PyCustomOpDef.dt_string,
19
+ "kernel_shape": PyCustomOpDef.dt_string,
20
+ "dilations": PyCustomOpDef.dt_string,
21
+ },
22
+ )
23
+ def int64_maxpool(
24
+ x: np.ndarray,
25
+ strides: str | None = None,
26
+ pads: str | None = None,
27
+ kernel_shape: str | None = None,
28
+ dilations: str | None = None,
29
+ ) -> np.ndarray:
30
+ """
31
+ Performs a MaxPool operation on int64 input tensors.
32
+
33
+ This function is registered as a custom ONNX operator via onnxruntime_extensions
34
+ and is used in the JSTprove quantized inference pipeline. It parses ONNX-style
35
+ maxpool attributes and applies maxpool.
36
+
37
+ Parameters
38
+ ----------
39
+ X : Input tensor with dtype int64.
40
+ kernel_shape : Kernel shape (default: `[2, 2]`).
41
+ pads : Padding values (default: `[0, 0, 0, 0]`).
42
+ strides : Stride values (default: `[1, 1]`).
43
+ dilations : dilation values (default: `[1, 1]`).
44
+
45
+
46
+ Returns
47
+ -------
48
+ numpy.ndarray
49
+ Maxpool tensor with dtype int64.
50
+
51
+ Notes
52
+ -----
53
+ - This op is part of the `ai.onnx.contrib` custom domain.
54
+ - ONNX Runtime Extensions is required to register this op.
55
+
56
+ References
57
+ ----------
58
+ For more information on the maxpool operation, please refer to the
59
+ ONNX standard MaxPool operator documentation:
60
+ https://onnx.ai/onnx/operators/onnx__MaxPool.html
61
+ """
62
+ try:
63
+ strides = parse_attr(strides, [1, 1])
64
+ pads = parse_attr(pads, [0, 0])
65
+ kernel_size = parse_attr(kernel_shape, [2, 2])
66
+ dilations = parse_attr(dilations, [1, 1])
67
+
68
+ x = torch.from_numpy(x)
69
+ result = f.max_pool2d(
70
+ x,
71
+ kernel_size=kernel_size,
72
+ stride=strides,
73
+ padding=pads[:2],
74
+ dilation=dilations,
75
+ )
76
+ return result.numpy().astype(np.int64)
77
+ except Exception as e:
78
+ msg = f"Int64Gemm failed: {e}"
79
+ raise RuntimeError(msg) from e
@@ -0,0 +1,173 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import onnx
6
+ from onnx import AttributeProto, numpy_helper
7
+
8
+ ATTRIBUTE_PARSERS = {
9
+ AttributeProto.FLOAT: lambda a: a.f,
10
+ AttributeProto.INT: lambda a: a.i,
11
+ AttributeProto.STRING: lambda a: a.s.decode("utf-8", errors="replace"),
12
+ AttributeProto.FLOATS: lambda a: list(a.floats),
13
+ AttributeProto.INTS: lambda a: list(a.ints),
14
+ AttributeProto.STRINGS: lambda a: [
15
+ s.decode("utf-8", errors="replace") for s in a.strings
16
+ ],
17
+ AttributeProto.TENSOR: lambda a: numpy_helper.to_array(a.t).tolist(),
18
+ AttributeProto.TENSORS: lambda a: [
19
+ numpy_helper.to_array(t).tolist() for t in a.tensors
20
+ ],
21
+ }
22
+
23
+
24
+ def parse_attribute(
25
+ attr: AttributeProto,
26
+ ) -> float | int | str | list[int] | list[float] | list[str]:
27
+ """Parse ONNX attributes into a Python-native type.
28
+
29
+ Args:
30
+ attr (AttributeProto): The ONNX attribute to parse.
31
+
32
+ Raises:
33
+ ValueError: If the attribute type is unsupported.
34
+
35
+ Returns:
36
+ Any: The parsed attribute value as a Python type.
37
+ """
38
+ parser = ATTRIBUTE_PARSERS.get(attr.type)
39
+ if parser is None:
40
+ msg = f"Unsupported attribute type: {attr.type}"
41
+ raise ValueError(msg)
42
+ return parser(attr)
43
+
44
+
45
+ def parse_attributes(attrs: list[AttributeProto]) -> dict[str, Any]:
46
+ """Parse multiple ONNX attributes into a dictionary.
47
+
48
+ Args:
49
+ attrs (list[AttributeProto]): List of ONNX attributes.
50
+
51
+ Returns:
52
+ dict[str, Any]: Mapping of attribute names to their parsed values.
53
+ """
54
+ return {attr.name: parse_attribute(attr) for attr in attrs}
55
+
56
+
57
+ def extract_shape_dict(inferred_model: onnx.GraphProto) -> dict[str, list[int]]:
58
+ """Extract shape information from an ONNX model's graph.
59
+
60
+ Args:
61
+ inferred_model (onnx.GraphProto): The inferred ONNX model graph.
62
+
63
+ Returns:
64
+ dict[str, list[int]]: Mapping from tensor names to their shape dimensions.
65
+ Unknown dimensions are returned as 1.
66
+ """
67
+ value_info = {}
68
+ graph = inferred_model.graph
69
+ all_info = list(graph.value_info) + list(graph.output) + list(graph.input)
70
+ for vi in all_info:
71
+ if vi.type.HasField("tensor_type"):
72
+ shape = [
73
+ # TODO@jsgold-1: figure out how to deal with bad value # noqa: FIX002, TD003, E501
74
+ d.dim_value if d.HasField("dim_value") else 1
75
+ for d in vi.type.tensor_type.shape.dim
76
+ ]
77
+ value_info[vi.name] = shape
78
+ return value_info
79
+
80
+
81
+ def replace_input_references(
82
+ graph: onnx.GraphProto,
83
+ old_output: str,
84
+ new_output: str,
85
+ ) -> None:
86
+ """Replace all references to an input tensor in an ONNX graph.
87
+
88
+ Args:
89
+ graph (onnx.GraphProto): The ONNX graph to modify.
90
+ old_output (str): The original tensor name to replace.
91
+ new_output (str): The new tensor name.
92
+ """
93
+ for node in graph.node:
94
+ for i, input_name in enumerate(node.input):
95
+ if input_name == old_output:
96
+ node.input[i] = new_output
97
+
98
+
99
+ def extract_attributes(node: onnx.NodeProto) -> dict:
100
+ """Extract all attributes from an ONNX node into a Python dictionary.
101
+
102
+ Args:
103
+ node (onnx.NodeProto): The ONNX node to extract attributes from.
104
+
105
+ Raises:
106
+ ValueError: If an attribute type is unsupported.
107
+
108
+ Returns:
109
+ dict: Mapping of attribute names to Python-native values.
110
+ """
111
+ attrs = {}
112
+ for attr in node.attribute:
113
+ name = attr.name
114
+ val = onnx.helper.get_attribute_value(attr)
115
+
116
+ if attr.type == AttributeProto.FLOAT:
117
+ attrs[name] = float(val)
118
+ elif attr.type == AttributeProto.INT:
119
+ attrs[name] = int(val)
120
+ elif attr.type == AttributeProto.FLOATS:
121
+ attrs[name] = [
122
+ float(x) for x in val
123
+ ] # ← you want to ensure these are int if your op expects it
124
+ elif attr.type == AttributeProto.INTS:
125
+ attrs[name] = ",".join(str(v) for v in val)
126
+ elif attr.type == AttributeProto.STRING:
127
+ attrs[name] = val.decode("utf-8") if isinstance(val, bytes) else val
128
+ elif attr.type == AttributeProto.BOOL:
129
+ attrs[name] = bool(val)
130
+ else:
131
+ msg = f"Unsupported attribute type: {attr.name} (type={attr.type})"
132
+ raise ValueError(msg)
133
+ return attrs
134
+
135
+
136
+ def get_input_shapes(onnx_model: onnx.ModelProto) -> dict:
137
+ """Get the input tensor shapes from an ONNX model.
138
+
139
+ Args:
140
+ onnx_model (onnx.ModelProto): The ONNX model.
141
+
142
+ Returns:
143
+ dict: Mapping from input tensor names to their shape dimensions.
144
+ """
145
+ input_shapes = {}
146
+ for model_in in onnx_model.graph.input:
147
+ input_name = model_in.name
148
+ # Get the shape from the input's type information
149
+ shape = [dim.dim_value for dim in model_in.type.tensor_type.shape.dim]
150
+ input_shapes[input_name] = shape
151
+ return input_shapes
152
+
153
+
154
+ def get_attribute_ints(
155
+ node: onnx.NodeProto,
156
+ name: str,
157
+ default: list[int] | None = None,
158
+ ) -> list[int]:
159
+ """Retrieve a list of integer values from an ONNX node's attribute.
160
+
161
+ Args:
162
+ node (onnx.NodeProto): The ONNX node.
163
+ name (str): Name of the attribute to retrieve.
164
+ default (list[int], optional):
165
+ Default list to return if the attribute is missing. Defaults to None.
166
+
167
+ Returns:
168
+ list[int]: List of integers from the attribute, or the default if not found.
169
+ """
170
+ for attr in node.attribute:
171
+ if attr.name == name and attr.type == onnx.AttributeProto.INTS:
172
+ return list(attr.ints)
173
+ return default if default is not None else []