emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
emx_onnx_cgen/ops.py ADDED
@@ -0,0 +1,565 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable, Mapping
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ import math
7
+
8
+ import numpy as np
9
+
10
+ from shared.scalar_functions import ScalarFunction
11
+ from shared.scalar_types import ScalarType
12
+
13
+ from .errors import UnsupportedOpError
14
+
15
+
16
+ _NP_ERF = getattr(np, "erf", None)
17
+ if _NP_ERF is None:
18
+ _NP_ERF = np.vectorize(math.erf, otypes=[float])
19
+
20
+
21
+ class OperatorKind(str, Enum):
22
+ INFIX = "infix"
23
+ FUNC = "func"
24
+ EXPR = "expr"
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class BinaryOpSpec:
29
+ operator: str
30
+ kind: OperatorKind
31
+ apply: Callable[[np.ndarray, np.ndarray], np.ndarray]
32
+
33
+
34
+ BINARY_OP_TYPES = {
35
+ "Add",
36
+ "And",
37
+ "BitShift",
38
+ "BitwiseAnd",
39
+ "BitwiseOr",
40
+ "BitwiseXor",
41
+ "Div",
42
+ "Equal",
43
+ "Greater",
44
+ "GreaterOrEqual",
45
+ "Less",
46
+ "LessOrEqual",
47
+ "Max",
48
+ "Mean",
49
+ "Min",
50
+ "Mod",
51
+ "Mul",
52
+ "Or",
53
+ "PRelu",
54
+ "Pow",
55
+ "Sub",
56
+ "Sum",
57
+ "Xor",
58
+ }
59
+
60
+ COMPARE_OP_TYPES = {
61
+ "Equal",
62
+ "Greater",
63
+ "GreaterOrEqual",
64
+ "Less",
65
+ "LessOrEqual",
66
+ }
67
+
68
+ UNARY_OP_TYPES = {
69
+ "Abs",
70
+ "Acos",
71
+ "Acosh",
72
+ "Asin",
73
+ "Asinh",
74
+ "Atan",
75
+ "Atanh",
76
+ "BitwiseNot",
77
+ "Ceil",
78
+ "Cos",
79
+ "Cosh",
80
+ "Elu",
81
+ "Erf",
82
+ "Exp",
83
+ "Floor",
84
+ "Gelu",
85
+ "HardSigmoid",
86
+ "HardSwish",
87
+ "Identity",
88
+ "LeakyRelu",
89
+ "Log",
90
+ "Neg",
91
+ "Not",
92
+ "Reciprocal",
93
+ "Relu",
94
+ "Round",
95
+ "Selu",
96
+ "Sigmoid",
97
+ "Sign",
98
+ "Sin",
99
+ "Sinh",
100
+ "Softplus",
101
+ "Softsign",
102
+ "Sqrt",
103
+ "Tan",
104
+ "Tanh",
105
+ "ThresholdedRelu",
106
+ }
107
+
108
+
109
+ def _format_float_literal(value: float, dtype: ScalarType) -> str:
110
+ formatted = f"{value:.9g}"
111
+ if "e" not in formatted and "E" not in formatted and "." not in formatted:
112
+ formatted = f"{formatted}.0"
113
+ if dtype in {ScalarType.F16, ScalarType.F32}:
114
+ return f"{formatted}f"
115
+ return formatted
116
+
117
+
118
+ UNARY_SYMBOLS_BOOL = {
119
+ ScalarFunction.POSITIVE: "identity",
120
+ ScalarFunction.LOGICAL_NOT: "!",
121
+ ScalarFunction.BITWISE_NOT: "bitwise_not",
122
+ }
123
+
124
+ UNARY_SYMBOLS_INT64 = {
125
+ ScalarFunction.ABS: "llabs",
126
+ ScalarFunction.BITWISE_NOT: "bitwise_not",
127
+ ScalarFunction.POSITIVE: "identity",
128
+ ScalarFunction.NEG: "neg",
129
+ ScalarFunction.ROUND: "round",
130
+ ScalarFunction.SIGN: "sign",
131
+ }
132
+
133
+ UNARY_SYMBOLS_INT32 = {
134
+ ScalarFunction.ABS: "abs",
135
+ ScalarFunction.BITWISE_NOT: "bitwise_not",
136
+ ScalarFunction.POSITIVE: "identity",
137
+ ScalarFunction.NEG: "neg",
138
+ ScalarFunction.ROUND: "round",
139
+ ScalarFunction.SIGN: "sign",
140
+ }
141
+
142
+ UNARY_SYMBOLS_INT16 = {
143
+ ScalarFunction.ABS: "abs",
144
+ ScalarFunction.BITWISE_NOT: "bitwise_not",
145
+ ScalarFunction.POSITIVE: "identity",
146
+ ScalarFunction.NEG: "neg",
147
+ ScalarFunction.ROUND: "round",
148
+ ScalarFunction.SIGN: "sign",
149
+ }
150
+
151
+ UNARY_SYMBOLS_INT8 = {
152
+ ScalarFunction.ABS: "abs",
153
+ ScalarFunction.BITWISE_NOT: "bitwise_not",
154
+ ScalarFunction.POSITIVE: "identity",
155
+ ScalarFunction.NEG: "neg",
156
+ ScalarFunction.ROUND: "round",
157
+ ScalarFunction.SIGN: "sign",
158
+ }
159
+
160
+ UNARY_SYMBOLS_DOUBLE = {
161
+ ScalarFunction.ABS: "fabs",
162
+ ScalarFunction.ACOS: "acos",
163
+ ScalarFunction.ACOSH: "acosh",
164
+ ScalarFunction.ASIN: "asin",
165
+ ScalarFunction.ASINH: "asinh",
166
+ ScalarFunction.ATAN: "atan",
167
+ ScalarFunction.CEIL: "ceil",
168
+ ScalarFunction.COS: "cos",
169
+ ScalarFunction.COSH: "cosh",
170
+ ScalarFunction.ELU: "elu",
171
+ ScalarFunction.ERF: "erf",
172
+ ScalarFunction.EXP: "exp",
173
+ ScalarFunction.FLOOR: "floor",
174
+ ScalarFunction.GELU: "gelu",
175
+ ScalarFunction.HARDSIGMOID: "hardsigmoid",
176
+ ScalarFunction.HARDSWISH: "hardswish",
177
+ ScalarFunction.LEAKY_RELU: "leaky_relu",
178
+ ScalarFunction.POSITIVE: "identity",
179
+ ScalarFunction.LOG: "log",
180
+ ScalarFunction.NEG: "neg",
181
+ ScalarFunction.RECIPROCAL: "reciprocal",
182
+ ScalarFunction.RELU: "relu",
183
+ ScalarFunction.ROUND: "round",
184
+ ScalarFunction.SELU: "selu",
185
+ ScalarFunction.SIGMOID: "sigmoid",
186
+ ScalarFunction.SIGN: "sign",
187
+ ScalarFunction.SIN: "sin",
188
+ ScalarFunction.SINH: "sinh",
189
+ ScalarFunction.SOFTPLUS: "softplus",
190
+ ScalarFunction.SOFTSIGN: "softsign",
191
+ ScalarFunction.SQRT: "sqrt",
192
+ ScalarFunction.TAN: "tan",
193
+ ScalarFunction.TANH: "tanh",
194
+ ScalarFunction.THRESHOLDED_RELU: "thresholded_relu",
195
+ ScalarFunction.ATANH: "atanh",
196
+ }
197
+
198
+ UNARY_SYMBOLS_FLOAT = {
199
+ ScalarFunction.ABS: "fabsf",
200
+ ScalarFunction.ACOS: "acosf",
201
+ ScalarFunction.ACOSH: "acoshf",
202
+ ScalarFunction.ASIN: "asinf",
203
+ ScalarFunction.ASINH: "asinhf",
204
+ ScalarFunction.ATAN: "atanf",
205
+ ScalarFunction.CEIL: "ceilf",
206
+ ScalarFunction.COS: "cosf",
207
+ ScalarFunction.COSH: "coshf",
208
+ ScalarFunction.ELU: "elu",
209
+ ScalarFunction.ERF: "erff",
210
+ ScalarFunction.EXP: "expf",
211
+ ScalarFunction.FLOOR: "floorf",
212
+ ScalarFunction.GELU: "gelu",
213
+ ScalarFunction.HARDSIGMOID: "hardsigmoid",
214
+ ScalarFunction.HARDSWISH: "hardswish",
215
+ ScalarFunction.LEAKY_RELU: "leaky_relu",
216
+ ScalarFunction.POSITIVE: "identity",
217
+ ScalarFunction.LOG: "logf",
218
+ ScalarFunction.NEG: "neg",
219
+ ScalarFunction.RECIPROCAL: "reciprocal",
220
+ ScalarFunction.RELU: "relu",
221
+ ScalarFunction.ROUND: "round",
222
+ ScalarFunction.SELU: "selu",
223
+ ScalarFunction.SIGMOID: "sigmoid",
224
+ ScalarFunction.SIGN: "sign",
225
+ ScalarFunction.SIN: "sinf",
226
+ ScalarFunction.SINH: "sinhf",
227
+ ScalarFunction.SOFTPLUS: "softplus",
228
+ ScalarFunction.SOFTSIGN: "softsign",
229
+ ScalarFunction.SQRT: "sqrtf",
230
+ ScalarFunction.TAN: "tanf",
231
+ ScalarFunction.TANH: "tanhf",
232
+ ScalarFunction.THRESHOLDED_RELU: "thresholded_relu",
233
+ ScalarFunction.ATANH: "atanhf",
234
+ }
235
+
236
+ BINARY_SPECS_BOOL = {
237
+ ScalarFunction.LOGICAL_AND: BinaryOpSpec(
238
+ "&&", OperatorKind.INFIX, lambda left, right: np.logical_and(left, right)
239
+ ),
240
+ ScalarFunction.LOGICAL_OR: BinaryOpSpec(
241
+ "||", OperatorKind.INFIX, lambda left, right: np.logical_or(left, right)
242
+ ),
243
+ ScalarFunction.LOGICAL_XOR: BinaryOpSpec(
244
+ "!=", OperatorKind.INFIX, lambda left, right: np.logical_xor(left, right)
245
+ ),
246
+ }
247
+
248
+ COMPARE_SPECS = {
249
+ ScalarFunction.EQ: BinaryOpSpec("==", OperatorKind.INFIX, np.equal),
250
+ ScalarFunction.GT: BinaryOpSpec(">", OperatorKind.INFIX, np.greater),
251
+ ScalarFunction.GE: BinaryOpSpec(">=", OperatorKind.INFIX, np.greater_equal),
252
+ ScalarFunction.LT: BinaryOpSpec("<", OperatorKind.INFIX, np.less),
253
+ ScalarFunction.LE: BinaryOpSpec("<=", OperatorKind.INFIX, np.less_equal),
254
+ }
255
+
256
+ BINARY_SPECS_INT = {
257
+ ScalarFunction.ADD: BinaryOpSpec(
258
+ "+", OperatorKind.INFIX, lambda left, right: left + right
259
+ ),
260
+ ScalarFunction.BITWISE_AND: BinaryOpSpec(
261
+ "&", OperatorKind.INFIX, lambda left, right: left & right
262
+ ),
263
+ ScalarFunction.BITWISE_OR: BinaryOpSpec(
264
+ "|", OperatorKind.INFIX, lambda left, right: left | right
265
+ ),
266
+ ScalarFunction.BITWISE_XOR: BinaryOpSpec(
267
+ "^", OperatorKind.INFIX, lambda left, right: left ^ right
268
+ ),
269
+ ScalarFunction.BITWISE_LEFT_SHIFT: BinaryOpSpec(
270
+ "<<", OperatorKind.INFIX, np.left_shift
271
+ ),
272
+ ScalarFunction.BITWISE_RIGHT_SHIFT: BinaryOpSpec(
273
+ ">>", OperatorKind.INFIX, np.right_shift
274
+ ),
275
+ ScalarFunction.DIV: BinaryOpSpec(
276
+ "/", OperatorKind.INFIX, lambda left, right: left // right
277
+ ),
278
+ ScalarFunction.FMOD: BinaryOpSpec(
279
+ "%", OperatorKind.INFIX, np.fmod
280
+ ),
281
+ ScalarFunction.REMAINDER: BinaryOpSpec(
282
+ "remainder", OperatorKind.FUNC, np.mod
283
+ ),
284
+ ScalarFunction.MAXIMUM: BinaryOpSpec(
285
+ "maximum", OperatorKind.FUNC, np.maximum
286
+ ),
287
+ ScalarFunction.MINIMUM: BinaryOpSpec(
288
+ "minimum", OperatorKind.FUNC, np.minimum
289
+ ),
290
+ ScalarFunction.POW: BinaryOpSpec("pow", OperatorKind.FUNC, np.power),
291
+ ScalarFunction.SUB: BinaryOpSpec(
292
+ "-", OperatorKind.INFIX, lambda left, right: left - right
293
+ ),
294
+ ScalarFunction.MUL: BinaryOpSpec(
295
+ "*", OperatorKind.INFIX, lambda left, right: left * right
296
+ ),
297
+ }
298
+
299
+
300
+ def _mean_binary_spec(dtype: ScalarType) -> BinaryOpSpec:
301
+ return BinaryOpSpec(
302
+ f"({{left}} + {{right}}) * {_format_float_literal(0.5, dtype)}",
303
+ OperatorKind.EXPR,
304
+ lambda left, right: (left + right) * 0.5,
305
+ )
306
+
307
+
308
+ def _prelu_binary_spec(dtype: ScalarType) -> BinaryOpSpec:
309
+ zero_literal = _format_float_literal(0.0, dtype)
310
+ return BinaryOpSpec(
311
+ f"({{left}} > {zero_literal} ? {{left}} : {{right}} * {{left}})",
312
+ OperatorKind.EXPR,
313
+ lambda left, right: np.where(left > 0.0, left, right * left),
314
+ )
315
+
316
+
317
+ BINARY_SPECS_DOUBLE = {
318
+ ScalarFunction.ADD: BinaryOpSpec(
319
+ "+", OperatorKind.INFIX, lambda left, right: left + right
320
+ ),
321
+ ScalarFunction.DIV: BinaryOpSpec(
322
+ "/", OperatorKind.INFIX, lambda left, right: left / right
323
+ ),
324
+ ScalarFunction.MAXIMUM: BinaryOpSpec("fmax", OperatorKind.FUNC, np.maximum),
325
+ ScalarFunction.MEAN: _mean_binary_spec(ScalarType.F64),
326
+ ScalarFunction.MINIMUM: BinaryOpSpec("fmin", OperatorKind.FUNC, np.minimum),
327
+ ScalarFunction.MUL: BinaryOpSpec(
328
+ "*", OperatorKind.INFIX, lambda left, right: left * right
329
+ ),
330
+ ScalarFunction.REMAINDER: BinaryOpSpec(
331
+ "remainder", OperatorKind.FUNC, np.remainder
332
+ ),
333
+ ScalarFunction.POW: BinaryOpSpec("pow", OperatorKind.FUNC, np.power),
334
+ ScalarFunction.PRELU: _prelu_binary_spec(ScalarType.F64),
335
+ ScalarFunction.SUB: BinaryOpSpec(
336
+ "-", OperatorKind.INFIX, lambda left, right: left - right
337
+ ),
338
+ }
339
+
340
+ BINARY_SPECS_FLOAT = {
341
+ ScalarFunction.ADD: BinaryOpSpec(
342
+ "+", OperatorKind.INFIX, lambda left, right: left + right
343
+ ),
344
+ ScalarFunction.DIV: BinaryOpSpec(
345
+ "/", OperatorKind.INFIX, lambda left, right: left / right
346
+ ),
347
+ ScalarFunction.MAXIMUM: BinaryOpSpec("fmaxf", OperatorKind.FUNC, np.maximum),
348
+ ScalarFunction.MEAN: _mean_binary_spec(ScalarType.F32),
349
+ ScalarFunction.MINIMUM: BinaryOpSpec("fminf", OperatorKind.FUNC, np.minimum),
350
+ ScalarFunction.MUL: BinaryOpSpec(
351
+ "*", OperatorKind.INFIX, lambda left, right: left * right
352
+ ),
353
+ ScalarFunction.REMAINDER: BinaryOpSpec(
354
+ "remainder", OperatorKind.FUNC, np.remainder
355
+ ),
356
+ ScalarFunction.POW: BinaryOpSpec("powf", OperatorKind.FUNC, np.power),
357
+ ScalarFunction.PRELU: _prelu_binary_spec(ScalarType.F32),
358
+ ScalarFunction.SUB: BinaryOpSpec(
359
+ "-", OperatorKind.INFIX, lambda left, right: left - right
360
+ ),
361
+ }
362
+
363
+ UNARY_SYMBOLS_BY_DTYPE = {
364
+ ScalarType.BOOL: UNARY_SYMBOLS_BOOL,
365
+ ScalarType.I64: UNARY_SYMBOLS_INT64,
366
+ ScalarType.I32: UNARY_SYMBOLS_INT32,
367
+ ScalarType.I16: UNARY_SYMBOLS_INT16,
368
+ ScalarType.I8: UNARY_SYMBOLS_INT8,
369
+ ScalarType.F64: UNARY_SYMBOLS_DOUBLE,
370
+ ScalarType.F32: UNARY_SYMBOLS_FLOAT,
371
+ ScalarType.F16: UNARY_SYMBOLS_FLOAT,
372
+ }
373
+
374
+ BINARY_SPECS_BY_DTYPE = {
375
+ ScalarType.BOOL: BINARY_SPECS_BOOL,
376
+ ScalarType.I64: BINARY_SPECS_INT,
377
+ ScalarType.I32: BINARY_SPECS_INT,
378
+ ScalarType.I16: BINARY_SPECS_INT,
379
+ ScalarType.I8: BINARY_SPECS_INT,
380
+ ScalarType.U64: BINARY_SPECS_INT,
381
+ ScalarType.U32: BINARY_SPECS_INT,
382
+ ScalarType.U16: BINARY_SPECS_INT,
383
+ ScalarType.U8: BINARY_SPECS_INT,
384
+ ScalarType.F64: BINARY_SPECS_DOUBLE,
385
+ ScalarType.F32: BINARY_SPECS_FLOAT,
386
+ ScalarType.F16: BINARY_SPECS_FLOAT,
387
+ }
388
+
389
+ UNARY_APPLY_FUNCS = {
390
+ "acosf": np.arccos,
391
+ "acos": np.arccos,
392
+ "acoshf": np.arccosh,
393
+ "acosh": np.arccosh,
394
+ "fabsf": np.abs,
395
+ "fabs": np.abs,
396
+ "abs": np.abs,
397
+ "llabs": np.abs,
398
+ "asinf": np.arcsin,
399
+ "asin": np.arcsin,
400
+ "asinhf": np.arcsinh,
401
+ "asinh": np.arcsinh,
402
+ "atanf": np.arctan,
403
+ "atan": np.arctan,
404
+ "bitwise_not": np.bitwise_not,
405
+ "!": np.logical_not,
406
+ "identity": lambda value: value,
407
+ "ceilf": np.ceil,
408
+ "ceil": np.ceil,
409
+ "cosf": np.cos,
410
+ "cos": np.cos,
411
+ "coshf": np.cosh,
412
+ "cosh": np.cosh,
413
+ "elu": lambda value: np.where(value > 0.0, value, np.exp(value) - 1.0),
414
+ "erff": _NP_ERF,
415
+ "erf": _NP_ERF,
416
+ "expf": np.exp,
417
+ "exp": np.exp,
418
+ "floorf": np.floor,
419
+ "floor": np.floor,
420
+ "gelu": lambda value: 0.5
421
+ * value
422
+ * (1.0 + _NP_ERF(value / np.sqrt(2.0))),
423
+ "hardsigmoid": lambda value: np.clip(value * 0.2 + 0.5, 0.0, 1.0),
424
+ "hardswish": lambda value: value
425
+ * np.clip(value + 3.0, 0.0, 6.0)
426
+ / 6.0,
427
+ "leaky_relu": lambda value: np.where(value > 0.0, value, 0.01 * value),
428
+ "logf": np.log,
429
+ "log": np.log,
430
+ "neg": lambda value: -value,
431
+ "reciprocal": lambda value: 1.0 / value,
432
+ "relu": lambda value: np.maximum(value, 0),
433
+ "round": np.round,
434
+ "selu": lambda value: np.where(
435
+ value > 0.0,
436
+ 1.0507009873554805 * value,
437
+ 1.0507009873554805
438
+ * 1.6732632423543772
439
+ * (np.exp(value) - 1.0),
440
+ ),
441
+ "sigmoid": lambda value: 1.0 / (1.0 + np.exp(-value)),
442
+ "sign": np.sign,
443
+ "sinf": np.sin,
444
+ "sin": np.sin,
445
+ "sqrtf": np.sqrt,
446
+ "sqrt": np.sqrt,
447
+ "softplus": lambda value: np.where(
448
+ value > 20.0, value, np.log1p(np.exp(value))
449
+ ),
450
+ "softsign": lambda value: value / (1.0 + np.abs(value)),
451
+ "sinhf": np.sinh,
452
+ "sinh": np.sinh,
453
+ "tanf": np.tan,
454
+ "tan": np.tan,
455
+ "tanhf": np.tanh,
456
+ "tanh": np.tanh,
457
+ "thresholded_relu": lambda value: np.where(
458
+ value > 1.0, value, 0.0
459
+ ),
460
+ "atanhf": np.arctanh,
461
+ "atanh": np.arctanh,
462
+ }
463
+
464
+ COMPARE_FUNCTIONS = {
465
+ ScalarFunction.EQ,
466
+ ScalarFunction.GT,
467
+ ScalarFunction.GE,
468
+ ScalarFunction.LT,
469
+ ScalarFunction.LE,
470
+ }
471
+
472
+ UNARY_ATTR_DEFAULTS: Mapping[str, Mapping[str, object]] = {
473
+ "Elu": {"alpha": 1.0},
474
+ "Gelu": {"approximate": "none"},
475
+ "HardSigmoid": {"alpha": 0.2, "beta": 0.5},
476
+ "LeakyRelu": {"alpha": 0.01},
477
+ "Selu": {"alpha": 1.6732632423543772, "gamma": 1.0507009873554805},
478
+ "Softplus": {"beta": 1.0, "threshold": 20.0},
479
+ "ThresholdedRelu": {"alpha": 1.0},
480
+ }
481
+
482
+
483
+ def validate_unary_attrs(op_type: str, attrs: Mapping[str, object]) -> None:
484
+ defaults = UNARY_ATTR_DEFAULTS.get(op_type)
485
+ if defaults is None or not attrs:
486
+ return
487
+ for key in attrs:
488
+ if key not in defaults:
489
+ raise UnsupportedOpError(
490
+ f"{op_type} does not support attribute {key}"
491
+ )
492
+ for key, default in defaults.items():
493
+ if key not in attrs:
494
+ continue
495
+ value = attrs[key]
496
+ if isinstance(default, str):
497
+ if str(value) != default:
498
+ raise UnsupportedOpError(
499
+ f"{op_type} only supports {key}={default}"
500
+ )
501
+ continue
502
+ try:
503
+ numeric_value = float(value)
504
+ except (TypeError, ValueError) as exc:
505
+ raise UnsupportedOpError(
506
+ f"{op_type} only supports {key}={default}"
507
+ ) from exc
508
+ if not math.isclose(numeric_value, float(default), abs_tol=1e-6):
509
+ raise UnsupportedOpError(
510
+ f"{op_type} only supports {key}={default}"
511
+ )
512
+
513
+
514
+ def binary_op_symbol(
515
+ function: ScalarFunction,
516
+ attrs: Mapping[str, object] | None = None,
517
+ *,
518
+ dtype: ScalarType,
519
+ validate_attrs: bool = True,
520
+ ) -> BinaryOpSpec | None:
521
+ compare_spec = COMPARE_SPECS.get(function)
522
+ if compare_spec is not None:
523
+ return compare_spec
524
+ specs = BINARY_SPECS_BY_DTYPE.get(dtype)
525
+ if specs is not None:
526
+ op_spec = specs.get(function)
527
+ if op_spec is not None:
528
+ return op_spec
529
+ if not dtype.is_float:
530
+ return None
531
+ if function == ScalarFunction.FMOD:
532
+ fmod = 0
533
+ if attrs is not None:
534
+ fmod = int(attrs.get("fmod", 0))
535
+ if validate_attrs and fmod != 1:
536
+ raise UnsupportedOpError(
537
+ "Mod only supports fmod=1 for floating point types"
538
+ )
539
+ func = (
540
+ "fmodf" if dtype in {ScalarType.F16, ScalarType.F32} else "fmod"
541
+ )
542
+ return BinaryOpSpec(func, OperatorKind.FUNC, np.fmod)
543
+ return None
544
+
545
+
546
+ def unary_op_symbol(function: ScalarFunction, *, dtype: ScalarType) -> str | None:
547
+ return UNARY_SYMBOLS_BY_DTYPE.get(dtype, {}).get(function)
548
+
549
+
550
+ def apply_binary_op(
551
+ op_spec: BinaryOpSpec, left: np.ndarray, right: np.ndarray
552
+ ) -> np.ndarray:
553
+ return op_spec.apply(left, right)
554
+
555
+
556
+ def apply_unary_op(
557
+ function: ScalarFunction, value: np.ndarray, *, dtype: ScalarType
558
+ ) -> np.ndarray:
559
+ op_symbol = unary_op_symbol(function, dtype=dtype)
560
+ if op_symbol is None:
561
+ raise UnsupportedOpError(f"Unsupported unary op {function.value}")
562
+ func = UNARY_APPLY_FUNCS.get(op_symbol)
563
+ if func is not None:
564
+ return func(value)
565
+ raise UnsupportedOpError(f"Unsupported unary op {op_symbol}")
@@ -0,0 +1 @@
1
+ """Runtime helpers for evaluating ONNX graphs."""