emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/onnx_import.py
CHANGED
|
@@ -41,21 +41,19 @@ def _unsupported_value_type(value_info: onnx.ValueInfoProto) -> UnsupportedOpErr
|
|
|
41
41
|
)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def
|
|
45
|
-
|
|
44
|
+
def _tensor_type_from_proto(
|
|
45
|
+
tensor_type: onnx.TypeProto.Tensor,
|
|
46
|
+
name: str,
|
|
46
47
|
*,
|
|
47
48
|
dim_param_override: tuple[str | None, ...] | None = None,
|
|
48
49
|
) -> TensorType:
|
|
49
|
-
if value_info.type.WhichOneof("value") != "tensor_type":
|
|
50
|
-
raise _unsupported_value_type(value_info)
|
|
51
|
-
tensor_type = value_info.type.tensor_type
|
|
52
50
|
if not tensor_type.HasField("elem_type"):
|
|
53
|
-
raise ShapeInferenceError(f"Missing elem_type for tensor '{
|
|
51
|
+
raise ShapeInferenceError(f"Missing elem_type for tensor '{name}'")
|
|
54
52
|
dtype = scalar_type_from_onnx(tensor_type.elem_type)
|
|
55
53
|
if dtype is None:
|
|
56
54
|
raise UnsupportedOpError(
|
|
57
55
|
"Unsupported elem_type "
|
|
58
|
-
f"{_format_elem_type(tensor_type.elem_type)} for tensor '{
|
|
56
|
+
f"{_format_elem_type(tensor_type.elem_type)} for tensor '{name}'."
|
|
59
57
|
)
|
|
60
58
|
shape = []
|
|
61
59
|
dim_params = []
|
|
@@ -72,7 +70,7 @@ def _tensor_type(
|
|
|
72
70
|
if dim_param:
|
|
73
71
|
shape.append(1)
|
|
74
72
|
continue
|
|
75
|
-
raise ShapeInferenceError(f"Dynamic dim for tensor '{
|
|
73
|
+
raise ShapeInferenceError(f"Dynamic dim for tensor '{name}'")
|
|
76
74
|
shape.append(dim.dim_value)
|
|
77
75
|
return TensorType(
|
|
78
76
|
dtype=dtype,
|
|
@@ -81,6 +79,40 @@ def _tensor_type(
|
|
|
81
79
|
)
|
|
82
80
|
|
|
83
81
|
|
|
82
|
+
def _value_type(
|
|
83
|
+
value_info: onnx.ValueInfoProto,
|
|
84
|
+
*,
|
|
85
|
+
dim_param_override: tuple[str | None, ...] | None = None,
|
|
86
|
+
) -> TensorType:
|
|
87
|
+
value_kind = value_info.type.WhichOneof("value")
|
|
88
|
+
if value_kind == "tensor_type":
|
|
89
|
+
return _tensor_type_from_proto(
|
|
90
|
+
value_info.type.tensor_type,
|
|
91
|
+
value_info.name,
|
|
92
|
+
dim_param_override=dim_param_override,
|
|
93
|
+
)
|
|
94
|
+
if value_kind == "optional_type":
|
|
95
|
+
elem_type = value_info.type.optional_type.elem_type
|
|
96
|
+
elem_kind = elem_type.WhichOneof("value")
|
|
97
|
+
if elem_kind != "tensor_type":
|
|
98
|
+
raise UnsupportedOpError(
|
|
99
|
+
f"Unsupported optional element type '{elem_kind}' for '{value_info.name}'. "
|
|
100
|
+
"Hint: export the model with optional tensor inputs/outputs."
|
|
101
|
+
)
|
|
102
|
+
tensor_type = _tensor_type_from_proto(
|
|
103
|
+
elem_type.tensor_type,
|
|
104
|
+
value_info.name,
|
|
105
|
+
dim_param_override=dim_param_override,
|
|
106
|
+
)
|
|
107
|
+
return TensorType(
|
|
108
|
+
dtype=tensor_type.dtype,
|
|
109
|
+
shape=tensor_type.shape,
|
|
110
|
+
dim_params=tensor_type.dim_params,
|
|
111
|
+
is_optional=True,
|
|
112
|
+
)
|
|
113
|
+
raise _unsupported_value_type(value_info)
|
|
114
|
+
|
|
115
|
+
|
|
84
116
|
def _values(
|
|
85
117
|
value_infos: Iterable[onnx.ValueInfoProto],
|
|
86
118
|
*,
|
|
@@ -90,7 +122,7 @@ def _values(
|
|
|
90
122
|
return tuple(
|
|
91
123
|
Value(
|
|
92
124
|
name=vi.name,
|
|
93
|
-
type=
|
|
125
|
+
type=_value_type(
|
|
94
126
|
vi, dim_param_override=dim_param_by_name.get(vi.name)
|
|
95
127
|
),
|
|
96
128
|
)
|
|
@@ -103,8 +135,18 @@ def _collect_dim_params(
|
|
|
103
135
|
) -> dict[str, tuple[str | None, ...]]:
|
|
104
136
|
dim_params: dict[str, tuple[str | None, ...]] = {}
|
|
105
137
|
for value_info in value_infos:
|
|
138
|
+
value_kind = value_info.type.WhichOneof("value")
|
|
139
|
+
if value_kind == "tensor_type":
|
|
140
|
+
tensor_type = value_info.type.tensor_type
|
|
141
|
+
elif value_kind == "optional_type":
|
|
142
|
+
elem_type = value_info.type.optional_type.elem_type
|
|
143
|
+
if elem_type.WhichOneof("value") != "tensor_type":
|
|
144
|
+
continue
|
|
145
|
+
tensor_type = elem_type.tensor_type
|
|
146
|
+
else:
|
|
147
|
+
continue
|
|
106
148
|
dims = []
|
|
107
|
-
for dim in
|
|
149
|
+
for dim in tensor_type.shape.dim:
|
|
108
150
|
dim_param = dim.dim_param if dim.HasField("dim_param") else ""
|
|
109
151
|
dims.append(dim_param or None)
|
|
110
152
|
if any(dims):
|
|
@@ -112,6 +154,61 @@ def _collect_dim_params(
|
|
|
112
154
|
return dim_params
|
|
113
155
|
|
|
114
156
|
|
|
157
|
+
def _value_info_complete(value_info: onnx.ValueInfoProto) -> bool:
|
|
158
|
+
value_kind = value_info.type.WhichOneof("value")
|
|
159
|
+
if value_kind == "tensor_type":
|
|
160
|
+
tensor_type = value_info.type.tensor_type
|
|
161
|
+
elif value_kind == "optional_type":
|
|
162
|
+
elem_type = value_info.type.optional_type.elem_type
|
|
163
|
+
if elem_type.WhichOneof("value") != "tensor_type":
|
|
164
|
+
return False
|
|
165
|
+
tensor_type = elem_type.tensor_type
|
|
166
|
+
else:
|
|
167
|
+
return False
|
|
168
|
+
if not tensor_type.HasField("elem_type"):
|
|
169
|
+
return False
|
|
170
|
+
if not tensor_type.HasField("shape"):
|
|
171
|
+
return False
|
|
172
|
+
for dim in tensor_type.shape.dim:
|
|
173
|
+
if dim.HasField("dim_value"):
|
|
174
|
+
continue
|
|
175
|
+
if dim.HasField("dim_param"):
|
|
176
|
+
continue
|
|
177
|
+
return False
|
|
178
|
+
return True
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _needs_shape_inference(model: onnx.ModelProto) -> bool:
|
|
182
|
+
graph = model.graph
|
|
183
|
+
value_info_by_name = {
|
|
184
|
+
value_info.name: value_info for value_info in graph.value_info
|
|
185
|
+
}
|
|
186
|
+
output_names = {value_info.name for value_info in graph.output}
|
|
187
|
+
initializer_names = {initializer.name for initializer in graph.initializer}
|
|
188
|
+
initializer_names.update(
|
|
189
|
+
sparse_init.name for sparse_init in graph.sparse_initializer
|
|
190
|
+
)
|
|
191
|
+
for node in graph.node:
|
|
192
|
+
for output in node.output:
|
|
193
|
+
if not output:
|
|
194
|
+
continue
|
|
195
|
+
if output in output_names or output in value_info_by_name:
|
|
196
|
+
continue
|
|
197
|
+
return True
|
|
198
|
+
for value_info in graph.value_info:
|
|
199
|
+
if not _value_info_complete(value_info):
|
|
200
|
+
return True
|
|
201
|
+
for value_info in graph.output:
|
|
202
|
+
if not _value_info_complete(value_info):
|
|
203
|
+
return True
|
|
204
|
+
for value_info in graph.input:
|
|
205
|
+
if value_info.name in initializer_names:
|
|
206
|
+
continue
|
|
207
|
+
if not _value_info_complete(value_info):
|
|
208
|
+
return True
|
|
209
|
+
return False
|
|
210
|
+
|
|
211
|
+
|
|
115
212
|
def _initializer(value: onnx.TensorProto) -> Initializer:
|
|
116
213
|
dtype = scalar_type_from_onnx(value.data_type)
|
|
117
214
|
if dtype is None:
|
|
@@ -136,6 +233,471 @@ def _node_attrs(node: onnx.NodeProto) -> dict[str, object]:
|
|
|
136
233
|
return {attr.name: helper.get_attribute_value(attr) for attr in node.attribute}
|
|
137
234
|
|
|
138
235
|
|
|
236
|
+
def _find_value_info(
|
|
237
|
+
graph: onnx.GraphProto, name: str
|
|
238
|
+
) -> onnx.ValueInfoProto | None:
|
|
239
|
+
for value_info in graph.input:
|
|
240
|
+
if value_info.name == name:
|
|
241
|
+
return value_info
|
|
242
|
+
for value_info in graph.value_info:
|
|
243
|
+
if value_info.name == name:
|
|
244
|
+
return value_info
|
|
245
|
+
for value_info in graph.output:
|
|
246
|
+
if value_info.name == name:
|
|
247
|
+
return value_info
|
|
248
|
+
return None
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _tensor_shape_from_value_info(
|
|
252
|
+
graph: onnx.GraphProto, name: str
|
|
253
|
+
) -> tuple[int, ...]:
|
|
254
|
+
value_info = _find_value_info(graph, name)
|
|
255
|
+
if value_info is None:
|
|
256
|
+
for initializer in graph.initializer:
|
|
257
|
+
if initializer.name == name:
|
|
258
|
+
return tuple(int(dim) for dim in initializer.dims)
|
|
259
|
+
raise ShapeInferenceError(
|
|
260
|
+
f"Missing shape for '{name}' in Scan expansion. "
|
|
261
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
262
|
+
)
|
|
263
|
+
tensor_type = value_info.type.tensor_type
|
|
264
|
+
if not tensor_type.HasField("shape"):
|
|
265
|
+
raise ShapeInferenceError(
|
|
266
|
+
f"Missing shape for '{name}' in Scan expansion. "
|
|
267
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
268
|
+
)
|
|
269
|
+
dims: list[int] = []
|
|
270
|
+
for dim in tensor_type.shape.dim:
|
|
271
|
+
if not dim.HasField("dim_value"):
|
|
272
|
+
raise ShapeInferenceError(
|
|
273
|
+
f"Dynamic dim for '{name}' in Scan expansion. "
|
|
274
|
+
"Hint: export with static shapes."
|
|
275
|
+
)
|
|
276
|
+
dims.append(int(dim.dim_value))
|
|
277
|
+
return tuple(dims)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _scan_attr_ints(
|
|
281
|
+
attrs: dict[str, object],
|
|
282
|
+
key: str,
|
|
283
|
+
*,
|
|
284
|
+
default: tuple[int, ...],
|
|
285
|
+
) -> tuple[int, ...]:
|
|
286
|
+
value = attrs.get(key)
|
|
287
|
+
if value is None:
|
|
288
|
+
return default
|
|
289
|
+
return tuple(int(item) for item in value)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _onnx_opset_version(model: onnx.ModelProto) -> int | None:
|
|
293
|
+
for opset in model.opset_import:
|
|
294
|
+
if opset.domain in {"", "ai.onnx"}:
|
|
295
|
+
return int(opset.version)
|
|
296
|
+
return None
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _scan_expected_axis(is_opset8: bool) -> int:
|
|
300
|
+
return 1 if is_opset8 else 0
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _scan_axes_and_directions(
|
|
304
|
+
attrs: dict[str, object],
|
|
305
|
+
*,
|
|
306
|
+
num_scan_inputs: int,
|
|
307
|
+
scan_output_count: int,
|
|
308
|
+
is_opset8: bool,
|
|
309
|
+
) -> None:
|
|
310
|
+
default_axis = _scan_expected_axis(is_opset8)
|
|
311
|
+
scan_input_axes = _scan_attr_ints(
|
|
312
|
+
attrs,
|
|
313
|
+
"scan_input_axes",
|
|
314
|
+
default=(default_axis,) * num_scan_inputs,
|
|
315
|
+
)
|
|
316
|
+
scan_output_axes = _scan_attr_ints(
|
|
317
|
+
attrs,
|
|
318
|
+
"scan_output_axes",
|
|
319
|
+
default=(default_axis,) * scan_output_count,
|
|
320
|
+
)
|
|
321
|
+
scan_input_directions = _scan_attr_ints(
|
|
322
|
+
attrs,
|
|
323
|
+
"scan_input_directions",
|
|
324
|
+
default=(0,) * num_scan_inputs,
|
|
325
|
+
)
|
|
326
|
+
scan_output_directions = _scan_attr_ints(
|
|
327
|
+
attrs,
|
|
328
|
+
"scan_output_directions",
|
|
329
|
+
default=(0,) * scan_output_count,
|
|
330
|
+
)
|
|
331
|
+
if any(axis != default_axis for axis in scan_input_axes):
|
|
332
|
+
raise UnsupportedOpError(
|
|
333
|
+
f"Scan only supports scan_input_axes={default_axis}"
|
|
334
|
+
)
|
|
335
|
+
if any(axis != default_axis for axis in scan_output_axes):
|
|
336
|
+
raise UnsupportedOpError(
|
|
337
|
+
f"Scan only supports scan_output_axes={default_axis}"
|
|
338
|
+
)
|
|
339
|
+
if any(direction != 0 for direction in scan_input_directions):
|
|
340
|
+
raise UnsupportedOpError(
|
|
341
|
+
"Scan only supports scan_input_directions=0"
|
|
342
|
+
)
|
|
343
|
+
if any(direction != 0 for direction in scan_output_directions):
|
|
344
|
+
raise UnsupportedOpError(
|
|
345
|
+
"Scan only supports scan_output_directions=0"
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _scan_sequence_length(
|
|
350
|
+
graph: onnx.GraphProto,
|
|
351
|
+
scan_input_names: list[str],
|
|
352
|
+
*,
|
|
353
|
+
is_opset8: bool,
|
|
354
|
+
) -> tuple[int, int | None]:
|
|
355
|
+
scan_input_shapes = [
|
|
356
|
+
_tensor_shape_from_value_info(graph, name)
|
|
357
|
+
for name in scan_input_names
|
|
358
|
+
]
|
|
359
|
+
if not scan_input_shapes:
|
|
360
|
+
raise UnsupportedOpError("Scan requires scan inputs")
|
|
361
|
+
if is_opset8:
|
|
362
|
+
if any(len(shape) < 2 for shape in scan_input_shapes):
|
|
363
|
+
raise UnsupportedOpError(
|
|
364
|
+
"Scan opset 8 inputs must include batch and sequence dims"
|
|
365
|
+
)
|
|
366
|
+
batch_size = scan_input_shapes[0][0]
|
|
367
|
+
sequence_len = scan_input_shapes[0][1]
|
|
368
|
+
if batch_size != 1:
|
|
369
|
+
raise UnsupportedOpError(
|
|
370
|
+
"Scan opset 8 currently supports batch size 1 only"
|
|
371
|
+
)
|
|
372
|
+
if sequence_len <= 0:
|
|
373
|
+
raise UnsupportedOpError("Scan requires positive sequence length")
|
|
374
|
+
if any(
|
|
375
|
+
shape[0] != batch_size or shape[1] != sequence_len
|
|
376
|
+
for shape in scan_input_shapes
|
|
377
|
+
):
|
|
378
|
+
raise UnsupportedOpError(
|
|
379
|
+
"Scan inputs must share the same batch and sequence length"
|
|
380
|
+
)
|
|
381
|
+
return sequence_len, batch_size
|
|
382
|
+
sequence_len = scan_input_shapes[0][0]
|
|
383
|
+
if sequence_len <= 0:
|
|
384
|
+
raise UnsupportedOpError("Scan requires positive sequence length")
|
|
385
|
+
if any(shape[0] != sequence_len for shape in scan_input_shapes):
|
|
386
|
+
raise UnsupportedOpError(
|
|
387
|
+
"Scan inputs must share the same sequence length"
|
|
388
|
+
)
|
|
389
|
+
return sequence_len, None
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _scan_body_initializers(
|
|
393
|
+
body: onnx.GraphProto,
|
|
394
|
+
*,
|
|
395
|
+
prefix: str,
|
|
396
|
+
new_initializers: list[onnx.TensorProto],
|
|
397
|
+
) -> dict[str, str]:
|
|
398
|
+
initializer_map: dict[str, str] = {}
|
|
399
|
+
for initializer in body.initializer:
|
|
400
|
+
new_name = f"{prefix}_init_{initializer.name}"
|
|
401
|
+
initializer_map[initializer.name] = new_name
|
|
402
|
+
array = numpy_helper.to_array(initializer)
|
|
403
|
+
new_initializers.append(numpy_helper.from_array(array, name=new_name))
|
|
404
|
+
return initializer_map
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _scan_state_inputs(
|
|
408
|
+
graph: onnx.GraphProto,
|
|
409
|
+
*,
|
|
410
|
+
prefix: str,
|
|
411
|
+
state_input_names: list[str],
|
|
412
|
+
new_nodes: list[onnx.NodeProto],
|
|
413
|
+
is_opset8: bool,
|
|
414
|
+
batch_size: int | None,
|
|
415
|
+
) -> list[str]:
|
|
416
|
+
state_names = list(state_input_names)
|
|
417
|
+
if is_opset8 and state_input_names:
|
|
418
|
+
for state_index, state_name in enumerate(state_input_names):
|
|
419
|
+
state_shape = _tensor_shape_from_value_info(graph, state_name)
|
|
420
|
+
if not state_shape:
|
|
421
|
+
raise UnsupportedOpError(
|
|
422
|
+
"Scan opset 8 state inputs must be tensors"
|
|
423
|
+
)
|
|
424
|
+
if batch_size is not None and state_shape[0] != batch_size:
|
|
425
|
+
raise UnsupportedOpError(
|
|
426
|
+
"Scan opset 8 state inputs must match batch size"
|
|
427
|
+
)
|
|
428
|
+
squeezed_name = f"{prefix}_state{state_index}_squeezed"
|
|
429
|
+
new_nodes.append(
|
|
430
|
+
helper.make_node(
|
|
431
|
+
"Squeeze",
|
|
432
|
+
inputs=[state_name],
|
|
433
|
+
outputs=[squeezed_name],
|
|
434
|
+
name=f"{squeezed_name}_node",
|
|
435
|
+
axes=[0],
|
|
436
|
+
)
|
|
437
|
+
)
|
|
438
|
+
state_names[state_index] = squeezed_name
|
|
439
|
+
return state_names
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def _scan_iteration_inputs(
|
|
443
|
+
*,
|
|
444
|
+
prefix: str,
|
|
445
|
+
iter_index: int,
|
|
446
|
+
scan_input_names: list[str],
|
|
447
|
+
new_nodes: list[onnx.NodeProto],
|
|
448
|
+
is_opset8: bool,
|
|
449
|
+
) -> list[str]:
|
|
450
|
+
scan_iter_inputs: list[str] = []
|
|
451
|
+
slice_axis = _scan_expected_axis(is_opset8)
|
|
452
|
+
squeeze_axes = [0, 1] if is_opset8 else [0]
|
|
453
|
+
for scan_index, scan_name in enumerate(scan_input_names):
|
|
454
|
+
slice_out = f"{prefix}_iter{iter_index}_scan{scan_index}_slice"
|
|
455
|
+
squeeze_out = f"{prefix}_iter{iter_index}_scan{scan_index}_value"
|
|
456
|
+
new_nodes.append(
|
|
457
|
+
helper.make_node(
|
|
458
|
+
"Slice",
|
|
459
|
+
inputs=[scan_name],
|
|
460
|
+
outputs=[slice_out],
|
|
461
|
+
name=f"{slice_out}_node",
|
|
462
|
+
starts=[iter_index],
|
|
463
|
+
ends=[iter_index + 1],
|
|
464
|
+
axes=[slice_axis],
|
|
465
|
+
)
|
|
466
|
+
)
|
|
467
|
+
new_nodes.append(
|
|
468
|
+
helper.make_node(
|
|
469
|
+
"Squeeze",
|
|
470
|
+
inputs=[slice_out],
|
|
471
|
+
outputs=[squeeze_out],
|
|
472
|
+
name=f"{squeeze_out}_node",
|
|
473
|
+
axes=squeeze_axes,
|
|
474
|
+
)
|
|
475
|
+
)
|
|
476
|
+
scan_iter_inputs.append(squeeze_out)
|
|
477
|
+
return scan_iter_inputs
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _expand_scan_nodes(model: onnx.ModelProto) -> tuple[onnx.ModelProto, bool]:
|
|
481
|
+
graph = model.graph
|
|
482
|
+
opset_version = _onnx_opset_version(model)
|
|
483
|
+
if opset_version is None:
|
|
484
|
+
return model, False
|
|
485
|
+
|
|
486
|
+
new_nodes: list[onnx.NodeProto] = []
|
|
487
|
+
new_initializers: list[onnx.TensorProto] = []
|
|
488
|
+
scan_index = 0
|
|
489
|
+
expanded = False
|
|
490
|
+
is_opset8 = opset_version <= 8
|
|
491
|
+
|
|
492
|
+
for node in graph.node:
|
|
493
|
+
if node.op_type != "Scan":
|
|
494
|
+
new_nodes.append(node)
|
|
495
|
+
continue
|
|
496
|
+
|
|
497
|
+
expanded = True
|
|
498
|
+
scan_index += 1
|
|
499
|
+
attrs = _node_attrs(node)
|
|
500
|
+
body = attrs.get("body")
|
|
501
|
+
if not isinstance(body, onnx.GraphProto):
|
|
502
|
+
raise UnsupportedOpError("Scan requires a body graph")
|
|
503
|
+
num_scan_inputs = int(attrs.get("num_scan_inputs", 0))
|
|
504
|
+
if num_scan_inputs <= 0:
|
|
505
|
+
raise UnsupportedOpError("Scan requires num_scan_inputs")
|
|
506
|
+
input_names = list(node.input)
|
|
507
|
+
if is_opset8:
|
|
508
|
+
if not input_names:
|
|
509
|
+
raise UnsupportedOpError("Scan in opset 8 requires inputs")
|
|
510
|
+
sequence_lens = input_names.pop(0)
|
|
511
|
+
if sequence_lens:
|
|
512
|
+
raise UnsupportedOpError(
|
|
513
|
+
"Scan sequence_lens input is not supported"
|
|
514
|
+
)
|
|
515
|
+
num_state_inputs = len(input_names) - num_scan_inputs
|
|
516
|
+
if num_state_inputs < 0:
|
|
517
|
+
raise UnsupportedOpError("Scan input count is invalid")
|
|
518
|
+
if len(body.input) != num_state_inputs + num_scan_inputs:
|
|
519
|
+
raise UnsupportedOpError(
|
|
520
|
+
"Scan body input count must match state and scan inputs"
|
|
521
|
+
)
|
|
522
|
+
if len(body.output) != len(node.output):
|
|
523
|
+
raise UnsupportedOpError(
|
|
524
|
+
"Scan body output count must match Scan outputs"
|
|
525
|
+
)
|
|
526
|
+
scan_output_count = len(node.output) - num_state_inputs
|
|
527
|
+
_scan_axes_and_directions(
|
|
528
|
+
attrs,
|
|
529
|
+
num_scan_inputs=num_scan_inputs,
|
|
530
|
+
scan_output_count=scan_output_count,
|
|
531
|
+
is_opset8=is_opset8,
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
state_input_names = input_names[:num_state_inputs]
|
|
535
|
+
scan_input_names = input_names[num_state_inputs:]
|
|
536
|
+
sequence_len, batch_size = _scan_sequence_length(
|
|
537
|
+
graph,
|
|
538
|
+
scan_input_names,
|
|
539
|
+
is_opset8=is_opset8,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
prefix = node.name or f"scan_{scan_index}"
|
|
543
|
+
initializer_map = _scan_body_initializers(
|
|
544
|
+
body,
|
|
545
|
+
prefix=prefix,
|
|
546
|
+
new_initializers=new_initializers,
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
state_names = _scan_state_inputs(
|
|
550
|
+
graph,
|
|
551
|
+
prefix=prefix,
|
|
552
|
+
state_input_names=state_input_names,
|
|
553
|
+
new_nodes=new_nodes,
|
|
554
|
+
is_opset8=is_opset8,
|
|
555
|
+
batch_size=batch_size,
|
|
556
|
+
)
|
|
557
|
+
scan_output_buffers: list[list[str]] = [
|
|
558
|
+
[] for _ in range(scan_output_count)
|
|
559
|
+
]
|
|
560
|
+
|
|
561
|
+
for iter_index in range(sequence_len):
|
|
562
|
+
scan_iter_inputs = _scan_iteration_inputs(
|
|
563
|
+
prefix=prefix,
|
|
564
|
+
iter_index=iter_index,
|
|
565
|
+
scan_input_names=scan_input_names,
|
|
566
|
+
new_nodes=new_nodes,
|
|
567
|
+
is_opset8=is_opset8,
|
|
568
|
+
)
|
|
569
|
+
name_map: dict[str, str] = {}
|
|
570
|
+
for index, value in enumerate(body.input[:num_state_inputs]):
|
|
571
|
+
name_map[value.name] = state_names[index]
|
|
572
|
+
for index, value in enumerate(
|
|
573
|
+
body.input[num_state_inputs : num_state_inputs + num_scan_inputs]
|
|
574
|
+
):
|
|
575
|
+
name_map[value.name] = scan_iter_inputs[index]
|
|
576
|
+
for original, mapped in initializer_map.items():
|
|
577
|
+
name_map[original] = mapped
|
|
578
|
+
|
|
579
|
+
for body_node in body.node:
|
|
580
|
+
body_attrs = _node_attrs(body_node)
|
|
581
|
+
mapped_inputs = [
|
|
582
|
+
name_map.get(input_name, input_name)
|
|
583
|
+
for input_name in body_node.input
|
|
584
|
+
]
|
|
585
|
+
mapped_outputs: list[str] = []
|
|
586
|
+
for output_name in body_node.output:
|
|
587
|
+
if not output_name:
|
|
588
|
+
mapped_outputs.append("")
|
|
589
|
+
continue
|
|
590
|
+
mapped_name = (
|
|
591
|
+
f"{prefix}_iter{iter_index}_{output_name}"
|
|
592
|
+
)
|
|
593
|
+
name_map[output_name] = mapped_name
|
|
594
|
+
mapped_outputs.append(mapped_name)
|
|
595
|
+
new_nodes.append(
|
|
596
|
+
helper.make_node(
|
|
597
|
+
body_node.op_type,
|
|
598
|
+
inputs=mapped_inputs,
|
|
599
|
+
outputs=mapped_outputs,
|
|
600
|
+
name=(
|
|
601
|
+
f"{prefix}_iter{iter_index}_{body_node.name}"
|
|
602
|
+
if body_node.name
|
|
603
|
+
else ""
|
|
604
|
+
),
|
|
605
|
+
domain=body_node.domain,
|
|
606
|
+
**body_attrs,
|
|
607
|
+
)
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
for index, output in enumerate(body.output[:num_state_inputs]):
|
|
611
|
+
mapped_output = name_map.get(output.name)
|
|
612
|
+
if mapped_output is None:
|
|
613
|
+
raise UnsupportedOpError(
|
|
614
|
+
"Scan body did not produce a required state output"
|
|
615
|
+
)
|
|
616
|
+
state_names[index] = mapped_output
|
|
617
|
+
|
|
618
|
+
for output_index, output in enumerate(
|
|
619
|
+
body.output[
|
|
620
|
+
num_state_inputs : num_state_inputs + scan_output_count
|
|
621
|
+
]
|
|
622
|
+
):
|
|
623
|
+
mapped_output = name_map.get(output.name)
|
|
624
|
+
if mapped_output is None:
|
|
625
|
+
raise UnsupportedOpError(
|
|
626
|
+
"Scan body did not produce a required scan output"
|
|
627
|
+
)
|
|
628
|
+
unsqueeze_out = (
|
|
629
|
+
f"{prefix}_iter{iter_index}_scanout{output_index}"
|
|
630
|
+
)
|
|
631
|
+
unsqueeze_axes = [0, 1] if is_opset8 else [0]
|
|
632
|
+
new_nodes.append(
|
|
633
|
+
helper.make_node(
|
|
634
|
+
"Unsqueeze",
|
|
635
|
+
inputs=[mapped_output],
|
|
636
|
+
outputs=[unsqueeze_out],
|
|
637
|
+
name=f"{unsqueeze_out}_node",
|
|
638
|
+
axes=unsqueeze_axes,
|
|
639
|
+
)
|
|
640
|
+
)
|
|
641
|
+
scan_output_buffers[output_index].append(unsqueeze_out)
|
|
642
|
+
|
|
643
|
+
for index, output_name in enumerate(node.output[:num_state_inputs]):
|
|
644
|
+
state_value = state_names[index]
|
|
645
|
+
if is_opset8:
|
|
646
|
+
expanded_state = f"{prefix}_state_output_{index}_expanded"
|
|
647
|
+
new_nodes.append(
|
|
648
|
+
helper.make_node(
|
|
649
|
+
"Unsqueeze",
|
|
650
|
+
inputs=[state_value],
|
|
651
|
+
outputs=[expanded_state],
|
|
652
|
+
name=f"{expanded_state}_node",
|
|
653
|
+
axes=[0],
|
|
654
|
+
)
|
|
655
|
+
)
|
|
656
|
+
state_value = expanded_state
|
|
657
|
+
if state_value == output_name:
|
|
658
|
+
continue
|
|
659
|
+
new_nodes.append(
|
|
660
|
+
helper.make_node(
|
|
661
|
+
"Identity",
|
|
662
|
+
inputs=[state_value],
|
|
663
|
+
outputs=[output_name],
|
|
664
|
+
name=f"{prefix}_state_output_{index}",
|
|
665
|
+
)
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
for output_index, output_name in enumerate(
|
|
669
|
+
node.output[num_state_inputs : num_state_inputs + scan_output_count]
|
|
670
|
+
):
|
|
671
|
+
buffer = scan_output_buffers[output_index]
|
|
672
|
+
concat_axis = _scan_expected_axis(is_opset8)
|
|
673
|
+
if len(buffer) == 1:
|
|
674
|
+
new_nodes.append(
|
|
675
|
+
helper.make_node(
|
|
676
|
+
"Identity",
|
|
677
|
+
inputs=buffer,
|
|
678
|
+
outputs=[output_name],
|
|
679
|
+
name=f"{prefix}_scan_output_{output_index}",
|
|
680
|
+
)
|
|
681
|
+
)
|
|
682
|
+
else:
|
|
683
|
+
new_nodes.append(
|
|
684
|
+
helper.make_node(
|
|
685
|
+
"Concat",
|
|
686
|
+
inputs=buffer,
|
|
687
|
+
outputs=[output_name],
|
|
688
|
+
name=f"{prefix}_scan_output_{output_index}",
|
|
689
|
+
axis=concat_axis,
|
|
690
|
+
)
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
if expanded:
|
|
694
|
+
del graph.node[:]
|
|
695
|
+
graph.node.extend(new_nodes)
|
|
696
|
+
if new_initializers:
|
|
697
|
+
graph.initializer.extend(new_initializers)
|
|
698
|
+
return model, expanded
|
|
699
|
+
|
|
700
|
+
|
|
139
701
|
def _constant_initializer(node: onnx.NodeProto) -> Initializer:
|
|
140
702
|
if len(node.output) != 1:
|
|
141
703
|
raise UnsupportedOpError("Constant must have exactly one output")
|
|
@@ -209,16 +771,18 @@ def _constant_initializer(node: onnx.NodeProto) -> Initializer:
|
|
|
209
771
|
|
|
210
772
|
|
|
211
773
|
def import_onnx(model: onnx.ModelProto) -> Graph:
|
|
774
|
+
model, _ = _expand_scan_nodes(model)
|
|
212
775
|
dim_param_by_name = _collect_dim_params(
|
|
213
776
|
tuple(model.graph.input) + tuple(model.graph.output)
|
|
214
777
|
)
|
|
215
778
|
opset_imports = tuple(
|
|
216
779
|
(opset.domain, opset.version) for opset in model.opset_import
|
|
217
780
|
)
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
781
|
+
if _needs_shape_inference(model):
|
|
782
|
+
try:
|
|
783
|
+
model = shape_inference.infer_shapes(model, data_prop=True)
|
|
784
|
+
except Exception as exc: # pragma: no cover - onnx inference errors
|
|
785
|
+
raise ShapeInferenceError("ONNX shape inference failed") from exc
|
|
222
786
|
graph = model.graph
|
|
223
787
|
base_initializers = [_initializer(value) for value in graph.initializer]
|
|
224
788
|
constant_initializers: list[Initializer] = []
|
emx_onnx_cgen/ops.py
CHANGED
|
@@ -554,6 +554,9 @@ def unary_op_symbol(function: ScalarFunction, *, dtype: ScalarType) -> str | Non
|
|
|
554
554
|
def apply_binary_op(
|
|
555
555
|
op_spec: BinaryOpSpec, left: np.ndarray, right: np.ndarray
|
|
556
556
|
) -> np.ndarray:
|
|
557
|
+
if op_spec.apply is np.power:
|
|
558
|
+
with np.errstate(invalid="ignore", divide="ignore", over="ignore"):
|
|
559
|
+
return op_spec.apply(left, right)
|
|
557
560
|
return op_spec.apply(left, right)
|
|
558
561
|
|
|
559
562
|
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} r = {{ rate }}[0] / ({{ one_literal }} + ({{ c_type }}){{ timestep }}[0] * {{ decay_factor_literal }});
|
|
3
|
+
{% for tensor in tensors %}
|
|
4
|
+
{% for dim in tensor.shape %}
|
|
5
|
+
for (idx_t {{ tensor.loop_vars[loop.index0] }} = 0; {{ tensor.loop_vars[loop.index0] }} < {{ dim }}; ++{{ tensor.loop_vars[loop.index0] }}) {
|
|
6
|
+
{% endfor %}
|
|
7
|
+
{{ c_type }} g_regularized = {{ norm_coefficient_literal }} * {{ tensor.input_expr }} + {{ tensor.grad_expr }};
|
|
8
|
+
{{ c_type }} h_new = {{ tensor.acc_expr }} + g_regularized * g_regularized;
|
|
9
|
+
{{ tensor.acc_output_expr }} = h_new;
|
|
10
|
+
{{ c_type }} h_adaptive = {{ sqrt_fn }}(h_new) + {{ epsilon_literal }};
|
|
11
|
+
{{ tensor.output_expr }} = {{ tensor.input_expr }} - r * g_regularized / h_adaptive;
|
|
12
|
+
{% for _ in tensor.shape %}
|
|
13
|
+
}
|
|
14
|
+
{% endfor %}
|
|
15
|
+
{% endfor %}
|
|
16
|
+
}
|