emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (99) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +372 -64
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +169 -343
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +193 -0
  11. emx_onnx_cgen/ir/op_context.py +65 -0
  12. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  13. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  14. emx_onnx_cgen/ir/ops/misc.py +421 -0
  15. emx_onnx_cgen/ir/ops/nn.py +580 -0
  16. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  17. emx_onnx_cgen/lowering/__init__.py +79 -1
  18. emx_onnx_cgen/lowering/adagrad.py +114 -0
  19. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  20. emx_onnx_cgen/lowering/attention.py +1 -1
  21. emx_onnx_cgen/lowering/average_pool.py +1 -1
  22. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  23. emx_onnx_cgen/lowering/cast.py +1 -1
  24. emx_onnx_cgen/lowering/common.py +406 -11
  25. emx_onnx_cgen/lowering/concat.py +1 -1
  26. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  27. emx_onnx_cgen/lowering/conv.py +1 -1
  28. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  29. emx_onnx_cgen/lowering/cumsum.py +1 -1
  30. emx_onnx_cgen/lowering/depth_space.py +1 -1
  31. emx_onnx_cgen/lowering/dropout.py +1 -1
  32. emx_onnx_cgen/lowering/einsum.py +153 -0
  33. emx_onnx_cgen/lowering/elementwise.py +152 -4
  34. emx_onnx_cgen/lowering/expand.py +1 -1
  35. emx_onnx_cgen/lowering/eye_like.py +1 -1
  36. emx_onnx_cgen/lowering/flatten.py +1 -1
  37. emx_onnx_cgen/lowering/gather.py +1 -1
  38. emx_onnx_cgen/lowering/gather_elements.py +2 -4
  39. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  40. emx_onnx_cgen/lowering/gemm.py +1 -1
  41. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  42. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  43. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  44. emx_onnx_cgen/lowering/hardmax.py +53 -0
  45. emx_onnx_cgen/lowering/identity.py +7 -6
  46. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  48. emx_onnx_cgen/lowering/logsoftmax.py +6 -2
  49. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  50. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  51. emx_onnx_cgen/lowering/lrn.py +1 -1
  52. emx_onnx_cgen/lowering/lstm.py +1 -1
  53. emx_onnx_cgen/lowering/matmul.py +7 -8
  54. emx_onnx_cgen/lowering/maxpool.py +1 -1
  55. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  56. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
  57. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  58. emx_onnx_cgen/lowering/nonzero.py +42 -0
  59. emx_onnx_cgen/lowering/one_hot.py +120 -0
  60. emx_onnx_cgen/lowering/pad.py +1 -1
  61. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  62. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  63. emx_onnx_cgen/lowering/range.py +1 -1
  64. emx_onnx_cgen/lowering/reduce.py +6 -7
  65. emx_onnx_cgen/lowering/registry.py +24 -5
  66. emx_onnx_cgen/lowering/reshape.py +224 -52
  67. emx_onnx_cgen/lowering/resize.py +1 -1
  68. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  69. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  70. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  71. emx_onnx_cgen/lowering/shape.py +6 -25
  72. emx_onnx_cgen/lowering/size.py +1 -1
  73. emx_onnx_cgen/lowering/slice.py +1 -1
  74. emx_onnx_cgen/lowering/softmax.py +6 -2
  75. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  76. emx_onnx_cgen/lowering/split.py +1 -1
  77. emx_onnx_cgen/lowering/squeeze.py +6 -6
  78. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  79. emx_onnx_cgen/lowering/tile.py +1 -1
  80. emx_onnx_cgen/lowering/topk.py +134 -0
  81. emx_onnx_cgen/lowering/transpose.py +1 -1
  82. emx_onnx_cgen/lowering/trilu.py +89 -0
  83. emx_onnx_cgen/lowering/unsqueeze.py +6 -6
  84. emx_onnx_cgen/lowering/variadic.py +1 -1
  85. emx_onnx_cgen/lowering/where.py +1 -1
  86. emx_onnx_cgen/onnx_import.py +4 -0
  87. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  88. emx_onnx_cgen/ops.py +4 -0
  89. emx_onnx_cgen/runtime/evaluator.py +785 -43
  90. emx_onnx_cgen/testbench.py +23 -0
  91. emx_onnx_cgen/verification.py +31 -0
  92. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
  93. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  94. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  95. shared/scalar_functions.py +60 -17
  96. shared/ulp.py +65 -0
  97. emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
  98. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  99. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,421 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ...errors import ShapeInferenceError
