JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl → 1.2.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/METADATA +3 -3
  2. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/RECORD +60 -25
  3. python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
  4. python/core/circuit_models/generic_onnx.py +43 -9
  5. python/core/circuits/base.py +231 -71
  6. python/core/model_processing/converters/onnx_converter.py +114 -59
  7. python/core/model_processing/onnx_custom_ops/batchnorm.py +64 -0
  8. python/core/model_processing/onnx_custom_ops/maxpool.py +1 -1
  9. python/core/model_processing/onnx_custom_ops/mul.py +66 -0
  10. python/core/model_processing/onnx_custom_ops/relu.py +1 -1
  11. python/core/model_processing/onnx_quantizer/layers/add.py +54 -0
  12. python/core/model_processing/onnx_quantizer/layers/base.py +188 -1
  13. python/core/model_processing/onnx_quantizer/layers/batchnorm.py +224 -0
  14. python/core/model_processing/onnx_quantizer/layers/constant.py +1 -1
  15. python/core/model_processing/onnx_quantizer/layers/conv.py +20 -68
  16. python/core/model_processing/onnx_quantizer/layers/gemm.py +20 -66
  17. python/core/model_processing/onnx_quantizer/layers/maxpool.py +53 -43
  18. python/core/model_processing/onnx_quantizer/layers/mul.py +53 -0
  19. python/core/model_processing/onnx_quantizer/layers/relu.py +20 -35
  20. python/core/model_processing/onnx_quantizer/layers/sub.py +54 -0
  21. python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +43 -1
  22. python/core/utils/general_layer_functions.py +17 -12
  23. python/core/utils/model_registry.py +6 -3
  24. python/scripts/gen_and_bench.py +2 -2
  25. python/tests/circuit_e2e_tests/other_e2e_test.py +202 -9
  26. python/tests/circuit_parent_classes/test_circuit.py +561 -38
  27. python/tests/circuit_parent_classes/test_onnx_converter.py +22 -13
  28. python/tests/onnx_quantizer_tests/__init__.py +1 -0
  29. python/tests/onnx_quantizer_tests/layers/__init__.py +13 -0
  30. python/tests/onnx_quantizer_tests/layers/add_config.py +102 -0
  31. python/tests/onnx_quantizer_tests/layers/base.py +279 -0
  32. python/tests/onnx_quantizer_tests/layers/batchnorm_config.py +190 -0
  33. python/tests/onnx_quantizer_tests/layers/constant_config.py +39 -0
  34. python/tests/onnx_quantizer_tests/layers/conv_config.py +154 -0
  35. python/tests/onnx_quantizer_tests/layers/factory.py +142 -0
  36. python/tests/onnx_quantizer_tests/layers/flatten_config.py +61 -0
  37. python/tests/onnx_quantizer_tests/layers/gemm_config.py +160 -0
  38. python/tests/onnx_quantizer_tests/layers/maxpool_config.py +82 -0
  39. python/tests/onnx_quantizer_tests/layers/mul_config.py +102 -0
  40. python/tests/onnx_quantizer_tests/layers/relu_config.py +61 -0
  41. python/tests/onnx_quantizer_tests/layers/reshape_config.py +61 -0
  42. python/tests/onnx_quantizer_tests/layers/sub_config.py +102 -0
  43. python/tests/onnx_quantizer_tests/layers_tests/__init__.py +0 -0
  44. python/tests/onnx_quantizer_tests/layers_tests/base_test.py +94 -0
  45. python/tests/onnx_quantizer_tests/layers_tests/test_check_model.py +115 -0
  46. python/tests/onnx_quantizer_tests/layers_tests/test_e2e.py +196 -0
  47. python/tests/onnx_quantizer_tests/layers_tests/test_error_cases.py +59 -0
  48. python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +198 -0
  49. python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +267 -0
  50. python/tests/onnx_quantizer_tests/layers_tests/test_scalability.py +109 -0
  51. python/tests/onnx_quantizer_tests/layers_tests/test_validation.py +45 -0
  52. python/tests/onnx_quantizer_tests/test_base_layer.py +228 -0
  53. python/tests/onnx_quantizer_tests/test_exceptions.py +99 -0
  54. python/tests/onnx_quantizer_tests/test_onnx_op_quantizer.py +246 -0
  55. python/tests/onnx_quantizer_tests/test_registered_quantizers.py +121 -0
  56. python/tests/onnx_quantizer_tests/testing_helper_functions.py +17 -0
  57. python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
  58. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/WHEEL +0 -0
  59. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/entry_points.txt +0 -0
  60. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/licenses/LICENSE +0 -0
  61. {jstprove-1.0.0.dist-info → jstprove-1.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,99 @@
1
+ import pytest
2
+
3
+ from python.core.model_processing.onnx_quantizer.exceptions import (
4
+ REPORTING_URL,
5
+ InvalidParamError,
6
+ QuantizationError,
7
+ UnsupportedOpError,
8
+ )
9
+
10
+
11
+ @pytest.mark.unit
12
+ def test_quantization_error_message() -> None:
13
+ custom_msg = "Something went wrong."
14
+ with pytest.raises(QuantizationError) as exc_info:
15
+ raise QuantizationError(custom_msg)
16
+ assert "This model is not supported by JSTprove." in str(exc_info.value)
17
+ assert custom_msg in str(exc_info.value)
18
+
19
+ assert REPORTING_URL in str(exc_info.value)
20
+
21
+ assert "Submit model support requests via the JSTprove channel:" in str(
22
+ exc_info.value,
23
+ )
24
+
25
+
26
+ @pytest.mark.unit
27
+ def test_invalid_param_error_basic() -> None:
28
+ with pytest.raises(InvalidParamError) as exc_info:
29
+ raise InvalidParamError(
30
+ node_name="Conv_1",
31
+ op_type="Conv",
32
+ message="Missing 'strides' attribute.",
33
+ )
34
+ err_msg = str(exc_info.value)
35
+ assert "Invalid parameters in node 'Conv_1'" in err_msg
36
+ assert "(op_type='Conv')" in err_msg
37
+ assert "Missing 'strides' attribute." in err_msg
38
+ assert "[Attribute:" not in err_msg
39
+ assert "[Expected:" not in err_msg
40
+
41
+ msg = ""
42
+
43
+ with pytest.raises(QuantizationError) as exc_info_quantization:
44
+ raise QuantizationError(msg)
45
+ # Assert contains generic error message from quantization error
46
+ assert str(exc_info_quantization.value) in err_msg
47
+
48
+
49
+ @pytest.mark.unit
50
+ def test_invalid_param_error_with_attr_and_expected() -> None:
51
+ with pytest.raises(InvalidParamError) as exc_info:
52
+ raise InvalidParamError(
53
+ node_name="MaxPool_3",
54
+ op_type="MaxPool",
55
+ message="Kernel shape is invalid.",
56
+ attr_key="kernel_shape",
57
+ expected="a list of 2 positive integers",
58
+ )
59
+ err_msg = str(exc_info.value)
60
+ assert "Invalid parameters in node 'MaxPool_3'" in err_msg
61
+ assert "[Attribute: kernel_shape]" in err_msg
62
+ assert "[Expected: a list of 2 positive integers]" in err_msg
63
+ msg = ""
64
+
65
+ with pytest.raises(QuantizationError) as exc_info_quantization:
66
+ raise QuantizationError(msg)
67
+ # Assert contains generic error message from quantization error
68
+ assert str(exc_info_quantization.value) in err_msg
69
+
70
+
71
+ @pytest.mark.unit
72
+ def test_unsupported_op_error_with_node() -> None:
73
+ with pytest.raises(UnsupportedOpError) as exc_info:
74
+ raise UnsupportedOpError(op_type="Resize", node_name="Resize_42")
75
+ err_msg = str(exc_info.value)
76
+ assert "Unsupported op type: 'Resize'" in err_msg
77
+ assert "in node 'Resize_42'" in err_msg
78
+ assert "documentation for supported layers" in err_msg
79
+ msg = ""
80
+
81
+ with pytest.raises(QuantizationError) as exc_info_quantization:
82
+ raise QuantizationError(msg)
83
+ # Assert contains generic error message from quantization error
84
+ assert str(exc_info_quantization.value) in err_msg
85
+
86
+
87
+ @pytest.mark.unit
88
+ def test_unsupported_op_error_without_node() -> None:
89
+ with pytest.raises(UnsupportedOpError) as exc_info:
90
+ raise UnsupportedOpError(op_type="Upsample")
91
+ err_msg = str(exc_info.value)
92
+ assert "Unsupported op type: 'Upsample'" in err_msg
93
+ assert "in node" not in err_msg
94
+ msg = ""
95
+
96
+ with pytest.raises(QuantizationError) as exc_info_quantization:
97
+ raise QuantizationError(msg)
98
+ # Assert contains generic error message from quantization error
99
+ assert str(exc_info_quantization.value) in err_msg
@@ -0,0 +1,246 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ import onnx
6
+ import pytest
7
+ from onnx import GraphProto, ModelProto, NodeProto, TensorProto, helper
8
+
9
+ from python.core.model_processing.onnx_quantizer.exceptions import (
10
+ MissingHandlerError,
11
+ QuantizationError,
12
+ UnsupportedOpError,
13
+ )
14
+
15
+ if TYPE_CHECKING:
16
+ from python.core.model_processing.onnx_quantizer.layers.base import ScaleConfig
17
+
18
+ # Optional: mock layers if needed
19
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
20
+ ONNXOpQuantizer,
21
+ )
22
+
23
+
24
+ # Mocks
25
+ class MockHandler:
26
+ def __init__(self: MockHandler) -> None:
27
+ self.called_quantize = False
28
+ self.called_supported = False
29
+
30
+ def quantize(
31
+ self: MockHandler,
32
+ node: NodeProto,
33
+ graph: GraphProto,
34
+ scale_config: ScaleConfig,
35
+ initializer_map: dict[str, TensorProto],
36
+ ) -> list[NodeProto]:
37
+ _ = graph, scale_config, initializer_map
38
+ self.called_quantize = True
39
+ return [node] # Return the original node as a list for simplicity
40
+
41
+ def check_supported(
42
+ self: MockHandler,
43
+ node: NodeProto,
44
+ initializer_map: dict[str, TensorProto],
45
+ ) -> None:
46
+ _ = initializer_map
47
+ self.called_supported = True
48
+ if node.name == "bad_node":
49
+ msg = "Invalid node parameters"
50
+ raise ValueError(msg)
51
+
52
+
53
+ # Fixtures
54
+ @pytest.fixture
55
+ def quantizer() -> ONNXOpQuantizer:
56
+ return ONNXOpQuantizer()
57
+
58
+
59
+ @pytest.fixture
60
+ def dummy_node() -> NodeProto:
61
+ return helper.make_node("FakeOp", inputs=["x"], outputs=["y"])
62
+
63
+
64
+ @pytest.fixture
65
+ def valid_node() -> NodeProto:
66
+ return helper.make_node("Dummy", inputs=["x"], outputs=["y"], name="good_node")
67
+
68
+
69
+ @pytest.fixture
70
+ def invalid_node() -> NodeProto:
71
+ return helper.make_node("Dummy", inputs=["x"], outputs=["y"], name="bad_node")
72
+
73
+
74
+ @pytest.fixture
75
+ def dummy_model(valid_node: NodeProto, invalid_node: NodeProto) -> ModelProto:
76
+ graph = helper.make_graph(
77
+ [valid_node, invalid_node],
78
+ "test_graph",
79
+ inputs=[],
80
+ outputs=[],
81
+ initializer=[helper.make_tensor("x", TensorProto.FLOAT, [1], [0.5])],
82
+ )
83
+ return helper.make_model(graph)
84
+
85
+
86
+ # Tests
87
+
88
+
89
+ @pytest.mark.unit
90
+ def test_check_model_raises_on_unsupported_op() -> None:
91
+ quantizer = ONNXOpQuantizer()
92
+
93
+ unsupported_node = helper.make_node("UnsupportedOp", ["x"], ["y"])
94
+ graph = helper.make_graph([unsupported_node], "test_graph", [], [])
95
+ model = helper.make_model(graph)
96
+
97
+ with pytest.raises(UnsupportedOpError):
98
+ quantizer.check_model(model)
99
+
100
+
101
+ @pytest.mark.unit
102
+ def test_check_layer_invokes_check_supported() -> None:
103
+ quantizer = ONNXOpQuantizer()
104
+ handler = MockHandler()
105
+ quantizer.register("FakeOp", handler)
106
+
107
+ node = helper.make_node("FakeOp", ["x"], ["y"])
108
+ initializer_map = {}
109
+
110
+ quantizer.check_layer(node, initializer_map)
111
+ # Check that check_supported is called
112
+ assert handler.called_supported
113
+
114
+
115
+ @pytest.mark.unit
116
+ def test_get_initializer_map_returns_correct_dict() -> None:
117
+ quantizer = ONNXOpQuantizer()
118
+
119
+ tensor = helper.make_tensor(
120
+ name="W",
121
+ data_type=TensorProto.FLOAT,
122
+ dims=[1],
123
+ vals=[1.0],
124
+ )
125
+ graph = helper.make_graph([], "test_graph", [], [], [tensor])
126
+ model = helper.make_model(graph)
127
+
128
+ init_map = quantizer.get_initializer_map(model)
129
+ # Test initializer in map
130
+ assert "W" in init_map
131
+ # Test initializer map lines up
132
+ assert init_map["W"] == tensor
133
+ # Enhanced: check tensor properties
134
+ assert init_map["W"].data_type == TensorProto.FLOAT
135
+ assert init_map["W"].dims == [1]
136
+ assert onnx.numpy_helper.to_array(init_map["W"])[0] == 1.0
137
+
138
+
139
+ @pytest.mark.unit
140
+ def test_quantize_with_unregistered_op_warns(dummy_node: NodeProto) -> None:
141
+ quantizer = ONNXOpQuantizer()
142
+ graph = helper.make_graph([], "g", [], [])
143
+ with pytest.raises(UnsupportedOpError) as excinfo:
144
+ _ = quantizer.quantize(dummy_node, graph, 1, 1, {}, rescale=False)
145
+
146
+ captured = str(excinfo.value)
147
+ assert "Unsupported op type: 'FakeOp'" in captured
148
+
149
+
150
+ # Could be unit or integration?
151
+ @pytest.mark.unit
152
+ def test_check_model_raises_unsupported(dummy_model: ModelProto) -> None:
153
+ quantizer = ONNXOpQuantizer()
154
+ quantizer.handlers = {"Dummy": MockHandler()}
155
+
156
+ # Remove one node to simulate unsupported ops
157
+ dummy_model.graph.node.append(helper.make_node("FakeOp", ["a"], ["b"]))
158
+
159
+ with pytest.raises(UnsupportedOpError) as excinfo:
160
+ quantizer.check_model(dummy_model)
161
+
162
+ assert "FakeOp" in str(excinfo.value)
163
+
164
+
165
+ @pytest.mark.unit
166
+ def test_check_layer_missing_handler(valid_node: NodeProto) -> None:
167
+ quantizer = ONNXOpQuantizer()
168
+ with pytest.raises(MissingHandlerError) as exc_info:
169
+ quantizer.check_layer(valid_node, {})
170
+
171
+ assert QuantizationError("").GENERIC_MESSAGE in str(exc_info.value)
172
+ assert "No quantization handler registered for operator type 'Dummy'." in str(
173
+ exc_info.value,
174
+ )
175
+
176
+
177
+ @pytest.mark.unit
178
+ def test_check_layer_with_bad_handler(invalid_node: NodeProto) -> None:
179
+ quantizer = ONNXOpQuantizer()
180
+ quantizer.handlers = {"Dummy": MockHandler()}
181
+
182
+ # This error is created in our mock handler
183
+ with pytest.raises(ValueError, match="Invalid node parameters"):
184
+ quantizer.check_layer(invalid_node, {})
185
+
186
+
187
+ @pytest.mark.unit
188
+ def test_get_initializer_map_extracts_all() -> None:
189
+ one_f = 1.0
190
+ two_f = 2.0
191
+ count_init = 2
192
+ tensor1 = helper.make_tensor("a", TensorProto.FLOAT, [1], [one_f])
193
+ tensor2 = helper.make_tensor("b", TensorProto.FLOAT, [1], [two_f])
194
+ graph = helper.make_graph([], "g", [], [], initializer=[tensor1, tensor2])
195
+ model = helper.make_model(graph)
196
+
197
+ quantizer = ONNXOpQuantizer()
198
+ init_map = quantizer.get_initializer_map(model)
199
+ assert init_map["a"].float_data[0] == one_f
200
+ assert init_map["b"].float_data[0] == two_f
201
+
202
+ # Enhanced: check all properties
203
+ assert len(init_map) == count_init
204
+ assert init_map["a"].name == "a"
205
+ assert init_map["a"].data_type == TensorProto.FLOAT
206
+ assert init_map["a"].dims == [1]
207
+ assert init_map["b"].name == "b"
208
+ assert init_map["b"].data_type == TensorProto.FLOAT
209
+ assert init_map["b"].dims == [1]
210
+ # Using numpy_helper for consistency
211
+ assert onnx.numpy_helper.to_array(init_map["a"])[0] == one_f
212
+ assert onnx.numpy_helper.to_array(init_map["b"])[0] == two_f
213
+
214
+
215
+ @pytest.mark.unit
216
+ def test_check_layer_skips_handler_without_check_supported() -> None:
217
+ class NoCheckHandler:
218
+ def quantize(self, *args: tuple, **kwargs: dict[str, Any]) -> None:
219
+ pass # no check_supported
220
+
221
+ quantizer = ONNXOpQuantizer()
222
+ quantizer.register("NoCheckOp", NoCheckHandler())
223
+
224
+ node = helper.make_node("NoCheckOp", ["x"], ["y"])
225
+ # Should not raise
226
+ quantizer.check_layer(node, {})
227
+
228
+
229
+ @pytest.mark.unit
230
+ def test_register_overwrites_handler() -> None:
231
+ quantizer = ONNXOpQuantizer()
232
+ handler1 = MockHandler()
233
+ handler2 = MockHandler()
234
+
235
+ quantizer.register("Dummy", handler1)
236
+ quantizer.register("Dummy", handler2)
237
+
238
+ assert quantizer.handlers["Dummy"] is handler2
239
+
240
+
241
+ @pytest.mark.unit
242
+ def test_check_empty_model() -> None:
243
+ model = helper.make_model(helper.make_graph([], "empty", [], []))
244
+ quantizer = ONNXOpQuantizer()
245
+ # Should not raise
246
+ quantizer.check_model(model)
@@ -0,0 +1,121 @@
1
+ # This file performs very basic integration tests on each registered quantizer
2
+
3
+ import numpy as np
4
+ import onnx
5
+ import pytest
6
+ from onnx import helper
7
+
8
+ from python.core.model_processing.onnx_quantizer.layers.base import ScaleConfig
9
+ from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import (
10
+ ONNXOpQuantizer,
11
+ )
12
+ from python.tests.onnx_quantizer_tests import TEST_RNG_SEED
13
+
14
+
15
+ @pytest.fixture
16
+ def dummy_graph() -> onnx.GraphProto:
17
+ return onnx.GraphProto()
18
+
19
+
20
+ def mock_initializer_map(input_names: list[str]) -> dict[str, onnx.TensorProto]:
21
+ rng = np.random.default_rng(TEST_RNG_SEED)
22
+ return {
23
+ name: onnx.helper.make_tensor(
24
+ name=name,
25
+ data_type=onnx.TensorProto.FLOAT,
26
+ dims=[2, 2], # minimal shape
27
+ vals=rng.random(4, dtype=np.float32).tolist(),
28
+ )
29
+ for name in input_names
30
+ }
31
+
32
+
33
+ def get_required_input_names(op_type: str) -> list[str]:
34
+ try:
35
+ schema = onnx.defs.get_schema(op_type)
36
+ return [
37
+ inp.name or f"input{i}"
38
+ for i, inp in enumerate(schema.inputs)
39
+ if inp.option != 1
40
+ ] # 1 = optional
41
+ except Exception:
42
+ return ["input0"] # fallback
43
+
44
+
45
+ def validate_quantized_node(node_result: onnx.NodeProto, op_type: str) -> None:
46
+ """Validate a single quantized node."""
47
+ assert isinstance(node_result, onnx.NodeProto), f"Invalid node type for {op_type}"
48
+ assert node_result.op_type, f"Missing op_type for {op_type}"
49
+ assert node_result.output, f"Missing outputs for {op_type}"
50
+
51
+ try:
52
+ # Create a minimal model with custom opset for validation
53
+ temp_graph = onnx.GraphProto()
54
+ temp_graph.name = "temp_graph"
55
+
56
+ # Add dummy inputs/outputs to satisfy graph requirements
57
+ for inp in node_result.input:
58
+ if not any(vi.name == inp for vi in temp_graph.input):
59
+ temp_graph.input.append(
60
+ onnx.helper.make_tensor_value_info(
61
+ inp,
62
+ onnx.TensorProto.FLOAT,
63
+ [1],
64
+ ),
65
+ )
66
+ for out in node_result.output:
67
+ if not any(vi.name == out for vi in temp_graph.output):
68
+ temp_graph.output.append(
69
+ onnx.helper.make_tensor_value_info(
70
+ out,
71
+ onnx.TensorProto.FLOAT,
72
+ [1],
73
+ ),
74
+ )
75
+
76
+ temp_graph.node.append(node_result)
77
+ temp_model = onnx.helper.make_model(temp_graph)
78
+ custom_domain = onnx.helper.make_operatorsetid(
79
+ domain="ai.onnx.contrib",
80
+ version=1,
81
+ )
82
+ temp_model.opset_import.append(custom_domain)
83
+ onnx.checker.check_model(temp_model)
84
+ except onnx.checker.ValidationError as e:
85
+ pytest.fail(f"ONNX node validation failed for {op_type}: {e}")
86
+
87
+
88
+ @pytest.mark.integration
89
+ @pytest.mark.parametrize("op_type", list(ONNXOpQuantizer().handlers.keys()))
90
+ def test_registered_quantizer_quantize(
91
+ op_type: str,
92
+ dummy_graph: onnx.GraphProto,
93
+ ) -> None:
94
+ quantizer = ONNXOpQuantizer()
95
+ handler = quantizer.handlers[op_type]
96
+
97
+ inputs = get_required_input_names(op_type)
98
+ dummy_initializer_map = mock_initializer_map(inputs)
99
+
100
+ dummy_node = helper.make_node(
101
+ op_type=op_type,
102
+ inputs=inputs,
103
+ outputs=["dummy_output"],
104
+ )
105
+
106
+ result = handler.quantize(
107
+ node=dummy_node,
108
+ graph=dummy_graph,
109
+ scale_config=ScaleConfig(exponent=10, base=2, rescale=True),
110
+ initializer_map=dummy_initializer_map,
111
+ )
112
+ assert result is not None
113
+
114
+ # Enhanced assertions: validate result type and structure
115
+ if isinstance(result, list):
116
+ assert len(result) > 0, f"Quantize returned empty list for {op_type}"
117
+ for node_result in result:
118
+ validate_quantized_node(node_result, op_type)
119
+ else:
120
+ assert result.input, f"Missing inputs for {op_type}"
121
+ validate_quantized_node(result, op_type)
@@ -0,0 +1,17 @@
1
+ from onnx import ModelProto
2
+
3
+
4
+ # Helper to extract input shapes
5
+ def get_input_shapes(model: ModelProto) -> dict:
6
+ input_shapes = {}
7
+ for inp in model.graph.input:
8
+ shape = []
9
+ for dim in inp.type.tensor_type.shape.dim:
10
+ if dim.HasField("dim_value"):
11
+ shape.append(int(dim.dim_value))
12
+ elif dim.dim_param:
13
+ shape.append(1) # Default for dynamic dims
14
+ else:
15
+ shape.append(1)
16
+ input_shapes[inp.name] = shape
17
+ return input_shapes