emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
shared/scalar_types.py ADDED
@@ -0,0 +1,243 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+
5
+ import numpy as np
6
+
7
+
8
+ class ScalarFunctionError(RuntimeError):
9
+ pass
10
+
11
+
12
+ class ScalarType(str, Enum):
13
+ def __new__(
14
+ cls,
15
+ suffix: str,
16
+ onnx_name: str,
17
+ c_type: str,
18
+ np_dtype: np.dtype,
19
+ zero_literal: str,
20
+ min_literal: str,
21
+ max_literal: str,
22
+ is_float: bool,
23
+ is_signed: bool,
24
+ is_bool: bool,
25
+ bits: int | None,
26
+ ) -> "ScalarType":
27
+ obj = str.__new__(cls, suffix)
28
+ obj._value_ = suffix
29
+ obj.suffix = suffix
30
+ obj.onnx_name = onnx_name
31
+ obj.c_type = c_type
32
+ obj.np_dtype = np.dtype(np_dtype)
33
+ obj.zero_literal = zero_literal
34
+ obj.min_literal = min_literal
35
+ obj.max_literal = max_literal
36
+ obj.is_float = is_float
37
+ obj.is_signed = is_signed
38
+ obj.is_bool = is_bool
39
+ obj.bits = bits
40
+ return obj
41
+
42
+ F16 = (
43
+ "f16",
44
+ "float16",
45
+ "_Float16",
46
+ np.dtype("float16"),
47
+ "0.0f",
48
+ "-INFINITY",
49
+ "INFINITY",
50
+ True,
51
+ True,
52
+ False,
53
+ 16,
54
+ )
55
+ F32 = (
56
+ "f32",
57
+ "float",
58
+ "float",
59
+ np.dtype("float32"),
60
+ "0.0f",
61
+ "-INFINITY",
62
+ "INFINITY",
63
+ True,
64
+ True,
65
+ False,
66
+ 32,
67
+ )
68
+ F64 = (
69
+ "f64",
70
+ "double",
71
+ "double",
72
+ np.dtype("float64"),
73
+ "0.0",
74
+ "-INFINITY",
75
+ "INFINITY",
76
+ True,
77
+ True,
78
+ False,
79
+ 64,
80
+ )
81
+ I8 = (
82
+ "i8",
83
+ "int8",
84
+ "int8_t",
85
+ np.dtype("int8"),
86
+ "0",
87
+ "INT8_MIN",
88
+ "INT8_MAX",
89
+ False,
90
+ True,
91
+ False,
92
+ 8,
93
+ )
94
+ I16 = (
95
+ "i16",
96
+ "int16",
97
+ "int16_t",
98
+ np.dtype("int16"),
99
+ "0",
100
+ "INT16_MIN",
101
+ "INT16_MAX",
102
+ False,
103
+ True,
104
+ False,
105
+ 16,
106
+ )
107
+ I32 = (
108
+ "i32",
109
+ "int32",
110
+ "int32_t",
111
+ np.dtype("int32"),
112
+ "0",
113
+ "INT32_MIN",
114
+ "INT32_MAX",
115
+ False,
116
+ True,
117
+ False,
118
+ 32,
119
+ )
120
+ I64 = (
121
+ "i64",
122
+ "int64",
123
+ "int64_t",
124
+ np.dtype("int64"),
125
+ "0",
126
+ "INT64_MIN",
127
+ "INT64_MAX",
128
+ False,
129
+ True,
130
+ False,
131
+ 64,
132
+ )
133
+ U8 = (
134
+ "u8",
135
+ "uint8",
136
+ "uint8_t",
137
+ np.dtype("uint8"),
138
+ "0",
139
+ "0",
140
+ "UINT8_MAX",
141
+ False,
142
+ False,
143
+ False,
144
+ 8,
145
+ )
146
+ U16 = (
147
+ "u16",
148
+ "uint16",
149
+ "uint16_t",
150
+ np.dtype("uint16"),
151
+ "0",
152
+ "0",
153
+ "UINT16_MAX",
154
+ False,
155
+ False,
156
+ False,
157
+ 16,
158
+ )
159
+ U32 = (
160
+ "u32",
161
+ "uint32",
162
+ "uint32_t",
163
+ np.dtype("uint32"),
164
+ "0",
165
+ "0",
166
+ "UINT32_MAX",
167
+ False,
168
+ False,
169
+ False,
170
+ 32,
171
+ )
172
+ U64 = (
173
+ "u64",
174
+ "uint64",
175
+ "uint64_t",
176
+ np.dtype("uint64"),
177
+ "0",
178
+ "0",
179
+ "UINT64_MAX",
180
+ False,
181
+ False,
182
+ False,
183
+ 64,
184
+ )
185
+ BOOL = (
186
+ "bool",
187
+ "bool",
188
+ "bool",
189
+ np.dtype("bool"),
190
+ "false",
191
+ "false",
192
+ "true",
193
+ False,
194
+ False,
195
+ True,
196
+ None,
197
+ )
198
+
199
+ @property
200
+ def is_integer(self) -> bool:
201
+ return not self.is_float and not self.is_bool
202
+
203
+ @classmethod
204
+ def from_torch_dtype(cls, dtype: object) -> "ScalarType":
205
+ if isinstance(dtype, ScalarType):
206
+ return dtype
207
+ if isinstance(dtype, str):
208
+ dtype_name = dtype
209
+ else:
210
+ dtype_name = getattr(dtype, "name", None) or str(dtype)
211
+ normalized = dtype_name.lower()
212
+ if normalized.startswith("torch."):
213
+ normalized = normalized[len("torch.") :]
214
+ mapping = {
215
+ "float16": cls.F16,
216
+ "float32": cls.F32,
217
+ "float64": cls.F64,
218
+ "int8": cls.I8,
219
+ "int16": cls.I16,
220
+ "int32": cls.I32,
221
+ "int64": cls.I64,
222
+ "uint8": cls.U8,
223
+ "uint16": cls.U16,
224
+ "uint32": cls.U32,
225
+ "uint64": cls.U64,
226
+ "bool": cls.BOOL,
227
+ }
228
+ try:
229
+ return mapping[normalized]
230
+ except KeyError as exc:
231
+ raise ScalarFunctionError(
232
+ f"unsupported dtype for scalar functions: {dtype_name}"
233
+ ) from exc
234
+
235
+ @classmethod
236
+ def from_onnx_name(cls, name: str) -> "ScalarType":
237
+ if isinstance(name, ScalarType):
238
+ return name
239
+ mapping = {scalar.onnx_name: scalar for scalar in cls}
240
+ try:
241
+ return mapping[name]
242
+ except KeyError as exc:
243
+ raise ScalarFunctionError(f"unsupported ONNX dtype: {name}") from exc