8
+ from ..op_base import BroadcastingOpBase, RenderableOpBase
9
+ from ..op_context import OpContext
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class CastOp(RenderableOpBase):
14
+ input0: str
15
+ output: str
16
+ shape: tuple[int, ...]
17
+ input_dtype: ScalarType
18
+ dtype: ScalarType
19
+
20
+ def infer_types(self, ctx: OpContext) -> None:
21
+ ctx.dtype(self.input0)
22
+ ctx.dtype(self.output)
23
+
24
+ def infer_shapes(self, ctx: OpContext) -> None:
25
+ shape = ctx.shape(self.input0)
26
+ ctx.set_shape(self.output, shape)
27
+
28
+ @dataclass(frozen=True)
29
+ class QuantizeLinearOp(RenderableOpBase):
30
+ input0: str
31
+ scale: str
32
+ zero_point: str | None
33
+ output: str
34
+ input_shape: tuple[int, ...]
35
+ axis: int | None
36
+ dtype: ScalarType
37
+ input_dtype: ScalarType
38
+ scale_dtype: ScalarType
39
+
40
+ @dataclass(frozen=True)
41
+ class ConcatOp(RenderableOpBase):
42
+ inputs: tuple[str, ...]
43
+ output: str
44
+ axis: int
45
+ input_shapes: tuple[tuple[int, ...], ...]
46
+ output_shape: tuple[int, ...]
47
+ dtype: ScalarType
48
+
49
+ @dataclass(frozen=True)
50
+ class GatherElementsOp(RenderableOpBase):
51
+ data: str
52
+ indices: str
53
+ output: str
54
+ axis: int
55
+ data_shape: tuple[int, ...]
56
+ indices_shape: tuple[int, ...]
57
+ output_shape: tuple[int, ...]
58
+ dtype: ScalarType
59
+ indices_dtype: ScalarType
60
+
61
+ @dataclass(frozen=True)
62
+ class GatherOp(RenderableOpBase):
63
+ data: str
64
+ indices: str
65
+ output: str
66
+ axis: int
67
+ data_shape: tuple[int, ...]
68
+ indices_shape: tuple[int, ...]
69
+ output_shape: tuple[int, ...]
70
+ dtype: ScalarType
71
+ indices_dtype: ScalarType
72
+
73
+ @dataclass(frozen=True)
74
+ class GatherNDOp(RenderableOpBase):
75
+ data: str
76
+ indices: str
77
+ output: str
78
+ batch_dims: int
79
+ data_shape: tuple[int, ...]
80
+ indices_shape: tuple[int, ...]
81
+ output_shape: tuple[int, ...]
82
+ dtype: ScalarType
83
+ indices_dtype: ScalarType
84
+
85
+ @dataclass(frozen=True)
86
+ class ScatterNDOp(RenderableOpBase):
87
+ data: str
88
+ indices: str
89
+ updates: str
90
+ output: str
91
+ data_shape: tuple[int, ...]
92
+ indices_shape: tuple[int, ...]
93
+ updates_shape: tuple[int, ...]
94
+ output_shape: tuple[int, ...]
95
+ reduction: str
96
+ dtype: ScalarType
97
+ indices_dtype: ScalarType
98
+
99
+ @dataclass(frozen=True)
100
+ class TensorScatterOp(RenderableOpBase):
101
+ past_cache: str
102
+ update: str
103
+ write_indices: str | None
104
+ output: str
105
+ past_cache_shape: tuple[int, ...]
106
+ update_shape: tuple[int, ...]
107
+ output_shape: tuple[int, ...]
108
+ write_indices_shape: tuple[int, ...] | None
109
+ axis: int
110
+ mode: str
111
+ dtype: ScalarType
112
+ write_indices_dtype: ScalarType | None
113
+
114
+ @dataclass(frozen=True)
115
+ class TransposeOp(RenderableOpBase):
116
+ input0: str
117
+ output: str
118
+ perm: tuple[int, ...]
119
+ input_shape: tuple[int, ...]
120
+ output_shape: tuple[int, ...]
121
+ dtype: ScalarType
122
+ input_dtype: ScalarType
123
+
124
+ def infer_shapes(self, ctx: OpContext) -> None:
125
+ input_shape = ctx.shape(self.input0)
126
+ if len(self.perm) != len(input_shape):
127
+ raise ShapeInferenceError(
128
+ "Transpose perm rank must match input rank, "
129
+ f"got perm {self.perm} for input shape {input_shape}"
130
+ )
131
+ output_shape = tuple(input_shape[axis] for axis in self.perm)
132
+ ctx.set_shape(self.output, output_shape)
133
+
134
+ @dataclass(frozen=True)
135
+ class ReshapeOp(RenderableOpBase):
136
+ input0: str
137
+ output: str
138
+ input_shape: tuple[int, ...]
139
+ output_shape: tuple[int, ...] | None
140
+ dtype: ScalarType
141
+ input_dtype: ScalarType
142
+
143
+ def infer_shapes(self, ctx: OpContext) -> None:
144
+ input_shape = ctx.shape(self.input0)
145
+ output_shape = (
146
+ self.output_shape
147
+ if self.output_shape is not None
148
+ else ctx.shape(self.output)
149
+ )
150
+ ctx.set_shape(self.output, output_shape)
151
+
152
+ @dataclass(frozen=True)
153
+ class EyeLikeOp(RenderableOpBase):
154
+ input0: str
155
+ output: str
156
+ output_shape: tuple[int, ...]
157
+ k: int
158
+ dtype: ScalarType
159
+ input_dtype: ScalarType
160
+
161
+ @dataclass(frozen=True)
162
+ class TriluOp(RenderableOpBase):
163
+ input0: str
164
+ output: str
165
+ input_shape: tuple[int, ...]
166
+ output_shape: tuple[int, ...]
167
+ upper: bool
168
+ k_value: int
169
+ k_input: str | None
170
+ k_input_shape: tuple[int, ...] | None
171
+ k_input_dtype: ScalarType | None
172
+ dtype: ScalarType
173
+ input_dtype: ScalarType
174
+
175
+ @dataclass(frozen=True)
176
+ class TileOp(RenderableOpBase):
177
+ input0: str
178
+ output: str
179
+ input_shape: tuple[int, ...]
180
+ output_shape: tuple[int, ...]
181
+ repeats: tuple[int, ...]
182
+ input_strides: tuple[int, ...]
183
+ dtype: ScalarType
184
+ input_dtype: ScalarType
185
+
186
+ @dataclass(frozen=True)
187
+ class PadOp(RenderableOpBase):
188
+ input0: str
189
+ output: str
190
+ input_shape: tuple[int, ...]
191
+ output_shape: tuple[int, ...]
192
+ pads_begin: tuple[int, ...] | None
193
+ pads_end: tuple[int, ...] | None
194
+ pads_input: str | None
195
+ pads_shape: tuple[int, ...] | None
196
+ pads_dtype: ScalarType | None
197
+ pads_axis_map: tuple[int | None, ...] | None
198
+ pads_values: tuple[int, ...] | None
199
+ axes_input: str | None
200
+ axes_shape: tuple[int, ...] | None
201
+ axes_dtype: ScalarType | None
202
+ mode: str
203
+ value: float | int | bool
204
+ value_input: str | None
205
+ value_shape: tuple[int, ...] | None
206
+ dtype: ScalarType
207
+ input_dtype: ScalarType
208
+ input_strides: tuple[int, ...]
209
+
210
+ @dataclass(frozen=True)
211
+ class DepthToSpaceOp(RenderableOpBase):
212
+ input0: str
213
+ output: str
214
+ input_shape: tuple[int, ...]
215
+ output_shape: tuple[int, ...]
216
+ blocksize: int
217
+ mode: str
218
+ dtype: ScalarType
219
+ input_dtype: ScalarType
220
+
221
+ @dataclass(frozen=True)
222
+ class SpaceToDepthOp(RenderableOpBase):
223
+ input0: str
224
+ output: str
225
+ input_shape: tuple[int, ...]
226
+ output_shape: tuple[int, ...]
227
+ blocksize: int
228
+ dtype: ScalarType
229
+ input_dtype: ScalarType
230
+
231
+ @dataclass(frozen=True)
232
+ class SliceOp(RenderableOpBase):
233
+ input0: str
234
+ output: str
235
+ input_shape: tuple[int, ...]
236
+ output_shape: tuple[int, ...]
237
+ starts: tuple[int, ...] | None
238
+ steps: tuple[int, ...] | None
239
+ axes: tuple[int, ...] | None
240
+ starts_input: str | None
241
+ ends_input: str | None
242
+ axes_input: str | None
243
+ steps_input: str | None
244
+ starts_shape: tuple[int, ...] | None
245
+ ends_shape: tuple[int, ...] | None
246
+ axes_shape: tuple[int, ...] | None
247
+ steps_shape: tuple[int, ...] | None
248
+ starts_dtype: ScalarType | None
249
+ ends_dtype: ScalarType | None
250
+ axes_dtype: ScalarType | None
251
+ steps_dtype: ScalarType | None
252
+ dtype: ScalarType
253
+ input_dtype: ScalarType
254
+
255
+ @dataclass(frozen=True)
256
+ class ResizeOp(RenderableOpBase):
257
+ input0: str
258
+ output: str
259
+ input_shape: tuple[int, ...]
260
+ output_shape: tuple[int, ...]
261
+ scales: tuple[float, ...]
262
+ scales_input: str | None
263
+ sizes_input: str | None
264
+ roi_input: str | None
265
+ axes: tuple[int, ...]
266
+ scales_shape: tuple[int, ...] | None
267
+ sizes_shape: tuple[int, ...] | None
268
+ roi_shape: tuple[int, ...] | None
269
+ scales_dtype: ScalarType | None
270
+ sizes_dtype: ScalarType | None
271
+ roi_dtype: ScalarType | None
272
+ scales_axes: tuple[int, ...] | None
273
+ sizes_axes: tuple[int, ...] | None
274
+ roi_axes: tuple[int, ...] | None
275
+ mode: str
276
+ coordinate_transformation_mode: str
277
+ nearest_mode: str
278
+ cubic_coeff_a: float
279
+ exclude_outside: bool
280
+ extrapolation_value: float
281
+ antialias: bool
282
+ keep_aspect_ratio_policy: str
283
+ dtype: ScalarType
284
+
285
+ @dataclass(frozen=True)
286
+ class GridSampleOp(RenderableOpBase):
287
+ input0: str
288
+ grid: str
289
+ output: str
290
+ input_shape: tuple[int, ...]
291
+ grid_shape: tuple[int, ...]
292
+ output_shape: tuple[int, ...]
293
+ spatial_rank: int
294
+ input_spatial: tuple[int, ...]
295
+ output_spatial: tuple[int, ...]
296
+ mode: str
297
+ padding_mode: str
298
+ align_corners: bool
299
+ dtype: ScalarType
300
+ grid_dtype: ScalarType
301
+
302
+ @dataclass(frozen=True)
303
+ class ConstantOfShapeOp(RenderableOpBase):
304
+ input0: str
305
+ output: str
306
+ input_shape: tuple[int, ...]
307
+ shape: tuple[int, ...]
308
+ value: float | int | bool
309
+ dtype: ScalarType
310
+ input_dtype: ScalarType
311
+
312
+ @dataclass(frozen=True)
313
+ class ShapeOp(RenderableOpBase):
314
+ input0: str
315
+ output: str
316
+ input_shape: tuple[int, ...]
317
+ output_shape: tuple[int, ...]
318
+ values: tuple[int, ...]
319
+ dtype: ScalarType
320
+ input_dtype: ScalarType
321
+
322
+ @dataclass(frozen=True)
323
+ class SizeOp(RenderableOpBase):
324
+ input0: str
325
+ output: str
326
+ input_shape: tuple[int, ...]
327
+ output_shape: tuple[int, ...]
328
+ value: int
329
+ dtype: ScalarType
330
+ input_dtype: ScalarType
331
+
332
+ @dataclass(frozen=True)
333
+ class NonZeroOp(RenderableOpBase):
334
+ input0: str
335
+ output: str
336
+ input_shape: tuple[int, ...]
337
+ output_shape: tuple[int, ...]
338
+ dtype: ScalarType
339
+ input_dtype: ScalarType
340
+
341
+ @dataclass(frozen=True)
342
+ class NonMaxSuppressionOp(RenderableOpBase):
343
+ boxes: str
344
+ scores: str
345
+ max_output_boxes_per_class: str | None
346
+ iou_threshold: str | None
347
+ score_threshold: str | None
348
+ output: str
349
+ boxes_shape: tuple[int, ...]
350
+ scores_shape: tuple[int, ...]
351
+ output_shape: tuple[int, ...]
352
+ center_point_box: int
353
+ boxes_dtype: ScalarType
354
+ output_dtype: ScalarType
355
+ max_output_dtype: ScalarType | None
356
+ max_output_shape: tuple[int, ...] | None
357
+ iou_threshold_dtype: ScalarType | None
358
+ iou_threshold_shape: tuple[int, ...] | None
359
+ score_threshold_dtype: ScalarType | None
360
+ score_threshold_shape: tuple[int, ...] | None
361
+
362
+ @dataclass(frozen=True)
363
+ class ExpandOp(BroadcastingOpBase):
364
+ input0: str
365
+ output: str
366
+ input_shape: tuple[int, ...]
367
+ output_shape: tuple[int, ...]
368
+ input_shape_padded: tuple[int, ...]
369
+ input_strides: tuple[int, ...]
370
+ dtype: ScalarType
371
+ input_dtype: ScalarType
372
+
373
+ @dataclass(frozen=True)
374
+ class CumSumOp(RenderableOpBase):
375
+ input0: str
376
+ axis_input: str | None
377
+ axis_input_dtype: ScalarType | None
378
+ axis: int | None
379
+ output: str
380
+ input_shape: tuple[int, ...]
381
+ dtype: ScalarType
382
+ input_dtype: ScalarType
383
+ exclusive: bool
384
+ reverse: bool
385
+
386
+ @dataclass(frozen=True)
387
+ class RangeOp(RenderableOpBase):
388
+ start: str
389
+ limit: str
390
+ delta: str
391
+ output: str
392
+ output_shape: tuple[int, ...]
393
+ length: int
394
+ dtype: ScalarType
395
+ input_dtype: ScalarType
396
+
397
+ @dataclass(frozen=True)
398
+ class OneHotOp(RenderableOpBase):
399
+ indices: str
400
+ depth: str
401
+ values: str
402
+ output: str
403
+ axis: int
404
+ indices_shape: tuple[int, ...]
405
+ values_shape: tuple[int, ...]
406
+ output_shape: tuple[int, ...]
407
+ depth_dim: int
408
+ dtype: ScalarType
409
+ indices_dtype: ScalarType
410
+ depth_dtype: ScalarType
411
+
412
+ @dataclass(frozen=True)
413
+ class SplitOp(RenderableOpBase):
414
+ input0: str
415
+ outputs: tuple[str, ...]
416
+ input_shape: tuple[int, ...]
417
+ output_shapes: tuple[tuple[int, ...], ...]
418
+ axis: int
419
+ split_sizes: tuple[int, ...]
420
+ dtype: ScalarType
421
+ input_dtype: ScalarType