emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,2405 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ import math
6
+ from typing import Callable, Dict, List, Mapping, Set
7
+
8
+ from shared.scalar_types import ScalarFunctionError, ScalarType
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class _ScalarTypeInfo:
13
+ scalar_type: ScalarType
14
+ c_type: str
15
+ prefix: str
16
+ suffix: str
17
+ is_float: bool
18
+ is_bool: bool
19
+ is_signed: bool
20
+ is_small_int: bool
21
+ bits: int | None
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class _GeneratedScalar:
26
+ lines: List[str]
27
+ deps: Set[ScalarFunctionKey]
28
+ includes: Set[str]
29
+
30
+
31
+ def _scalar_function_spec(
32
+ value: str,
33
+ *,
34
+ supports_float: bool = True,
35
+ supports_signed_int: bool = True,
36
+ supports_unsigned_int: bool = True,
37
+ supports_bool: bool = True,
38
+ int_from_f32_arity: int | None = None,
39
+ bool_from_f32_arity: int | None = None,
40
+ ) -> tuple[
41
+ str,
42
+ bool,
43
+ bool,
44
+ bool,
45
+ bool,
46
+ int | None,
47
+ int | None,
48
+ ]:
49
+ return (
50
+ value,
51
+ supports_float,
52
+ supports_signed_int,
53
+ supports_unsigned_int,
54
+ supports_bool,
55
+ int_from_f32_arity,
56
+ bool_from_f32_arity,
57
+ )
58
+
59
+
60
+ def _common_unary_from_f32_spec(value: str) -> tuple[
61
+ str, bool, bool, bool, bool, int | None, int | None
62
+ ]:
63
+ return _scalar_function_spec(value, int_from_f32_arity=1, bool_from_f32_arity=1)
64
+
65
+
66
+ def _common_binary_from_f32_spec(value: str) -> tuple[
67
+ str, bool, bool, bool, bool, int | None, int | None
68
+ ]:
69
+ return _scalar_function_spec(value, int_from_f32_arity=2, bool_from_f32_arity=2)
70
+
71
+
72
+ def _bool_unary_from_f32_spec(
73
+ value: str, *, supports_unsigned_int: bool = True
74
+ ) -> tuple[str, bool, bool, bool, bool, int | None, int | None]:
75
+ return _scalar_function_spec(
76
+ value,
77
+ supports_unsigned_int=supports_unsigned_int,
78
+ bool_from_f32_arity=1,
79
+ )
80
+
81
+
82
+ def _bool_binary_from_f32_spec(value: str) -> tuple[
83
+ str, bool, bool, bool, bool, int | None, int | None
84
+ ]:
85
+ return _scalar_function_spec(value, bool_from_f32_arity=2)
86
+
87
+
88
+ def _no_float_spec(value: str) -> tuple[str, bool, bool, bool, bool, int | None, int | None]:
89
+ return _scalar_function_spec(value, supports_float=False)
90
+
91
+
92
+ def _int_only_spec(value: str) -> tuple[str, bool, bool, bool, bool, int | None, int | None]:
93
+ return _scalar_function_spec(value, supports_float=False, supports_bool=False)
94
+
95
+
96
+ def _conversion_spec(value: str) -> tuple[str, bool, bool, bool, bool, int | None, int | None]:
97
+ return _scalar_function_spec(
98
+ value,
99
+ supports_float=False,
100
+ supports_signed_int=False,
101
+ supports_unsigned_int=False,
102
+ supports_bool=False,
103
+ )
104
+
105
+
106
+ class ScalarFunction(str, Enum):
107
+ def __new__(
108
+ cls,
109
+ value: str,
110
+ supports_float: bool,
111
+ supports_signed_int: bool,
112
+ supports_unsigned_int: bool,
113
+ supports_bool: bool,
114
+ int_from_f32_arity: int | None = None,
115
+ bool_from_f32_arity: int | None = None,
116
+ ) -> "ScalarFunction":
117
+ obj = str.__new__(cls, value)
118
+ obj._value_ = value
119
+ obj.supports_float = supports_float
120
+ obj.supports_signed_int = supports_signed_int
121
+ obj.supports_unsigned_int = supports_unsigned_int
122
+ obj.supports_bool = supports_bool
123
+ obj.int_from_f32_arity = int_from_f32_arity
124
+ obj.bool_from_f32_arity = bool_from_f32_arity
125
+ return obj
126
+
127
+ ABS = _bool_unary_from_f32_spec("abs")
128
+ ABSOLUTE = _bool_unary_from_f32_spec("absolute")
129
+ ACOS = _common_unary_from_f32_spec("acos")
130
+ ACOSH = _common_unary_from_f32_spec("acosh")
131
+ ADD = _bool_binary_from_f32_spec("add")
132
+ ANGLE = _common_unary_from_f32_spec("angle")
133
+ ARCCOS = _common_unary_from_f32_spec("arccos")
134
+ ARCSIN = _common_unary_from_f32_spec("arcsin")
135
+ ARCSINH = _common_unary_from_f32_spec("arcsinh")
136
+ ARCTAN = _common_unary_from_f32_spec("arctan")
137
+ ASIN = _common_unary_from_f32_spec("asin")
138
+ ASINH = _common_unary_from_f32_spec("asinh")
139
+ ATAN = _common_unary_from_f32_spec("atan")
140
+ ATAN2 = _common_binary_from_f32_spec("atan2")
141
+ ATANH = _common_unary_from_f32_spec("atanh")
142
+ AFFINE = _common_unary_from_f32_spec("affine")
143
+ BITWISE_AND = _no_float_spec("bitwise_and")
144
+ BITWISE_LEFT_SHIFT = _int_only_spec("bitwise_left_shift")
145
+ BITWISE_NOT = _no_float_spec("bitwise_not")
146
+ BITWISE_OR = _no_float_spec("bitwise_or")
147
+ BITWISE_RIGHT_SHIFT = _int_only_spec("bitwise_right_shift")
148
+ BITWISE_XOR = _no_float_spec("bitwise_xor")
149
+ CBRT = _common_unary_from_f32_spec("cbrt")
150
+ CEIL = _bool_unary_from_f32_spec("ceil")
151
+ CELU = _common_unary_from_f32_spec("celu")
152
+ CLAMP_MAX = _bool_binary_from_f32_spec("clamp_max")
153
+ CLAMP_MIN = _bool_binary_from_f32_spec("clamp_min")
154
+ CONJ = _bool_unary_from_f32_spec("conj", supports_unsigned_int=False)
155
+ CONJ_PHYSICAL = _bool_unary_from_f32_spec("conj_physical", supports_unsigned_int=False)
156
+ COPYSIGN = _bool_binary_from_f32_spec("copysign")
157
+ COS = _common_unary_from_f32_spec("cos")
158
+ COSH = _common_unary_from_f32_spec("cosh")
159
+ DEG2RAD = _common_unary_from_f32_spec("deg2rad")
160
+ DIGAMMA = _common_unary_from_f32_spec("digamma")
161
+ DIV = _bool_binary_from_f32_spec("div")
162
+ ELU = _common_unary_from_f32_spec("elu")
163
+ EQ = _scalar_function_spec("eq")
164
+ ERF = _common_unary_from_f32_spec("erf")
165
+ ERFC = _common_unary_from_f32_spec("erfc")
166
+ ERFINV = _common_unary_from_f32_spec("erfinv")
167
+ EXP = _common_unary_from_f32_spec("exp")
168
+ EXP2 = _common_unary_from_f32_spec("exp2")
169
+ EXPM1 = _common_unary_from_f32_spec("expm1")
170
+ FLOOR = _bool_unary_from_f32_spec("floor")
171
+ FLOOR_DIVIDE = _bool_binary_from_f32_spec("floor_divide")
172
+ FMAX = _bool_binary_from_f32_spec("fmax")
173
+ FMIN = _bool_binary_from_f32_spec("fmin")
174
+ FMOD = _bool_binary_from_f32_spec("fmod")
175
+ FRAC = _bool_unary_from_f32_spec("frac", supports_unsigned_int=False)
176
+ GE = _scalar_function_spec("ge")
177
+ GELU = _common_unary_from_f32_spec("gelu")
178
+ GT = _scalar_function_spec("gt")
179
+ HARDSIGMOID = _common_unary_from_f32_spec("hardsigmoid")
180
+ HARDSWISH = _common_unary_from_f32_spec("hardswish")
181
+ HEAVISIDE = _common_binary_from_f32_spec("heaviside")
182
+ HYPOT = _common_binary_from_f32_spec("hypot")
183
+ I0 = _common_unary_from_f32_spec("i0")
184
+ ISFINITE = _common_unary_from_f32_spec("isfinite")
185
+ ISINF = _common_unary_from_f32_spec("isinf")
186
+ ISNAN = _common_unary_from_f32_spec("isnan")
187
+ ISNEGINF = _common_unary_from_f32_spec("isneginf")
188
+ ISPOSINF = _common_unary_from_f32_spec("isposinf")
189
+ LDEXP = _common_binary_from_f32_spec("ldexp")
190
+ LE = _scalar_function_spec("le")
191
+ LEAKY_RELU = _common_unary_from_f32_spec("leaky_relu")
192
+ LGAMMA = _common_unary_from_f32_spec("lgamma")
193
+ LOG = _common_unary_from_f32_spec("log")
194
+ LOG10 = _common_unary_from_f32_spec("log10")
195
+ LOG1P = _common_unary_from_f32_spec("log1p")
196
+ LOG2 = _common_unary_from_f32_spec("log2")
197
+ LOG_SIGMOID = _common_unary_from_f32_spec("log_sigmoid")
198
+ LOGADDEXP = _common_binary_from_f32_spec("logaddexp")
199
+ LOGADDEXP2 = _common_binary_from_f32_spec("logaddexp2")
200
+ LOGICAL_AND = _scalar_function_spec("logical_and")
201
+ LOGICAL_NOT = _scalar_function_spec("logical_not")
202
+ LOGICAL_OR = _scalar_function_spec("logical_or")
203
+ LOGICAL_XOR = _scalar_function_spec("logical_xor")
204
+ LOGIT = _common_unary_from_f32_spec("logit")
205
+ LT = _scalar_function_spec("lt")
206
+ MAXIMUM = _bool_binary_from_f32_spec("maximum")
207
+ MEAN = _scalar_function_spec(
208
+ "mean",
209
+ supports_signed_int=False,
210
+ supports_unsigned_int=False,
211
+ supports_bool=False,
212
+ )
213
+ MINIMUM = _bool_binary_from_f32_spec("minimum")
214
+ MISH = _common_unary_from_f32_spec("mish")
215
+ MUL = _bool_binary_from_f32_spec("mul")
216
+ NAN_TO_NUM = _common_unary_from_f32_spec("nan_to_num")
217
+ NE = _scalar_function_spec("ne")
218
+ NEG = _bool_unary_from_f32_spec("neg")
219
+ NEXTAFTER = _common_binary_from_f32_spec("nextafter")
220
+ POSITIVE = _bool_unary_from_f32_spec("positive", supports_unsigned_int=False)
221
+ POW = _common_binary_from_f32_spec("pow")
222
+ PRELU = _scalar_function_spec(
223
+ "prelu",
224
+ supports_signed_int=False,
225
+ supports_unsigned_int=False,
226
+ supports_bool=False,
227
+ )
228
+ RAD2DEG = _common_unary_from_f32_spec("rad2deg")
229
+ SCALED_TANH = _common_unary_from_f32_spec("scaled_tanh")
230
+ REAL = _bool_unary_from_f32_spec("real", supports_unsigned_int=False)
231
+ RECIPROCAL = _bool_unary_from_f32_spec("reciprocal")
232
+ RELU = _bool_unary_from_f32_spec("relu")
233
+ RELU6 = _common_unary_from_f32_spec("relu6")
234
+ REMAINDER = _bool_binary_from_f32_spec("remainder")
235
+ ROUND = _bool_unary_from_f32_spec("round")
236
+ RSQRT = _common_unary_from_f32_spec("rsqrt")
237
+ SELU = _common_unary_from_f32_spec("selu")
238
+ SGN = _bool_unary_from_f32_spec("sgn", supports_unsigned_int=False)
239
+ SIGMOID = _common_unary_from_f32_spec("sigmoid")
240
+ SIGN = _bool_unary_from_f32_spec("sign", supports_unsigned_int=False)
241
+ SILU = _common_unary_from_f32_spec("silu")
242
+ SIN = _common_unary_from_f32_spec("sin")
243
+ SINC = _common_unary_from_f32_spec("sinc")
244
+ SINH = _common_unary_from_f32_spec("sinh")
245
+ SOFTPLUS = _common_unary_from_f32_spec("softplus")
246
+ SOFTSIGN = _scalar_function_spec(
247
+ "softsign",
248
+ supports_signed_int=False,
249
+ supports_unsigned_int=False,
250
+ supports_bool=False,
251
+ )
252
+ SQRT = _common_unary_from_f32_spec("sqrt")
253
+ SQUARE = _bool_unary_from_f32_spec("square", supports_unsigned_int=False)
254
+ SHRINK = _common_unary_from_f32_spec("shrink")
255
+ SUB = _bool_binary_from_f32_spec("sub")
256
+ SWISH = _common_unary_from_f32_spec("swish")
257
+ TAN = _common_unary_from_f32_spec("tan")
258
+ TANH = _common_unary_from_f32_spec("tanh")
259
+ THRESHOLDED_RELU = _scalar_function_spec(
260
+ "thresholded_relu",
261
+ supports_signed_int=False,
262
+ supports_unsigned_int=False,
263
+ supports_bool=False,
264
+ )
265
+ TRUNC = _bool_unary_from_f32_spec("trunc", supports_unsigned_int=False)
266
+ XLOGY = _common_binary_from_f32_spec("xlogy")
267
+ CONVERT_FROM_F32 = _conversion_spec("convert_from_f32")
268
+ CONVERT_FROM_F64 = _conversion_spec("convert_from_f64")
269
+ CONVERT_FROM_I8 = _conversion_spec("convert_from_i8")
270
+ CONVERT_FROM_I16 = _conversion_spec("convert_from_i16")
271
+ CONVERT_FROM_I32 = _conversion_spec("convert_from_i32")
272
+ CONVERT_FROM_I64 = _conversion_spec("convert_from_i64")
273
+ CONVERT_FROM_U8 = _conversion_spec("convert_from_u8")
274
+ CONVERT_FROM_U16 = _conversion_spec("convert_from_u16")
275
+ CONVERT_FROM_U32 = _conversion_spec("convert_from_u32")
276
+ CONVERT_FROM_U64 = _conversion_spec("convert_from_u64")
277
+ CONVERT_FROM_BOOL = _conversion_spec("convert_from_bool")
278
+
279
+ def supports_dtype(self, dtype_info: _ScalarTypeInfo) -> bool:
280
+ if dtype_info.is_float:
281
+ return self.supports_float
282
+ if dtype_info.is_bool:
283
+ return self.supports_bool
284
+ if dtype_info.is_signed:
285
+ return self.supports_signed_int
286
+ return self.supports_unsigned_int
287
+
288
+ @classmethod
289
+ def from_op_name(cls, op_name: str) -> "ScalarFunction":
290
+ try:
291
+ return cls(op_name)
292
+ except ValueError as exc:
293
+ raise ScalarFunctionError(
294
+ f"unknown scalar function op name: {op_name}"
295
+ ) from exc
296
+
297
+ @classmethod
298
+ def from_onnx_op(cls, op_type: str) -> "ScalarFunction":
299
+ canonical = _normalize_op_name(op_type)
300
+ if canonical != op_type:
301
+ op_type = canonical
302
+ try:
303
+ return _ONNX_OP_TO_SCALAR_FUNCTION[op_type]
304
+ except KeyError as exc:
305
+ raise ScalarFunctionError(
306
+ f"unsupported ONNX scalar op: {op_type}"
307
+ ) from exc
308
+
309
+
310
+ @dataclass(frozen=True)
311
+ class ScalarFunctionKey:
312
+ function: ScalarFunction
313
+ return_type: ScalarType
314
+ params: tuple[float, ...] = ()
315
+
316
+ @classmethod
317
+ def for_torch_dtype(
318
+ cls, function: ScalarFunction, dtype: object
319
+ ) -> "ScalarFunctionKey":
320
+ return cls(function=function, return_type=ScalarType.from_torch_dtype(dtype))
321
+
322
+
323
+ def _conversion_key_from_alias(
324
+ dtype_info: _ScalarTypeInfo, alias: str
325
+ ) -> ScalarFunctionKey:
326
+ if alias == "from_f32":
327
+ return ScalarFunctionKey(
328
+ function=ScalarFunction.CONVERT_FROM_F32,
329
+ return_type=dtype_info.scalar_type,
330
+ params=(),
331
+ )
332
+ if alias == "to_f32":
333
+ return ScalarFunctionKey(
334
+ function=ScalarFunction.CONVERT_FROM_BOOL,
335
+ return_type=ScalarType.F32,
336
+ params=(),
337
+ )
338
+ raise ScalarFunctionError(f"unknown conversion alias: {alias}")
339
+
340
+
341
+ def _scalar_key_from_op(
342
+ dtype_info: _ScalarTypeInfo, op_name: str
343
+ ) -> ScalarFunctionKey:
344
+ canonical_name = _normalize_op_name(op_name)
345
+ if canonical_name in {"from_f32", "to_f32"}:
346
+ return _conversion_key_from_alias(dtype_info, canonical_name)
347
+ return ScalarFunctionKey(
348
+ function=ScalarFunction.from_op_name(canonical_name),
349
+ return_type=dtype_info.scalar_type,
350
+ params=(),
351
+ )
352
+
353
+
354
+ _OP_ALIASES = {
355
+ "absolute": "abs",
356
+ "arccos": "acos",
357
+ "arcsin": "asin",
358
+ "arcsinh": "asinh",
359
+ "arctan": "atan",
360
+ }
361
+
362
+ _ONNX_OP_TO_SCALAR_FUNCTION = {
363
+ "Abs": ScalarFunction.ABS,
364
+ "Acos": ScalarFunction.ACOS,
365
+ "Acosh": ScalarFunction.ACOSH,
366
+ "Add": ScalarFunction.ADD,
367
+ "And": ScalarFunction.LOGICAL_AND,
368
+ "Asin": ScalarFunction.ASIN,
369
+ "Asinh": ScalarFunction.ASINH,
370
+ "Atan": ScalarFunction.ATAN,
371
+ "Atanh": ScalarFunction.ATANH,
372
+ "BitwiseAnd": ScalarFunction.BITWISE_AND,
373
+ "BitwiseNot": ScalarFunction.BITWISE_NOT,
374
+ "BitwiseOr": ScalarFunction.BITWISE_OR,
375
+ "BitwiseXor": ScalarFunction.BITWISE_XOR,
376
+ "Ceil": ScalarFunction.CEIL,
377
+ "Celu": ScalarFunction.CELU,
378
+ "Cos": ScalarFunction.COS,
379
+ "Cosh": ScalarFunction.COSH,
380
+ "Div": ScalarFunction.DIV,
381
+ "Elu": ScalarFunction.ELU,
382
+ "Equal": ScalarFunction.EQ,
383
+ "Erf": ScalarFunction.ERF,
384
+ "Exp": ScalarFunction.EXP,
385
+ "Floor": ScalarFunction.FLOOR,
386
+ "Gelu": ScalarFunction.GELU,
387
+ "Greater": ScalarFunction.GT,
388
+ "GreaterOrEqual": ScalarFunction.GE,
389
+ "HardSigmoid": ScalarFunction.HARDSIGMOID,
390
+ "HardSwish": ScalarFunction.HARDSWISH,
391
+ "Identity": ScalarFunction.POSITIVE,
392
+ "LeakyRelu": ScalarFunction.LEAKY_RELU,
393
+ "Less": ScalarFunction.LT,
394
+ "LessOrEqual": ScalarFunction.LE,
395
+ "Log": ScalarFunction.LOG,
396
+ "Max": ScalarFunction.MAXIMUM,
397
+ "Mean": ScalarFunction.MEAN,
398
+ "Min": ScalarFunction.MINIMUM,
399
+ "Mod": ScalarFunction.FMOD,
400
+ "Mul": ScalarFunction.MUL,
401
+ "Neg": ScalarFunction.NEG,
402
+ "Not": ScalarFunction.LOGICAL_NOT,
403
+ "Or": ScalarFunction.LOGICAL_OR,
404
+ "PRelu": ScalarFunction.PRELU,
405
+ "Pow": ScalarFunction.POW,
406
+ "Reciprocal": ScalarFunction.RECIPROCAL,
407
+ "Relu": ScalarFunction.RELU,
408
+ "Round": ScalarFunction.ROUND,
409
+ "Selu": ScalarFunction.SELU,
410
+ "Sigmoid": ScalarFunction.SIGMOID,
411
+ "Sign": ScalarFunction.SIGN,
412
+ "Sin": ScalarFunction.SIN,
413
+ "Sinh": ScalarFunction.SINH,
414
+ "Softplus": ScalarFunction.SOFTPLUS,
415
+ "Softsign": ScalarFunction.SOFTSIGN,
416
+ "Shrink": ScalarFunction.SHRINK,
417
+ "Sqrt": ScalarFunction.SQRT,
418
+ "Sub": ScalarFunction.SUB,
419
+ "Sum": ScalarFunction.ADD,
420
+ "Swish": ScalarFunction.SWISH,
421
+ "Tan": ScalarFunction.TAN,
422
+ "Tanh": ScalarFunction.TANH,
423
+ "ThresholdedRelu": ScalarFunction.THRESHOLDED_RELU,
424
+ "Xor": ScalarFunction.LOGICAL_XOR,
425
+ }
426
+
427
+
428
+ _NO_SUFFIX_MATH = {"isfinite", "isnan", "isinf", "signbit"}
429
+
430
+
431
+ def _float_literal(value: float, dtype_info: _ScalarTypeInfo) -> str:
432
+ if dtype_info.suffix == "f32":
433
+ if value == int(value):
434
+ return f"{int(value)}.0f"
435
+ literal = f"{value}"
436
+ if "e" in literal or "E" in literal:
437
+ return f"{literal}f"
438
+ if "." not in literal:
439
+ literal = f"{literal}.0"
440
+ return f"{literal}f"
441
+ if value == int(value):
442
+ return f"{int(value)}.0"
443
+ literal = f"{value}"
444
+ if "." not in literal and "e" not in literal and "E" not in literal:
445
+ literal = f"{literal}.0"
446
+ return literal
447
+
448
+
449
+ def _param_suffix(params: tuple[float, ...]) -> str:
450
+ if not params:
451
+ return ""
452
+ parts: list[str] = []
453
+ for value in params:
454
+ if math.isnan(value):
455
+ encoded = "nan"
456
+ elif math.isinf(value):
457
+ encoded = "neg_inf" if value < 0 else "inf"
458
+ else:
459
+ encoded = format(value, ".17g")
460
+ if encoded == "-0":
461
+ encoded = "0"
462
+ encoded = encoded.replace("e-", "e_neg").replace("e+", "e")
463
+ encoded = encoded.replace("-", "neg").replace(".", "p")
464
+ parts.append(encoded)
465
+ return "__" + "_".join(parts)
466
+
467
+
468
+ def _math_fn(base: str, dtype_info: _ScalarTypeInfo) -> str:
469
+ if dtype_info.suffix == "f32" and base not in _NO_SUFFIX_MATH:
470
+ return f"{base}f"
471
+ return base
472
+
473
+
474
+ def _normalize_op_name(op_name: str) -> str:
475
+ return _OP_ALIASES.get(op_name, op_name)
476
+
477
+
478
+ def _cast_value(expr: str, dtype_info: _ScalarTypeInfo) -> str:
479
+ if dtype_info.is_small_int:
480
+ return f"({dtype_info.c_type})({expr})"
481
+ return expr
482
+
483
+
484
+ def _simple_unary(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
485
+ lines = [
486
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a) {{",
487
+ f" return {expr};",
488
+ "}",
489
+ ]
490
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
491
+
492
+
493
+ def _simple_binary(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
494
+ lines = [
495
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
496
+ f" return {expr};",
497
+ "}",
498
+ ]
499
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
500
+
501
+
502
+ def _float_unary_math(dtype_info: _ScalarTypeInfo, name: str, base: str) -> _GeneratedScalar:
503
+ return _simple_unary(dtype_info, name, f"{_math_fn(base, dtype_info)}(a)")
504
+
505
+
506
+ def _float_binary_math(dtype_info: _ScalarTypeInfo, name: str, base: str) -> _GeneratedScalar:
507
+ return _simple_binary(dtype_info, name, f"{_math_fn(base, dtype_info)}(a, b)")
508
+
509
+
510
+ def _float_isfinite(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
511
+ one = _float_literal(1.0, dtype_info)
512
+ zero = _float_literal(0.0, dtype_info)
513
+ return _simple_unary(dtype_info, "isfinite", f"isfinite(a) ? {one} : {zero}")
514
+
515
+
516
+ def _float_isnan(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
517
+ one = _float_literal(1.0, dtype_info)
518
+ zero = _float_literal(0.0, dtype_info)
519
+ return _simple_unary(dtype_info, "isnan", f"isnan(a) ? {one} : {zero}")
520
+
521
+
522
+ def _float_isinf(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
523
+ one = _float_literal(1.0, dtype_info)
524
+ zero = _float_literal(0.0, dtype_info)
525
+ return _simple_unary(dtype_info, "isinf", f"isinf(a) ? {one} : {zero}")
526
+
527
+
528
+ def _float_isneginf(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
529
+ one = _float_literal(1.0, dtype_info)
530
+ zero = _float_literal(0.0, dtype_info)
531
+ return _simple_unary(
532
+ dtype_info, "isneginf", f"(isinf(a) && signbit(a)) ? {one} : {zero}"
533
+ )
534
+
535
+
536
+ def _float_isposinf(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
537
+ one = _float_literal(1.0, dtype_info)
538
+ zero = _float_literal(0.0, dtype_info)
539
+ return _simple_unary(
540
+ dtype_info, "isposinf", f"(isinf(a) && !signbit(a)) ? {one} : {zero}"
541
+ )
542
+
543
+
544
+ def _float_comparison(
545
+ dtype_info: _ScalarTypeInfo, name: str, op: str
546
+ ) -> _GeneratedScalar:
547
+ one = _float_literal(1.0, dtype_info)
548
+ zero = _float_literal(0.0, dtype_info)
549
+ return _simple_binary(dtype_info, name, f"a {op} b ? {one} : {zero}")
550
+
551
+
552
+ def _float_logical_binary(
553
+ dtype_info: _ScalarTypeInfo, name: str, expr: str
554
+ ) -> _GeneratedScalar:
555
+ one = _float_literal(1.0, dtype_info)
556
+ zero = _float_literal(0.0, dtype_info)
557
+ return _simple_binary(dtype_info, name, f"{expr} ? {one} : {zero}")
558
+
559
+
560
+ def _float_logical_not(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
561
+ one = _float_literal(1.0, dtype_info)
562
+ zero = _float_literal(0.0, dtype_info)
563
+ return _simple_unary(dtype_info, "logical_not", f"a == {zero} ? {one} : {zero}")
564
+
565
+
566
+ def _float_remainder(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
567
+ nan = "NAN"
568
+ lines = [
569
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}remainder({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
570
+ " if (isnan(a) || isnan(b)) {",
571
+ f" return {nan};",
572
+ " }",
573
+ f" if (b == {_float_literal(0.0, dtype_info)}) {{",
574
+ f" return {nan};",
575
+ " }",
576
+ f" {dtype_info.c_type} mod = {_math_fn('fmod', dtype_info)}(a, b);",
577
+ f" if (mod == {_float_literal(0.0, dtype_info)}) {{",
578
+ " return mod;",
579
+ " }",
580
+ " if ((mod < 0) != (b < 0)) {",
581
+ " mod += b;",
582
+ " }",
583
+ " return mod;",
584
+ "}",
585
+ ]
586
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
587
+
588
+
589
+ def _float_floor_divide(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
590
+ return _simple_binary(
591
+ dtype_info, "floor_divide", f"{_math_fn('floor', dtype_info)}(a / b)"
592
+ )
593
+
594
+
595
+ def _float_logaddexp(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
596
+ lines = [
597
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}logaddexp({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
598
+ " if (isnan(a) || isnan(b)) {",
599
+ " return NAN;",
600
+ " }",
601
+ f" {dtype_info.c_type} max_val = {_math_fn('fmax', dtype_info)}(a, b);",
602
+ f" {dtype_info.c_type} min_val = {_math_fn('fmin', dtype_info)}(a, b);",
603
+ " if (max_val == -INFINITY) {",
604
+ " return -INFINITY;",
605
+ " }",
606
+ f" return max_val + {_math_fn('log1p', dtype_info)}({_math_fn('exp', dtype_info)}(min_val - max_val));",
607
+ "}",
608
+ ]
609
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
610
+
611
+
612
+ def _float_logaddexp2(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
613
+ one = _float_literal(1.0, dtype_info)
614
+ lines = [
615
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}logaddexp2({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
616
+ " if (isnan(a) || isnan(b)) {",
617
+ " return NAN;",
618
+ " }",
619
+ f" {dtype_info.c_type} max_val = {_math_fn('fmax', dtype_info)}(a, b);",
620
+ f" {dtype_info.c_type} min_val = {_math_fn('fmin', dtype_info)}(a, b);",
621
+ " if (max_val == -INFINITY) {",
622
+ " return -INFINITY;",
623
+ " }",
624
+ f" return max_val + {_math_fn('log2', dtype_info)}({one} + {_math_fn('exp2', dtype_info)}(min_val - max_val));",
625
+ "}",
626
+ ]
627
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
628
+
629
+
630
+ def _float_xlogy(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
631
+ zero = _float_literal(0.0, dtype_info)
632
+ lines = [
633
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}xlogy({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
634
+ " if (isnan(a) || isnan(b)) {",
635
+ " return NAN;",
636
+ " }",
637
+ f" if (a == {zero}) {{",
638
+ f" return {zero};",
639
+ " }",
640
+ f" return a * {_math_fn('log', dtype_info)}(b);",
641
+ "}",
642
+ ]
643
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
644
+
645
+
646
+ def _float_heaviside(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
647
+ zero = _float_literal(0.0, dtype_info)
648
+ one = _float_literal(1.0, dtype_info)
649
+ lines = [
650
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}heaviside({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
651
+ f" if (a > {zero}) {{",
652
+ f" return {one};",
653
+ " }",
654
+ f" if (a == {zero}) {{",
655
+ " return b;",
656
+ " }",
657
+ f" return {zero};",
658
+ "}",
659
+ ]
660
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
661
+
662
+
663
+ def _float_ldexp(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
664
+ return _simple_binary(
665
+ dtype_info, "ldexp", f"a * {_math_fn('exp2', dtype_info)}(b)"
666
+ )
667
+
668
+
669
+ def _float_reciprocal(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
670
+ one = _float_literal(1.0, dtype_info)
671
+ return _simple_unary(dtype_info, "reciprocal", f"{one} / a")
672
+
673
+
674
+ def _float_relu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
675
+ zero = _float_literal(0.0, dtype_info)
676
+ return _simple_unary(dtype_info, "relu", f"a > {zero} ? a : {zero}")
677
+
678
+
679
+ def _float_rsqrt(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
680
+ one = _float_literal(1.0, dtype_info)
681
+ return _simple_unary(dtype_info, "rsqrt", f"{one} / {_math_fn('sqrt', dtype_info)}(a)")
682
+
683
+
684
+ def _float_sigmoid(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
685
+ one = _float_literal(1.0, dtype_info)
686
+ return _simple_unary(dtype_info, "sigmoid", f"{one} / ({one} + {_math_fn('exp', dtype_info)}(-a))")
687
+
688
+
689
+ def _float_log_sigmoid(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
690
+ zero = _float_literal(0.0, dtype_info)
691
+ lines = [
692
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}log_sigmoid({dtype_info.c_type} a) {{",
693
+ f" if (a >= {zero}) {{",
694
+ f" return -{_math_fn('log1p', dtype_info)}({_math_fn('exp', dtype_info)}(-a));",
695
+ " }",
696
+ f" return a - {_math_fn('log1p', dtype_info)}({_math_fn('exp', dtype_info)}(a));",
697
+ "}",
698
+ ]
699
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
700
+
701
+
702
+ def _float_gelu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
703
+ inv_sqrt2 = _float_literal(0.7071067811865475, dtype_info)
704
+ half = _float_literal(0.5, dtype_info)
705
+ one = _float_literal(1.0, dtype_info)
706
+ lines = [
707
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}gelu({dtype_info.c_type} a) {{",
708
+ f" const {dtype_info.c_type} inv_sqrt2 = {inv_sqrt2};",
709
+ f" return {half} * a * ({one} + {_math_fn('erf', dtype_info)}(a * inv_sqrt2));",
710
+ "}",
711
+ ]
712
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
713
+
714
+
715
+ def _float_elu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
716
+ one = _float_literal(1.0, dtype_info)
717
+ lines = [
718
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}elu({dtype_info.c_type} a) {{",
719
+ f" const {dtype_info.c_type} alpha = {one};",
720
+ f" const {dtype_info.c_type} scale = {one};",
721
+ f" const {dtype_info.c_type} input_scale = {one};",
722
+ " if (a > 0) {",
723
+ " return scale * a;",
724
+ " }",
725
+ f" return scale * alpha * ({_math_fn('exp', dtype_info)}(input_scale * a) - {one});",
726
+ "}",
727
+ ]
728
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
729
+
730
+
731
+ def _float_elu_param(
732
+ dtype_info: _ScalarTypeInfo,
733
+ params: tuple[float, ...],
734
+ function_name: str,
735
+ ) -> _GeneratedScalar:
736
+ if params and len(params) != 1:
737
+ raise ScalarFunctionError("elu expects 1 parameter: alpha")
738
+ alpha_value = params[0] if params else 1.0
739
+ alpha = _float_literal(alpha_value, dtype_info)
740
+ one = _float_literal(1.0, dtype_info)
741
+ lines = [
742
+ f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
743
+ f" const {dtype_info.c_type} alpha = {alpha};",
744
+ " if (a >= 0) {",
745
+ " return a;",
746
+ " }",
747
+ f" return alpha * ({_math_fn('exp', dtype_info)}(a) - {one});",
748
+ "}",
749
+ ]
750
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
751
+
752
+
753
+ def _float_celu(
754
+ dtype_info: _ScalarTypeInfo,
755
+ params: tuple[float, ...],
756
+ function_name: str,
757
+ ) -> _GeneratedScalar:
758
+ if params and len(params) != 1:
759
+ raise ScalarFunctionError("celu expects 1 parameter: alpha")
760
+ alpha_value = params[0] if params else 1.0
761
+ alpha = _float_literal(alpha_value, dtype_info)
762
+ one = _float_literal(1.0, dtype_info)
763
+ lines = [
764
+ f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
765
+ f" const {dtype_info.c_type} alpha = {alpha};",
766
+ " if (a > 0) {",
767
+ " return a;",
768
+ " }",
769
+ f" return alpha * ({_math_fn('exp', dtype_info)}(a / alpha) - {one});",
770
+ "}",
771
+ ]
772
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
773
+
774
+
775
+ def _float_affine(
776
+ dtype_info: _ScalarTypeInfo,
777
+ params: tuple[float, ...],
778
+ function_name: str,
779
+ ) -> _GeneratedScalar:
780
+ if params and len(params) != 2:
781
+ raise ScalarFunctionError("affine expects 2 parameters: alpha, beta")
782
+ alpha_value = params[0] if params else 1.0
783
+ beta_value = params[1] if len(params) > 1 else 0.0
784
+ alpha = _float_literal(alpha_value, dtype_info)
785
+ beta = _float_literal(beta_value, dtype_info)
786
+ lines = [
787
+ f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
788
+ f" const {dtype_info.c_type} alpha = {alpha};",
789
+ f" const {dtype_info.c_type} beta = {beta};",
790
+ " return alpha * a + beta;",
791
+ "}",
792
+ ]
793
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
794
+
795
+
796
+ def _float_scaled_tanh(
797
+ dtype_info: _ScalarTypeInfo,
798
+ params: tuple[float, ...],
799
+ function_name: str,
800
+ ) -> _GeneratedScalar:
801
+ if params and len(params) != 2:
802
+ raise ScalarFunctionError("scaled_tanh expects 2 parameters: alpha, beta")
803
+ alpha_value = params[0] if params else 1.0
804
+ beta_value = params[1] if len(params) > 1 else 1.0
805
+ alpha = _float_literal(alpha_value, dtype_info)
806
+ beta = _float_literal(beta_value, dtype_info)
807
+ lines = [
808
+ f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
809
+ f" const {dtype_info.c_type} alpha = {alpha};",
810
+ f" const {dtype_info.c_type} beta = {beta};",
811
+ f" return alpha * {_math_fn('tanh', dtype_info)}(beta * a);",
812
+ "}",
813
+ ]
814
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
815
+
816
+
817
+ def _float_swish(
818
+ dtype_info: _ScalarTypeInfo,
819
+ params: tuple[float, ...],
820
+ function_name: str,
821
+ ) -> _GeneratedScalar:
822
+ if params and len(params) != 1:
823
+ raise ScalarFunctionError("swish expects 1 parameter: alpha")
824
+ alpha_value = params[0] if params else 1.0
825
+ alpha = _float_literal(alpha_value, dtype_info)
826
+ one = _float_literal(1.0, dtype_info)
827
+ lines = [
828
+ f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
829
+ f" const {dtype_info.c_type} alpha = {alpha};",
830
+ f" return a / ({one} + {_math_fn('exp', dtype_info)}(-alpha * a));",
831
+ "}",
832
+ ]
833
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
834
+
835
+
836
+ def _float_leaky_relu_param(
837
+ dtype_info: _ScalarTypeInfo,
838
+ params: tuple[float, ...],
839
+ function_name: str,
840
+ ) -> _GeneratedScalar:
841
+ if params and len(params) != 1:
842
+ raise ScalarFunctionError("leaky_relu expects 1 parameter: alpha")
843
+ alpha_value = params[0] if params else 0.01
844
+ alpha = _float_literal(alpha_value, dtype_info)
845
+ lines = [
846
+ f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
847
+ f" const {dtype_info.c_type} alpha = {alpha};",
848
+ " return a < 0 ? alpha * a : a;",
849
+ "}",
850
+ ]
851
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
852
+
853
+
854
+ def _float_thresholded_relu_param(
855
+ dtype_info: _ScalarTypeInfo,
856
+ params: tuple[float, ...],
857
+ function_name: str,
858
+ ) -> _GeneratedScalar:
859
+ if params and len(params) != 1:
860
+ raise ScalarFunctionError("thresholded_relu expects 1 parameter: alpha")
861
+ alpha_value = params[0] if params else 1.0
862
+ alpha = _float_literal(alpha_value, dtype_info)
863
+ zero = _float_literal(0.0, dtype_info)
864
+ lines = [
865
+ f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
866
+ f" const {dtype_info.c_type} alpha = {alpha};",
867
+ f" return a > alpha ? a : {zero};",
868
+ "}",
869
+ ]
870
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
871
+
872
+
873
+ def _float_hardsigmoid_param(
874
+ dtype_info: _ScalarTypeInfo,
875
+ params: tuple[float, ...],
876
+ function_name: str,
877
+ ) -> _GeneratedScalar:
878
+ if params and len(params) != 2:
879
+ raise ScalarFunctionError("hardsigmoid expects 2 parameters: alpha, beta")
880
+ alpha_value = params[0] if params else 0.2
881
+ beta_value = params[1] if len(params) > 1 else 0.5
882
+ alpha = _float_literal(alpha_value, dtype_info)
883
+ beta = _float_literal(beta_value, dtype_info)
884
+ zero = _float_literal(0.0, dtype_info)
885
+ one = _float_literal(1.0, dtype_info)
886
+ lines = [
887
+ f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
888
+ f" {dtype_info.c_type} value = a * {alpha} + {beta};",
889
+ f" {dtype_info.c_type} clamped = {_math_fn('fmin', dtype_info)}({one}, {_math_fn('fmax', dtype_info)}({zero}, value));",
890
+ " return clamped;",
891
+ "}",
892
+ ]
893
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
894
+
895
+
896
+ def _float_shrink(
897
+ dtype_info: _ScalarTypeInfo,
898
+ params: tuple[float, ...],
899
+ function_name: str,
900
+ ) -> _GeneratedScalar:
901
+ if params and len(params) != 2:
902
+ raise ScalarFunctionError("shrink expects 2 parameters: bias, lambd")
903
+ bias_value = params[0] if params else 0.0
904
+ lambd_value = params[1] if params else 0.5
905
+ bias = _float_literal(bias_value, dtype_info)
906
+ lambd = _float_literal(lambd_value, dtype_info)
907
+ zero = _float_literal(0.0, dtype_info)
908
+ lines = [
909
+ f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
910
+ f" const {dtype_info.c_type} bias = {bias};",
911
+ f" const {dtype_info.c_type} lambd = {lambd};",
912
+ " if (a < -lambd) {",
913
+ " return a + bias;",
914
+ " }",
915
+ " if (a > lambd) {",
916
+ " return a - bias;",
917
+ " }",
918
+ f" return {zero};",
919
+ "}",
920
+ ]
921
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
922
+
923
+
924
+ def _float_leaky_relu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
925
+ negative_slope = _float_literal(0.01, dtype_info)
926
+ lines = [
927
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}leaky_relu({dtype_info.c_type} a) {{",
928
+ f" const {dtype_info.c_type} negative_slope = {negative_slope};",
929
+ " return a > 0 ? a : negative_slope * a;",
930
+ "}",
931
+ ]
932
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
933
+
934
+
935
+ def _float_softplus(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
936
+ beta = _float_literal(1.0, dtype_info)
937
+ threshold = _float_literal(20.0, dtype_info)
938
+ lines = [
939
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}softplus({dtype_info.c_type} a) {{",
940
+ f" const {dtype_info.c_type} beta = {beta};",
941
+ f" const {dtype_info.c_type} threshold = {threshold};",
942
+ " if (beta * a > threshold) {",
943
+ " return a;",
944
+ " }",
945
+ f" return {_math_fn('log1p', dtype_info)}({_math_fn('exp', dtype_info)}(beta * a)) / beta;",
946
+ "}",
947
+ ]
948
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
949
+
950
+
951
+ def _float_softsign(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
952
+ one = _float_literal(1.0, dtype_info)
953
+ return _simple_unary(
954
+ dtype_info,
955
+ "softsign",
956
+ f"a / ({one} + {_math_fn('fabs', dtype_info)}(a))",
957
+ )
958
+
959
+
960
+ def _float_silu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
961
+ one = _float_literal(1.0, dtype_info)
962
+ return _simple_unary(dtype_info, "silu", f"a / ({one} + {_math_fn('exp', dtype_info)}(-a))")
963
+
964
+
965
+ def _float_mish(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
966
+ twenty = _float_literal(20.0, dtype_info)
967
+ zero = _float_literal(0.0, dtype_info)
968
+ lines = [
969
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}mish({dtype_info.c_type} a) {{",
970
+ f" if (a > {twenty}) {{",
971
+ " return a;",
972
+ " }",
973
+ f" if (a < -{twenty}) {{",
974
+ f" {dtype_info.c_type} exp_a = {_math_fn('exp', dtype_info)}(a);",
975
+ " return a * exp_a;",
976
+ " }",
977
+ f" {dtype_info.c_type} softplus = {_math_fn('log1p', dtype_info)}({_math_fn('exp', dtype_info)}(a));",
978
+ f" return a * {_math_fn('tanh', dtype_info)}(softplus);",
979
+ "}",
980
+ ]
981
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
982
+
983
+
984
+ def _float_selu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
985
+ alpha = "1.6732632423543772848170429916717"
986
+ scale = "1.0507009873554804934193349852946"
987
+ alpha_literal = f"{alpha}f" if dtype_info.suffix == "f32" else alpha
988
+ scale_literal = f"{scale}f" if dtype_info.suffix == "f32" else scale
989
+ one = _float_literal(1.0, dtype_info)
990
+ lines = [
991
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}selu({dtype_info.c_type} a) {{",
992
+ f" const {dtype_info.c_type} alpha = {alpha_literal};",
993
+ f" const {dtype_info.c_type} scale = {scale_literal};",
994
+ " if (a > 0) {",
995
+ " return scale * a;",
996
+ " }",
997
+ f" return scale * alpha * ({_math_fn('exp', dtype_info)}(a) - {one});",
998
+ "}",
999
+ ]
1000
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1001
+
1002
+
1003
+ def _float_thresholded_relu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1004
+ zero = _float_literal(0.0, dtype_info)
1005
+ alpha = _float_literal(1.0, dtype_info)
1006
+ return _simple_unary(
1007
+ dtype_info,
1008
+ "thresholded_relu",
1009
+ f"a > {alpha} ? a : {zero}",
1010
+ )
1011
+
1012
+
1013
+ def _float_relu6(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1014
+ zero = _float_literal(0.0, dtype_info)
1015
+ six = _float_literal(6.0, dtype_info)
1016
+ return _simple_unary(
1017
+ dtype_info,
1018
+ "relu6",
1019
+ f"{_math_fn('fmin', dtype_info)}({six}, {_math_fn('fmax', dtype_info)}({zero}, a))",
1020
+ )
1021
+
1022
+
1023
+ def _float_hardsigmoid(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1024
+ alpha = _float_literal(0.2, dtype_info)
1025
+ beta = _float_literal(0.5, dtype_info)
1026
+ zero = _float_literal(0.0, dtype_info)
1027
+ one = _float_literal(1.0, dtype_info)
1028
+ lines = [
1029
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}hardsigmoid({dtype_info.c_type} a) {{",
1030
+ f" {dtype_info.c_type} value = a * {alpha} + {beta};",
1031
+ f" {dtype_info.c_type} clamped = {_math_fn('fmin', dtype_info)}({one}, {_math_fn('fmax', dtype_info)}({zero}, value));",
1032
+ " return clamped;",
1033
+ "}",
1034
+ ]
1035
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1036
+
1037
+
1038
+ def _float_hardswish(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1039
+ three = _float_literal(3.0, dtype_info)
1040
+ six = _float_literal(6.0, dtype_info)
1041
+ zero = _float_literal(0.0, dtype_info)
1042
+ lines = [
1043
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}hardswish({dtype_info.c_type} a) {{",
1044
+ f" {dtype_info.c_type} shifted = a + {three};",
1045
+ f" {dtype_info.c_type} clamped = {_math_fn('fmin', dtype_info)}({six}, {_math_fn('fmax', dtype_info)}({zero}, shifted));",
1046
+ f" return a * clamped / {six};",
1047
+ "}",
1048
+ ]
1049
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1050
+
1051
+
1052
+ def _float_sign(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1053
+ zero = _float_literal(0.0, dtype_info)
1054
+ one = _float_literal(1.0, dtype_info)
1055
+ minus_one = _float_literal(-1.0, dtype_info)
1056
+ lines = [
1057
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}sign({dtype_info.c_type} a) {{",
1058
+ " if (isnan(a)) {",
1059
+ " return a;",
1060
+ " }",
1061
+ f" if (a > {zero}) {{",
1062
+ f" return {one};",
1063
+ " }",
1064
+ f" if (a < {zero}) {{",
1065
+ f" return {minus_one};",
1066
+ " }",
1067
+ f" return {zero};",
1068
+ "}",
1069
+ ]
1070
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1071
+
1072
+
1073
+ def _float_round(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1074
+ return _float_unary_math(dtype_info, "round", "round")
1075
+
1076
+
1077
+ def _float_trunc(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1078
+ return _float_unary_math(dtype_info, "trunc", "trunc")
1079
+
1080
+
1081
+ def _float_angle(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1082
+ zero = _float_literal(0.0, dtype_info)
1083
+ pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
1084
+ lines = [
1085
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}angle({dtype_info.c_type} a) {{",
1086
+ " if (isnan(a)) {",
1087
+ " return a;",
1088
+ " }",
1089
+ f" return a < {zero} ? {pi} : {zero};",
1090
+ "}",
1091
+ ]
1092
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1093
+
1094
+
1095
+ def _float_conj(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
1096
+ return _simple_unary(dtype_info, name, "a")
1097
+
1098
+
1099
+ def _float_deg2rad(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1100
+ pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
1101
+ one_eighty = _float_literal(180.0, dtype_info)
1102
+ return _simple_unary(dtype_info, "deg2rad", f"a * ({pi} / {one_eighty})")
1103
+
1104
+
1105
+ def _float_rad2deg(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1106
+ pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
1107
+ one_eighty = _float_literal(180.0, dtype_info)
1108
+ return _simple_unary(dtype_info, "rad2deg", f"a * ({one_eighty} / {pi})")
1109
+
1110
+
1111
+ def _float_digamma_f64() -> _GeneratedScalar:
1112
+ lines = [
1113
+ "static inline double ref_scalar_f64_digamma(double x) {",
1114
+ " if (isnan(x) || isinf(x)) {",
1115
+ " return x;",
1116
+ " }",
1117
+ " if (x <= 0.0) {",
1118
+ " double frac = x - floor(x);",
1119
+ " if (frac == 0.0) {",
1120
+ " return NAN;",
1121
+ " }",
1122
+ " return ref_scalar_f64_digamma(1.0 - x) - REF_PI_D / tan(REF_PI_D * x);",
1123
+ " }",
1124
+ " double result = 0.0;",
1125
+ " while (x < 10.0) {",
1126
+ " result -= 1.0 / x;",
1127
+ " x += 1.0;",
1128
+ " }",
1129
+ " double inv = 1.0 / x;",
1130
+ " double inv2 = inv * inv;",
1131
+ " result += log(x) - 0.5 * inv",
1132
+ " - inv2 * (1.0 / 12.0 - inv2 * (1.0 / 120.0",
1133
+ " - inv2 * (1.0 / 252.0 - inv2 * (1.0 / 240.0",
1134
+ " - inv2 * (1.0 / 132.0 - inv2 * (691.0 / 32760.0))))));",
1135
+ " return result;",
1136
+ "}",
1137
+ ]
1138
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1139
+
1140
+
1141
+ def _float_digamma(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1142
+ if dtype_info.suffix == "f64":
1143
+ return _float_digamma_f64()
1144
+ lines = [
1145
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}digamma({dtype_info.c_type} x) {{",
1146
+ " return (float)ref_scalar_f64_digamma((double)x);",
1147
+ "}",
1148
+ ]
1149
+ deps = {_scalar_key_from_op(_SCALAR_TYPES[ScalarType.F64], "digamma")}
1150
+ return _GeneratedScalar(lines=lines, deps=deps, includes=set())
1151
+
1152
+
1153
+ def _float_erfinv(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1154
+ pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
1155
+ one = _float_literal(1.0, dtype_info)
1156
+ zero = _float_literal(0.0, dtype_info)
1157
+ two = _float_literal(2.0, dtype_info)
1158
+ a_literal = _float_literal(0.147, dtype_info)
1159
+ lines = [
1160
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}erfinv({dtype_info.c_type} x) {{",
1161
+ " if (isnan(x)) {",
1162
+ " return x;",
1163
+ " }",
1164
+ f" if (x <= -{one}) {{",
1165
+ f" return x == -{one} ? -INFINITY : NAN;",
1166
+ " }",
1167
+ f" if (x >= {one}) {{",
1168
+ f" return x == {one} ? INFINITY : NAN;",
1169
+ " }",
1170
+ f" if (x == {zero}) {{",
1171
+ f" return {zero};",
1172
+ " }",
1173
+ f" {dtype_info.c_type} a = {a_literal};",
1174
+ f" {dtype_info.c_type} ln = {_math_fn('log', dtype_info)}({one} - x * x);",
1175
+ f" {dtype_info.c_type} term = {two} / ({pi} * a) + ln / {two};",
1176
+ f" {dtype_info.c_type} inner = term * term - ln / a;",
1177
+ f" {dtype_info.c_type} approx = {_math_fn('sqrt', dtype_info)}({_math_fn('fmax', dtype_info)}({zero}, {_math_fn('sqrt', dtype_info)}(inner) - term));",
1178
+ f" if (x < {zero}) {{",
1179
+ " approx = -approx;",
1180
+ " }",
1181
+ " for (int i = 0; i < 2; ++i) {",
1182
+ f" {dtype_info.c_type} err = {_math_fn('erf', dtype_info)}(approx) - x;",
1183
+ f" {dtype_info.c_type} deriv = {two} / {_math_fn('sqrt', dtype_info)}({pi}) * {_math_fn('exp', dtype_info)}(-approx * approx);",
1184
+ " approx -= err / deriv;",
1185
+ " }",
1186
+ " return approx;",
1187
+ "}",
1188
+ ]
1189
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1190
+
1191
+
1192
+ def _float_frac(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1193
+ return _simple_unary(dtype_info, "frac", f"a - {_math_fn('trunc', dtype_info)}(a)")
1194
+
1195
+
1196
+ def _float_i0(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1197
+ zero = _float_literal(0.0, dtype_info)
1198
+ three_seven_five = _float_literal(3.75, dtype_info)
1199
+ lines = [
1200
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}i0({dtype_info.c_type} x) {{",
1201
+ f" {dtype_info.c_type} ax = {_math_fn('fabs', dtype_info)}(x);",
1202
+ f" if (ax < {three_seven_five}) {{",
1203
+ f" {dtype_info.c_type} y = x / {three_seven_five};",
1204
+ " y *= y;",
1205
+ f" return { _float_literal(1.0, dtype_info)} + y * ({_float_literal(3.5156229, dtype_info)} + y * ({_float_literal(3.0899424, dtype_info)} + y * ({_float_literal(1.2067492, dtype_info)}",
1206
+ f" + y * ({_float_literal(0.2659732, dtype_info)} + y * ({_float_literal(0.0360768, dtype_info)} + y * {_float_literal(0.0045813, dtype_info)})))));",
1207
+ " }",
1208
+ f" {dtype_info.c_type} y = {three_seven_five} / ax;",
1209
+ f" return ({_math_fn('exp', dtype_info)}(ax) / {_math_fn('sqrt', dtype_info)}(ax)) * ({_float_literal(0.39894228, dtype_info)} + y * ({_float_literal(0.01328592, dtype_info)}",
1210
+ f" + y * ({_float_literal(0.00225319, dtype_info)} + y * ({_float_literal(-0.00157565, dtype_info)} + y * ({_float_literal(0.00916281, dtype_info)}",
1211
+ f" + y * ({_float_literal(-0.02057706, dtype_info)} + y * ({_float_literal(0.02635537, dtype_info)}",
1212
+ f" + y * ({_float_literal(-0.01647633, dtype_info)} + y * {_float_literal(0.00392377, dtype_info)}))))))));",
1213
+ "}",
1214
+ ]
1215
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1216
+
1217
+
1218
+ def _float_logit(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1219
+ zero = _float_literal(0.0, dtype_info)
1220
+ one = _float_literal(1.0, dtype_info)
1221
+ lines = [
1222
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}logit({dtype_info.c_type} a) {{",
1223
+ " if (isnan(a)) {",
1224
+ " return a;",
1225
+ " }",
1226
+ f" if (a == {zero}) {{",
1227
+ " return -INFINITY;",
1228
+ " }",
1229
+ f" if (a == {one}) {{",
1230
+ " return INFINITY;",
1231
+ " }",
1232
+ f" if (a < {zero} || a > {one}) {{",
1233
+ " return NAN;",
1234
+ " }",
1235
+ f" return {_math_fn('log', dtype_info)}(a / ({one} - a));",
1236
+ "}",
1237
+ ]
1238
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1239
+
1240
+
1241
+ def _float_nan_to_num(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1242
+ zero = _float_literal(0.0, dtype_info)
1243
+ lines = [
1244
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}nan_to_num({dtype_info.c_type} a) {{",
1245
+ " if (isnan(a)) {",
1246
+ f" return {zero};",
1247
+ " }",
1248
+ " if (isinf(a)) {",
1249
+ " return signbit(a) ? -FLT_MAX : FLT_MAX;",
1250
+ " }",
1251
+ " return a;",
1252
+ "}",
1253
+ ]
1254
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1255
+
1256
+
1257
+ def _float_sgn(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1258
+ zero = _float_literal(0.0, dtype_info)
1259
+ one = _float_literal(1.0, dtype_info)
1260
+ minus_one = _float_literal(-1.0, dtype_info)
1261
+ lines = [
1262
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}sgn({dtype_info.c_type} a) {{",
1263
+ " if (isnan(a)) {",
1264
+ f" return {zero};",
1265
+ " }",
1266
+ f" if (a > {zero}) {{",
1267
+ f" return {one};",
1268
+ " }",
1269
+ f" if (a < {zero}) {{",
1270
+ f" return {minus_one};",
1271
+ " }",
1272
+ f" return {zero};",
1273
+ "}",
1274
+ ]
1275
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1276
+
1277
+
1278
+ def _float_sinc(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1279
+ zero = _float_literal(0.0, dtype_info)
1280
+ one = _float_literal(1.0, dtype_info)
1281
+ pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
1282
+ lines = [
1283
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}sinc({dtype_info.c_type} a) {{",
1284
+ f" if (a == {zero}) {{",
1285
+ f" return {one};",
1286
+ " }",
1287
+ f" {dtype_info.c_type} x = {pi} * a;",
1288
+ f" return {_math_fn('sin', dtype_info)}(x) / x;",
1289
+ "}",
1290
+ ]
1291
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1292
+
1293
+
1294
+ def _float_square(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1295
+ return _simple_unary(dtype_info, "square", "a * a")
1296
+
1297
+
1298
+ def _float_positive(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1299
+ return _simple_unary(dtype_info, "positive", "a")
1300
+
1301
+
1302
+ def _float_clamp_min(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1303
+ return _simple_binary(dtype_info, "clamp_min", f"{_math_fn('fmax', dtype_info)}(a, b)")
1304
+
1305
+
1306
+ def _float_clamp_max(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1307
+ return _simple_binary(dtype_info, "clamp_max", f"{_math_fn('fmin', dtype_info)}(a, b)")
1308
+
1309
+ def _float_binary_op_handler(name: str, op: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
1310
+ def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1311
+ return _simple_binary(dtype_info, name, f"a {op} b")
1312
+
1313
+ return handler
1314
+
1315
+
1316
+ def _float_unary_math_handler(name: str, base: str | None = None) -> Callable[
1317
+ [_ScalarTypeInfo], _GeneratedScalar
1318
+ ]:
1319
+ base_name = base or name
1320
+
1321
+ def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1322
+ return _float_unary_math(dtype_info, name, base_name)
1323
+
1324
+ return handler
1325
+
1326
+
1327
+ def _float_binary_math_handler(name: str, base: str | None = None) -> Callable[
1328
+ [_ScalarTypeInfo], _GeneratedScalar
1329
+ ]:
1330
+ base_name = base or name
1331
+
1332
+ def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1333
+ return _float_binary_math(dtype_info, name, base_name)
1334
+
1335
+ return handler
1336
+
1337
+
1338
+ def _float_comparison_handler(name: str, op: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
1339
+ def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1340
+ return _float_comparison(dtype_info, name, op)
1341
+
1342
+ return handler
1343
+
1344
+
1345
+ def _float_logical_or(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1346
+ zero = _float_literal(0.0, dtype_info)
1347
+ return _float_logical_binary(dtype_info, "logical_or", f"(a != {zero} || b != {zero})")
1348
+
1349
+
1350
+ def _float_logical_and(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1351
+ zero = _float_literal(0.0, dtype_info)
1352
+ return _float_logical_binary(dtype_info, "logical_and", f"(a != {zero} && b != {zero})")
1353
+
1354
+
1355
+ def _float_logical_xor(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1356
+ zero = _float_literal(0.0, dtype_info)
1357
+ return _float_logical_binary(dtype_info, "logical_xor", f"((a != {zero}) != (b != {zero}))")
1358
+
1359
+
1360
+ _FLOAT_OP_DISPATCH: Mapping[str, Callable[[_ScalarTypeInfo], _GeneratedScalar]] = {
1361
+ "abs": _float_unary_math_handler("abs", "fabs"),
1362
+ "add": _float_binary_op_handler("add", "+"),
1363
+ "sub": _float_binary_op_handler("sub", "-"),
1364
+ "mul": _float_binary_op_handler("mul", "*"),
1365
+ "div": _float_binary_op_handler("div", "/"),
1366
+ "maximum": _float_binary_math_handler("maximum", "fmax"),
1367
+ "fmax": _float_binary_math_handler("fmax", "fmax"),
1368
+ "minimum": _float_binary_math_handler("minimum", "fmin"),
1369
+ "fmin": _float_binary_math_handler("fmin", "fmin"),
1370
+ "le": _float_comparison_handler("le", "<="),
1371
+ "lt": _float_comparison_handler("lt", "<"),
1372
+ "ge": _float_comparison_handler("ge", ">="),
1373
+ "gt": _float_comparison_handler("gt", ">"),
1374
+ "eq": _float_comparison_handler("eq", "=="),
1375
+ "ne": _float_comparison_handler("ne", "!="),
1376
+ "logical_or": _float_logical_or,
1377
+ "logical_and": _float_logical_and,
1378
+ "logical_xor": _float_logical_xor,
1379
+ "logical_not": _float_logical_not,
1380
+ "copysign": _float_binary_math_handler("copysign"),
1381
+ "hypot": _float_binary_math_handler("hypot"),
1382
+ "atan2": _float_binary_math_handler("atan2"),
1383
+ "pow": _float_binary_math_handler("pow"),
1384
+ "fmod": _float_binary_math_handler("fmod"),
1385
+ "remainder": _float_remainder,
1386
+ "floor_divide": _float_floor_divide,
1387
+ "logaddexp": _float_logaddexp,
1388
+ "logaddexp2": _float_logaddexp2,
1389
+ "nextafter": _float_binary_math_handler("nextafter"),
1390
+ "xlogy": _float_xlogy,
1391
+ "heaviside": _float_heaviside,
1392
+ "ldexp": _float_ldexp,
1393
+ "clamp_min": _float_clamp_min,
1394
+ "clamp_max": _float_clamp_max,
1395
+ "neg": lambda dtype_info: _simple_unary(dtype_info, "neg", "-a"),
1396
+ "reciprocal": _float_reciprocal,
1397
+ "relu": _float_relu,
1398
+ "ceil": _float_unary_math_handler("ceil"),
1399
+ "floor": _float_unary_math_handler("floor"),
1400
+ "sin": _float_unary_math_handler("sin"),
1401
+ "cos": _float_unary_math_handler("cos"),
1402
+ "sqrt": _float_unary_math_handler("sqrt"),
1403
+ "cbrt": _float_unary_math_handler("cbrt"),
1404
+ "exp": _float_unary_math_handler("exp"),
1405
+ "tanh": _float_unary_math_handler("tanh"),
1406
+ "log": _float_unary_math_handler("log"),
1407
+ "acos": _float_unary_math_handler("acos"),
1408
+ "acosh": _float_unary_math_handler("acosh"),
1409
+ "asin": _float_unary_math_handler("asin"),
1410
+ "asinh": _float_unary_math_handler("asinh"),
1411
+ "atan": _float_unary_math_handler("atan"),
1412
+ "atanh": _float_unary_math_handler("atanh"),
1413
+ "cosh": _float_unary_math_handler("cosh"),
1414
+ "sinh": _float_unary_math_handler("sinh"),
1415
+ "tan": _float_unary_math_handler("tan"),
1416
+ "erf": _float_unary_math_handler("erf"),
1417
+ "erfc": _float_unary_math_handler("erfc"),
1418
+ "expm1": _float_unary_math_handler("expm1"),
1419
+ "log1p": _float_unary_math_handler("log1p"),
1420
+ "log2": _float_unary_math_handler("log2"),
1421
+ "log10": _float_unary_math_handler("log10"),
1422
+ "exp2": _float_unary_math_handler("exp2"),
1423
+ "lgamma": _float_unary_math_handler("lgamma"),
1424
+ "isfinite": _float_isfinite,
1425
+ "rsqrt": _float_rsqrt,
1426
+ "sigmoid": _float_sigmoid,
1427
+ "log_sigmoid": _float_log_sigmoid,
1428
+ "gelu": _float_gelu,
1429
+ "elu": _float_elu,
1430
+ "leaky_relu": _float_leaky_relu,
1431
+ "softplus": _float_softplus,
1432
+ "softsign": _float_softsign,
1433
+ "silu": _float_silu,
1434
+ "mish": _float_mish,
1435
+ "selu": _float_selu,
1436
+ "relu6": _float_relu6,
1437
+ "hardsigmoid": _float_hardsigmoid,
1438
+ "hardswish": _float_hardswish,
1439
+ "thresholded_relu": _float_thresholded_relu,
1440
+ "sign": _float_sign,
1441
+ "round": _float_round,
1442
+ "trunc": _float_trunc,
1443
+ "angle": _float_angle,
1444
+ "conj": lambda dtype_info: _float_conj(dtype_info, "conj"),
1445
+ "conj_physical": lambda dtype_info: _float_conj(dtype_info, "conj_physical"),
1446
+ "deg2rad": _float_deg2rad,
1447
+ "digamma": _float_digamma,
1448
+ "erfinv": _float_erfinv,
1449
+ "frac": _float_frac,
1450
+ "i0": _float_i0,
1451
+ "logit": _float_logit,
1452
+ "isnan": _float_isnan,
1453
+ "isinf": _float_isinf,
1454
+ "isneginf": _float_isneginf,
1455
+ "isposinf": _float_isposinf,
1456
+ "nan_to_num": _float_nan_to_num,
1457
+ "positive": _float_positive,
1458
+ "rad2deg": _float_rad2deg,
1459
+ "real": lambda dtype_info: _simple_unary(dtype_info, "real", "a"),
1460
+ "sgn": _float_sgn,
1461
+ "sinc": _float_sinc,
1462
+ "square": _float_square,
1463
+ }
1464
+
1465
+ _PARAMETERIZED_FLOAT_OPS: Mapping[
1466
+ ScalarFunction,
1467
+ Callable[[_ScalarTypeInfo, tuple[float, ...], str], _GeneratedScalar],
1468
+ ] = {
1469
+ ScalarFunction.AFFINE: _float_affine,
1470
+ ScalarFunction.CELU: _float_celu,
1471
+ ScalarFunction.ELU: _float_elu_param,
1472
+ ScalarFunction.HARDSIGMOID: _float_hardsigmoid_param,
1473
+ ScalarFunction.LEAKY_RELU: _float_leaky_relu_param,
1474
+ ScalarFunction.SCALED_TANH: _float_scaled_tanh,
1475
+ ScalarFunction.SHRINK: _float_shrink,
1476
+ ScalarFunction.SWISH: _float_swish,
1477
+ ScalarFunction.THRESHOLDED_RELU: _float_thresholded_relu_param,
1478
+ }
1479
+
1480
+
1481
+ def _float_from_ops(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
1482
+ canonical_name = _normalize_op_name(name)
1483
+ if canonical_name != name:
1484
+ lines = [
1485
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a) {{",
1486
+ f" return {dtype_info.prefix}{canonical_name}(a);",
1487
+ "}",
1488
+ ]
1489
+ deps = {_scalar_key_from_op(dtype_info, canonical_name)}
1490
+ return _GeneratedScalar(lines=lines, deps=deps, includes=set())
1491
+ name = canonical_name
1492
+ handler = _FLOAT_OP_DISPATCH.get(name)
1493
+ if handler is None:
1494
+ raise ScalarFunctionError(f"unsupported float scalar op: {name}")
1495
+ return handler(dtype_info)
1496
+
1497
+
1498
+ def _int_from_f32(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1499
+ lines: List[str] = []
1500
+ includes: Set[str] = {"#include <math.h>", "#include <stdint.h>"}
1501
+ if dtype_info.is_signed:
1502
+ includes.add("#include <limits.h>")
1503
+ min_name = f"INT{dtype_info.bits}_MIN"
1504
+ max_name = f"INT{dtype_info.bits}_MAX"
1505
+ lines.extend(
1506
+ [
1507
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}from_f32(float value) {{",
1508
+ " if (!isfinite(value)) {",
1509
+ f" return {min_name};",
1510
+ " }",
1511
+ f" if (value > (float){max_name}) {{",
1512
+ f" return {max_name};",
1513
+ " }",
1514
+ f" if (value < (float){min_name}) {{",
1515
+ f" return {min_name};",
1516
+ " }",
1517
+ f" return ({dtype_info.c_type})value;",
1518
+ "}",
1519
+ ]
1520
+ )
1521
+ return _GeneratedScalar(lines=lines, deps=set(), includes=includes)
1522
+ max_name = f"UINT{dtype_info.bits}_MAX"
1523
+ if dtype_info.bits in {32, 64}:
1524
+ lines.extend(
1525
+ [
1526
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}from_f32(float value) {{",
1527
+ " if (!isfinite(value)) {",
1528
+ " return 0;",
1529
+ " }",
1530
+ " if (value <= 0.0f) {",
1531
+ " return 0;",
1532
+ " }",
1533
+ f" if (value >= (float){max_name}) {{",
1534
+ f" return {max_name};",
1535
+ " }",
1536
+ f" return ({dtype_info.c_type})value;",
1537
+ "}",
1538
+ ]
1539
+ )
1540
+ return _GeneratedScalar(lines=lines, deps=set(), includes=includes)
1541
+ lines.extend(
1542
+ [
1543
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}from_f32(float value) {{",
1544
+ " if (!isfinite(value)) {",
1545
+ " return 0;",
1546
+ " }",
1547
+ f" return ({dtype_info.c_type})value;",
1548
+ "}",
1549
+ ]
1550
+ )
1551
+ return _GeneratedScalar(lines=lines, deps=set(), includes=includes)
1552
+
1553
+
1554
+ def _int_unary_from_f32(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
1555
+ lines = [
1556
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a) {{",
1557
+ f" return {dtype_info.prefix}from_f32(ref_scalar_f32_{name}((float)a));",
1558
+ "}",
1559
+ ]
1560
+ deps = {
1561
+ _conversion_key_from_alias(dtype_info, "from_f32"),
1562
+ _scalar_key_from_op(_SCALAR_TYPES[ScalarType.F32], name),
1563
+ }
1564
+ return _GeneratedScalar(lines=lines, deps=deps, includes=set())
1565
+
1566
+
1567
+ def _int_binary_from_f32(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
1568
+ lines = [
1569
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
1570
+ f" return {dtype_info.prefix}from_f32(ref_scalar_f32_{name}((float)a, (float)b));",
1571
+ "}",
1572
+ ]
1573
+ deps = {
1574
+ _conversion_key_from_alias(dtype_info, "from_f32"),
1575
+ _scalar_key_from_op(_SCALAR_TYPES[ScalarType.F32], name),
1576
+ }
1577
+ return _GeneratedScalar(lines=lines, deps=deps, includes=set())
1578
+
1579
+
1580
+ def _int_bool_literal(dtype_info: _ScalarTypeInfo, value: bool) -> str:
1581
+ if dtype_info.is_small_int:
1582
+ return f"({dtype_info.c_type}){1 if value else 0}"
1583
+ return "1" if value else "0"
1584
+
1585
+
1586
+ def _int_binary_op(dtype_info: _ScalarTypeInfo, name: str, op: str) -> _GeneratedScalar:
1587
+ expr = _cast_value(f"a {op} b", dtype_info)
1588
+ return _simple_binary(dtype_info, name, expr)
1589
+
1590
+
1591
+ def _int_unary_op(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
1592
+ return _simple_unary(dtype_info, name, _cast_value(expr, dtype_info))
1593
+
1594
+
1595
+ def _int_comparison(dtype_info: _ScalarTypeInfo, name: str, op: str) -> _GeneratedScalar:
1596
+ one = _int_bool_literal(dtype_info, True)
1597
+ zero = _int_bool_literal(dtype_info, False)
1598
+ return _simple_binary(dtype_info, name, f"a {op} b ? {one} : {zero}")
1599
+
1600
+
1601
+ def _int_logical(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
1602
+ one = _int_bool_literal(dtype_info, True)
1603
+ zero = _int_bool_literal(dtype_info, False)
1604
+ return _simple_binary(dtype_info, name, f"{expr} ? {one} : {zero}")
1605
+
1606
+
1607
+ def _int_logical_not(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1608
+ zero = _int_bool_literal(dtype_info, False)
1609
+ one = _int_bool_literal(dtype_info, True)
1610
+ return _simple_unary(dtype_info, "logical_not", f"a == {zero} ? {one} : {zero}")
1611
+
1612
+
1613
+ def _int_abs(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1614
+ if not dtype_info.is_signed:
1615
+ return _simple_unary(dtype_info, "abs", "a")
1616
+ min_name = f"INT{dtype_info.bits}_MIN"
1617
+ includes = {"#include <limits.h>"}
1618
+ lines = [
1619
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}abs({dtype_info.c_type} a) {{",
1620
+ f" if (a == {min_name}) {{",
1621
+ f" return {min_name};",
1622
+ " }",
1623
+ " return a < 0 ? -a : a;",
1624
+ "}",
1625
+ ]
1626
+ return _GeneratedScalar(lines=lines, deps=set(), includes=includes)
1627
+
1628
+
1629
+ def _int_absolute(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1630
+ lines = [
1631
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}absolute({dtype_info.c_type} a) {{",
1632
+ f" return {dtype_info.prefix}abs(a);",
1633
+ "}",
1634
+ ]
1635
+ deps = {_scalar_key_from_op(dtype_info, "abs")}
1636
+ return _GeneratedScalar(lines=lines, deps=deps, includes=set())
1637
+
1638
+
1639
+ def _int_div(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1640
+ expr = _cast_value("a / b", dtype_info)
1641
+ lines = [
1642
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}div({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
1643
+ " if (b == 0) {",
1644
+ " return 0;",
1645
+ " }",
1646
+ f" return {expr};",
1647
+ "}",
1648
+ ]
1649
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1650
+
1651
+
1652
+ def _int_fmod(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1653
+ expr = _cast_value("a % b", dtype_info)
1654
+ lines = [
1655
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}fmod({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
1656
+ " if (b == 0) {",
1657
+ " return 0;",
1658
+ " }",
1659
+ f" return {expr};",
1660
+ "}",
1661
+ ]
1662
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1663
+
1664
+
1665
+ def _int_remainder(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1666
+ expr = _cast_value("a % b", dtype_info)
1667
+ lines = [
1668
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}remainder({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
1669
+ " if (b == 0) {",
1670
+ " return 0;",
1671
+ " }",
1672
+ f" {dtype_info.c_type} mod = {expr};",
1673
+ " if (mod == 0) {",
1674
+ " return mod;",
1675
+ " }",
1676
+ " if ((mod < 0) != (b < 0)) {",
1677
+ " mod += b;",
1678
+ " }",
1679
+ " return mod;",
1680
+ "}",
1681
+ ]
1682
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1683
+
1684
+
1685
+ def _int_floor_divide(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1686
+ expr = _cast_value("a / b", dtype_info)
1687
+ lines = [
1688
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}floor_divide({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
1689
+ " if (b == 0) {",
1690
+ " return 0;",
1691
+ " }",
1692
+ ]
1693
+ if dtype_info.is_signed:
1694
+ lines.extend(
1695
+ [
1696
+ f" {dtype_info.c_type} quo = a / b;",
1697
+ f" {dtype_info.c_type} rem = a % b;",
1698
+ " if (rem != 0 && ((rem < 0) != (b < 0))) {",
1699
+ " quo -= 1;",
1700
+ " }",
1701
+ " return quo;",
1702
+ "}",
1703
+ ]
1704
+ )
1705
+ else:
1706
+ lines.extend(
1707
+ [
1708
+ f" return {expr};",
1709
+ "}",
1710
+ ]
1711
+ )
1712
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1713
+
1714
+
1715
+ def _int_copysign(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1716
+ lines = [
1717
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}copysign({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
1718
+ f" {dtype_info.c_type} magnitude = {dtype_info.prefix}abs(a);",
1719
+ " return b < 0 ? -magnitude : magnitude;",
1720
+ "}",
1721
+ ]
1722
+ deps = {_scalar_key_from_op(dtype_info, "abs")}
1723
+ return _GeneratedScalar(lines=lines, deps=deps, includes=set())
1724
+
1725
+
1726
+ def _int_neg(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1727
+ if dtype_info.is_signed:
1728
+ min_name = f"INT{dtype_info.bits}_MIN"
1729
+ includes = {"#include <limits.h>"}
1730
+ lines = [
1731
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}neg({dtype_info.c_type} a) {{",
1732
+ f" if (a == {min_name}) {{",
1733
+ f" return {min_name};",
1734
+ " }",
1735
+ " return -a;",
1736
+ "}",
1737
+ ]
1738
+ return _GeneratedScalar(lines=lines, deps=set(), includes=includes)
1739
+ expr = _cast_value("0 - a", dtype_info)
1740
+ return _simple_unary(dtype_info, "neg", expr)
1741
+
1742
+
1743
+ def _int_reciprocal(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1744
+ expr = _cast_value("1 / a", dtype_info)
1745
+ lines = [
1746
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}reciprocal({dtype_info.c_type} a) {{",
1747
+ " if (a == 0) {",
1748
+ " return 0;",
1749
+ " }",
1750
+ f" return {expr};",
1751
+ "}",
1752
+ ]
1753
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1754
+
1755
+
1756
+ def _int_relu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1757
+ if dtype_info.is_signed:
1758
+ return _simple_unary(dtype_info, "relu", "a > 0 ? a : 0")
1759
+ return _simple_unary(dtype_info, "relu", "a")
1760
+
1761
+
1762
+ def _int_ceil_floor(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
1763
+ return _simple_unary(dtype_info, name, "a")
1764
+
1765
+
1766
+ def _int_round(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1767
+ return _simple_unary(dtype_info, "round", "a")
1768
+
1769
+
1770
+ def _int_trunc(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1771
+ return _simple_unary(dtype_info, "trunc", "a")
1772
+
1773
+
1774
+ def _int_frac(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1775
+ lines = [
1776
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}frac({dtype_info.c_type} a) {{",
1777
+ " (void)a;",
1778
+ " return 0;",
1779
+ "}",
1780
+ ]
1781
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1782
+
1783
+
1784
+ def _int_sign(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1785
+ if not dtype_info.is_signed:
1786
+ return _simple_unary(dtype_info, "sign", "a > 0 ? 1 : 0")
1787
+ lines = [
1788
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}sign({dtype_info.c_type} a) {{",
1789
+ " if (a > 0) {",
1790
+ " return 1;",
1791
+ " }",
1792
+ " if (a < 0) {",
1793
+ " return -1;",
1794
+ " }",
1795
+ " return 0;",
1796
+ "}",
1797
+ ]
1798
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1799
+
1800
+
1801
+ def _int_conj(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
1802
+ return _simple_unary(dtype_info, name, "a")
1803
+
1804
+
1805
+ def _int_positive(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1806
+ return _simple_unary(dtype_info, "positive", "a")
1807
+
1808
+
1809
+ def _int_sgn(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1810
+ if not dtype_info.is_signed:
1811
+ return _simple_unary(dtype_info, "sgn", "a > 0 ? 1 : 0")
1812
+ lines = [
1813
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}sgn({dtype_info.c_type} a) {{",
1814
+ " if (a > 0) {",
1815
+ " return 1;",
1816
+ " }",
1817
+ " if (a < 0) {",
1818
+ " return -1;",
1819
+ " }",
1820
+ " return 0;",
1821
+ "}",
1822
+ ]
1823
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1824
+
1825
+
1826
+ def _int_square(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1827
+ return _simple_unary(dtype_info, "square", _cast_value("a * a", dtype_info))
1828
+
1829
+ def _int_binary_op_handler(name: str, op: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
1830
+ def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1831
+ return _int_binary_op(dtype_info, name, op)
1832
+
1833
+ return handler
1834
+
1835
+
1836
+ def _int_simple_binary_handler(name: str, expr: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
1837
+ def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1838
+ return _simple_binary(dtype_info, name, expr)
1839
+
1840
+ return handler
1841
+
1842
+
1843
+ def _int_comparison_handler(name: str, op: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
1844
+ def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1845
+ return _int_comparison(dtype_info, name, op)
1846
+
1847
+ return handler
1848
+
1849
+
1850
+ def _int_logical_handler(name: str, expr: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
1851
+ def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1852
+ return _int_logical(dtype_info, name, expr)
1853
+
1854
+ return handler
1855
+
1856
+
1857
+ _INT_OP_DISPATCH: Mapping[str, Callable[[_ScalarTypeInfo], _GeneratedScalar]] = {
1858
+ "abs": _int_abs,
1859
+ "absolute": _int_absolute,
1860
+ "add": _int_binary_op_handler("add", "+"),
1861
+ "sub": _int_binary_op_handler("sub", "-"),
1862
+ "mul": _int_binary_op_handler("mul", "*"),
1863
+ "bitwise_and": _int_binary_op_handler("bitwise_and", "&"),
1864
+ "bitwise_or": _int_binary_op_handler("bitwise_or", "|"),
1865
+ "bitwise_xor": _int_binary_op_handler("bitwise_xor", "^"),
1866
+ "bitwise_left_shift": _int_binary_op_handler("bitwise_left_shift", "<<"),
1867
+ "bitwise_right_shift": _int_binary_op_handler("bitwise_right_shift", ">>"),
1868
+ "bitwise_not": lambda dtype_info: _int_unary_op(dtype_info, "bitwise_not", "~a"),
1869
+ "div": _int_div,
1870
+ "maximum": _int_simple_binary_handler("maximum", "a > b ? a : b"),
1871
+ "minimum": _int_simple_binary_handler("minimum", "a < b ? a : b"),
1872
+ "le": _int_comparison_handler("le", "<="),
1873
+ "lt": _int_comparison_handler("lt", "<"),
1874
+ "ge": _int_comparison_handler("ge", ">="),
1875
+ "gt": _int_comparison_handler("gt", ">"),
1876
+ "eq": _int_comparison_handler("eq", "=="),
1877
+ "ne": _int_comparison_handler("ne", "!="),
1878
+ "logical_or": _int_logical_handler("logical_or", "(a != 0 || b != 0)"),
1879
+ "logical_and": _int_logical_handler("logical_and", "(a != 0 && b != 0)"),
1880
+ "logical_xor": _int_logical_handler("logical_xor", "((a != 0) != (b != 0))"),
1881
+ "logical_not": _int_logical_not,
1882
+ "fmax": _int_simple_binary_handler("fmax", "a > b ? a : b"),
1883
+ "fmin": _int_simple_binary_handler("fmin", "a < b ? a : b"),
1884
+ "copysign": _int_copysign,
1885
+ "fmod": _int_fmod,
1886
+ "remainder": _int_remainder,
1887
+ "floor_divide": _int_floor_divide,
1888
+ "clamp_min": _int_simple_binary_handler("clamp_min", "a > b ? a : b"),
1889
+ "clamp_max": _int_simple_binary_handler("clamp_max", "a < b ? a : b"),
1890
+ "neg": _int_neg,
1891
+ "reciprocal": _int_reciprocal,
1892
+ "relu": _int_relu,
1893
+ "ceil": lambda dtype_info: _int_ceil_floor(dtype_info, "ceil"),
1894
+ "floor": lambda dtype_info: _int_ceil_floor(dtype_info, "floor"),
1895
+ "round": _int_round,
1896
+ "trunc": _int_trunc,
1897
+ "frac": _int_frac,
1898
+ "sign": _int_sign,
1899
+ "conj": lambda dtype_info: _int_conj(dtype_info, "conj"),
1900
+ "conj_physical": lambda dtype_info: _int_conj(dtype_info, "conj_physical"),
1901
+ "positive": _int_positive,
1902
+ "real": lambda dtype_info: _simple_unary(dtype_info, "real", "a"),
1903
+ "sgn": _int_sgn,
1904
+ "square": _int_square,
1905
+ }
1906
+
1907
+
1908
+ def _int_from_ops(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
1909
+ canonical_name = _normalize_op_name(name)
1910
+ if canonical_name != name:
1911
+ lines = [
1912
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a) {{",
1913
+ f" return {dtype_info.prefix}{canonical_name}(a);",
1914
+ "}",
1915
+ ]
1916
+ deps = {_scalar_key_from_op(dtype_info, canonical_name)}
1917
+ return _GeneratedScalar(lines=lines, deps=deps, includes=set())
1918
+ name = canonical_name
1919
+ if name == "from_f32":
1920
+ return _int_from_f32(dtype_info)
1921
+ handler = _INT_OP_DISPATCH.get(name)
1922
+ if handler is not None:
1923
+ return handler(dtype_info)
1924
+ function = ScalarFunction.from_op_name(name)
1925
+ if function.int_from_f32_arity == 1:
1926
+ return _int_unary_from_f32(dtype_info, name)
1927
+ if function.int_from_f32_arity == 2:
1928
+ return _int_binary_from_f32(dtype_info, name)
1929
+ raise ScalarFunctionError(f"unsupported int scalar op: {name}")
1930
+
1931
+
1932
+ def _bool_to_f32() -> _GeneratedScalar:
1933
+ lines = [
1934
+ "static inline float ref_scalar_bool_to_f32(bool value) {",
1935
+ " return value ? 1.0f : 0.0f;",
1936
+ "}",
1937
+ ]
1938
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1939
+
1940
+
1941
+ def _bool_from_f32() -> _GeneratedScalar:
1942
+ lines = [
1943
+ "static inline bool ref_scalar_bool_from_f32(float value) {",
1944
+ " return value != 0.0f;",
1945
+ "}",
1946
+ ]
1947
+ return _GeneratedScalar(lines=lines, deps=set(), includes=set())
1948
+
1949
+
1950
+ def _bool_bitwise(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
1951
+ return _simple_binary(dtype_info, name, expr)
1952
+
1953
+
1954
+ def _bool_bitwise_not(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1955
+ return _simple_unary(dtype_info, "bitwise_not", "!a")
1956
+
1957
+
1958
+ def _bool_logical(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
1959
+ return _simple_binary(dtype_info, name, expr)
1960
+
1961
+
1962
+ def _bool_logical_not(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
1963
+ return _simple_unary(dtype_info, "logical_not", "!a")
1964
+
1965
+
1966
+ def _bool_comparison(dtype_info: _ScalarTypeInfo, name: str, op: str) -> _GeneratedScalar:
1967
+ return _simple_binary(dtype_info, name, f"a {op} b")
1968
+
1969
+
1970
+ def _bool_unary_from_f32(name: str) -> _GeneratedScalar:
1971
+ lines = [
1972
+ f"static inline bool ref_scalar_bool_{name}(bool a) {{",
1973
+ f" return ref_scalar_bool_from_f32(ref_scalar_f32_{name}(ref_scalar_bool_to_f32(a)));",
1974
+ "}",
1975
+ ]
1976
+ bool_info = _SCALAR_TYPES[ScalarType.BOOL]
1977
+ deps = {
1978
+ _conversion_key_from_alias(bool_info, "from_f32"),
1979
+ _conversion_key_from_alias(bool_info, "to_f32"),
1980
+ _scalar_key_from_op(_SCALAR_TYPES[ScalarType.F32], name),
1981
+ }
1982
+ return _GeneratedScalar(lines=lines, deps=deps, includes=set())
1983
+
1984
+
1985
+ def _bool_binary_from_f32(name: str) -> _GeneratedScalar:
1986
+ lines = [
1987
+ f"static inline bool ref_scalar_bool_{name}(bool a, bool b) {{",
1988
+ " return ref_scalar_bool_from_f32(",
1989
+ f" ref_scalar_f32_{name}(ref_scalar_bool_to_f32(a), ref_scalar_bool_to_f32(b))",
1990
+ " );",
1991
+ "}",
1992
+ ]
1993
+ bool_info = _SCALAR_TYPES[ScalarType.BOOL]
1994
+ deps = {
1995
+ _conversion_key_from_alias(bool_info, "from_f32"),
1996
+ _conversion_key_from_alias(bool_info, "to_f32"),
1997
+ _scalar_key_from_op(_SCALAR_TYPES[ScalarType.F32], name),
1998
+ }
1999
+ return _GeneratedScalar(lines=lines, deps=deps, includes=set())
2000
+
2001
+
2002
+ _BOOL_OP_DISPATCH: Mapping[str, Callable[[], _GeneratedScalar]] = {
2003
+ "to_f32": _bool_to_f32,
2004
+ "from_f32": _bool_from_f32,
2005
+ "bitwise_and": lambda: _simple_binary(_SCALAR_TYPES[ScalarType.BOOL], "bitwise_and", "a & b"),
2006
+ "bitwise_or": lambda: _simple_binary(_SCALAR_TYPES[ScalarType.BOOL], "bitwise_or", "a | b"),
2007
+ "bitwise_xor": lambda: _simple_binary(_SCALAR_TYPES[ScalarType.BOOL], "bitwise_xor", "a ^ b"),
2008
+ "bitwise_not": lambda: _bool_bitwise_not(_SCALAR_TYPES[ScalarType.BOOL]),
2009
+ "logical_or": lambda: _bool_logical(_SCALAR_TYPES[ScalarType.BOOL], "logical_or", "a || b"),
2010
+ "logical_and": lambda: _bool_logical(_SCALAR_TYPES[ScalarType.BOOL], "logical_and", "a && b"),
2011
+ "logical_xor": lambda: _bool_logical(_SCALAR_TYPES[ScalarType.BOOL], "logical_xor", "a != b"),
2012
+ "logical_not": lambda: _bool_logical_not(_SCALAR_TYPES[ScalarType.BOOL]),
2013
+ "le": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "le", "<="),
2014
+ "lt": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "lt", "<"),
2015
+ "ge": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "ge", ">="),
2016
+ "gt": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "gt", ">"),
2017
+ "eq": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "eq", "=="),
2018
+ "ne": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "ne", "!="),
2019
+ }
2020
+
2021
+
2022
+ def _bool_from_ops(name: str) -> _GeneratedScalar:
2023
+ canonical_name = _normalize_op_name(name)
2024
+ if canonical_name != name:
2025
+ dtype_info = _SCALAR_TYPES[ScalarType.BOOL]
2026
+ lines = [
2027
+ f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a) {{",
2028
+ f" return {dtype_info.prefix}{canonical_name}(a);",
2029
+ "}",
2030
+ ]
2031
+ deps = {_scalar_key_from_op(dtype_info, canonical_name)}
2032
+ return _GeneratedScalar(lines=lines, deps=deps, includes=set())
2033
+ name = canonical_name
2034
+ handler = _BOOL_OP_DISPATCH.get(name)
2035
+ if handler is not None:
2036
+ return handler()
2037
+ function = ScalarFunction.from_op_name(name)
2038
+ if function.bool_from_f32_arity == 1:
2039
+ return _bool_unary_from_f32(name)
2040
+ if function.bool_from_f32_arity == 2:
2041
+ return _bool_binary_from_f32(name)
2042
+ raise ScalarFunctionError(f"unsupported bool scalar op: {name}")
2043
+
2044
+
2045
+ _SCALAR_TYPES: Dict[ScalarType, _ScalarTypeInfo] = {
2046
+ ScalarType.F32: _ScalarTypeInfo(
2047
+ scalar_type=ScalarType.F32,
2048
+ c_type="float",
2049
+ prefix="ref_scalar_f32_",
2050
+ suffix="f32",
2051
+ is_float=True,
2052
+ is_bool=False,
2053
+ is_signed=True,
2054
+ is_small_int=False,
2055
+ bits=None,
2056
+ ),
2057
+ ScalarType.F64: _ScalarTypeInfo(
2058
+ scalar_type=ScalarType.F64,
2059
+ c_type="double",
2060
+ prefix="ref_scalar_f64_",
2061
+ suffix="f64",
2062
+ is_float=True,
2063
+ is_bool=False,
2064
+ is_signed=True,
2065
+ is_small_int=False,
2066
+ bits=None,
2067
+ ),
2068
+ ScalarType.I8: _ScalarTypeInfo(
2069
+ scalar_type=ScalarType.I8,
2070
+ c_type="int8_t",
2071
+ prefix="ref_scalar_i8_",
2072
+ suffix="i8",
2073
+ is_float=False,
2074
+ is_bool=False,
2075
+ is_signed=True,
2076
+ is_small_int=True,
2077
+ bits=8,
2078
+ ),
2079
+ ScalarType.I16: _ScalarTypeInfo(
2080
+ scalar_type=ScalarType.I16,
2081
+ c_type="int16_t",
2082
+ prefix="ref_scalar_i16_",
2083
+ suffix="i16",
2084
+ is_float=False,
2085
+ is_bool=False,
2086
+ is_signed=True,
2087
+ is_small_int=True,
2088
+ bits=16,
2089
+ ),
2090
+ ScalarType.I32: _ScalarTypeInfo(
2091
+ scalar_type=ScalarType.I32,
2092
+ c_type="int32_t",
2093
+ prefix="ref_scalar_i32_",
2094
+ suffix="i32",
2095
+ is_float=False,
2096
+ is_bool=False,
2097
+ is_signed=True,
2098
+ is_small_int=False,
2099
+ bits=32,
2100
+ ),
2101
+ ScalarType.I64: _ScalarTypeInfo(
2102
+ scalar_type=ScalarType.I64,
2103
+ c_type="int64_t",
2104
+ prefix="ref_scalar_i64_",
2105
+ suffix="i64",
2106
+ is_float=False,
2107
+ is_bool=False,
2108
+ is_signed=True,
2109
+ is_small_int=False,
2110
+ bits=64,
2111
+ ),
2112
+ ScalarType.U8: _ScalarTypeInfo(
2113
+ scalar_type=ScalarType.U8,
2114
+ c_type="uint8_t",
2115
+ prefix="ref_scalar_u8_",
2116
+ suffix="u8",
2117
+ is_float=False,
2118
+ is_bool=False,
2119
+ is_signed=False,
2120
+ is_small_int=True,
2121
+ bits=8,
2122
+ ),
2123
+ ScalarType.U16: _ScalarTypeInfo(
2124
+ scalar_type=ScalarType.U16,
2125
+ c_type="uint16_t",
2126
+ prefix="ref_scalar_u16_",
2127
+ suffix="u16",
2128
+ is_float=False,
2129
+ is_bool=False,
2130
+ is_signed=False,
2131
+ is_small_int=True,
2132
+ bits=16,
2133
+ ),
2134
+ ScalarType.U32: _ScalarTypeInfo(
2135
+ scalar_type=ScalarType.U32,
2136
+ c_type="uint32_t",
2137
+ prefix="ref_scalar_u32_",
2138
+ suffix="u32",
2139
+ is_float=False,
2140
+ is_bool=False,
2141
+ is_signed=False,
2142
+ is_small_int=False,
2143
+ bits=32,
2144
+ ),
2145
+ ScalarType.U64: _ScalarTypeInfo(
2146
+ scalar_type=ScalarType.U64,
2147
+ c_type="uint64_t",
2148
+ prefix="ref_scalar_u64_",
2149
+ suffix="u64",
2150
+ is_float=False,
2151
+ is_bool=False,
2152
+ is_signed=False,
2153
+ is_small_int=False,
2154
+ bits=64,
2155
+ ),
2156
+ ScalarType.BOOL: _ScalarTypeInfo(
2157
+ scalar_type=ScalarType.BOOL,
2158
+ c_type="bool",
2159
+ prefix="ref_scalar_bool_",
2160
+ suffix="bool",
2161
+ is_float=False,
2162
+ is_bool=True,
2163
+ is_signed=False,
2164
+ is_small_int=False,
2165
+ bits=None,
2166
+ ),
2167
+ }
2168
+
2169
+
2170
+ _SCALAR_TYPE_BY_ENUM: Mapping[ScalarType, _ScalarTypeInfo] = _SCALAR_TYPES
2171
+
2172
+
2173
+ _CONVERSION_SOURCE_BY_FUNCTION: Mapping[ScalarFunction, ScalarType] = {
2174
+ ScalarFunction.CONVERT_FROM_F32: ScalarType.F32,
2175
+ ScalarFunction.CONVERT_FROM_F64: ScalarType.F64,
2176
+ ScalarFunction.CONVERT_FROM_I8: ScalarType.I8,
2177
+ ScalarFunction.CONVERT_FROM_I16: ScalarType.I16,
2178
+ ScalarFunction.CONVERT_FROM_I32: ScalarType.I32,
2179
+ ScalarFunction.CONVERT_FROM_I64: ScalarType.I64,
2180
+ ScalarFunction.CONVERT_FROM_U8: ScalarType.U8,
2181
+ ScalarFunction.CONVERT_FROM_U16: ScalarType.U16,
2182
+ ScalarFunction.CONVERT_FROM_U32: ScalarType.U32,
2183
+ ScalarFunction.CONVERT_FROM_U64: ScalarType.U64,
2184
+ ScalarFunction.CONVERT_FROM_BOOL: ScalarType.BOOL,
2185
+ }
2186
+
2187
+
2188
+ def _scalar_type_info(dtype: ScalarType) -> _ScalarTypeInfo:
2189
+ try:
2190
+ return _SCALAR_TYPE_BY_ENUM[dtype]
2191
+ except KeyError as exc:
2192
+ raise ScalarFunctionError(
2193
+ f"unsupported scalar dtype: {dtype.value}"
2194
+ ) from exc
2195
+
2196
+
2197
+ def _supported_ops(dtype_info: _ScalarTypeInfo) -> Set[str]:
2198
+ supported = {
2199
+ _normalize_op_name(function.value)
2200
+ for function in ScalarFunction
2201
+ if not function.value.startswith("convert_from_")
2202
+ and function.supports_dtype(dtype_info)
2203
+ }
2204
+ if not dtype_info.is_float:
2205
+ supported.add("from_f32")
2206
+ if dtype_info.is_bool:
2207
+ supported.add("to_f32")
2208
+ return supported
2209
+
2210
+
2211
+ def validate_scalar_function_supported_ops() -> None:
2212
+ scalar_ops = {
2213
+ _normalize_op_name(function.value)
2214
+ for function in ScalarFunction
2215
+ if not function.value.startswith("convert_from_")
2216
+ }
2217
+ conversion_aliases = {"from_f32", "to_f32"}
2218
+ categories = {
2219
+ "float": _supported_ops(_SCALAR_TYPES[ScalarType.F32]),
2220
+ "bool": _supported_ops(_SCALAR_TYPES[ScalarType.BOOL]),
2221
+ "signed_int": _supported_ops(_SCALAR_TYPES[ScalarType.I8]),
2222
+ "unsigned_int": _supported_ops(_SCALAR_TYPES[ScalarType.U8]),
2223
+ }
2224
+ errors: List[str] = []
2225
+ for category, supported in categories.items():
2226
+ missing = sorted(supported - scalar_ops - conversion_aliases)
2227
+ if missing:
2228
+ errors.append(
2229
+ f"{category} missing ScalarFunction ops (defined in _supported_ops): {missing}"
2230
+ )
2231
+ supported_union = set().union(*categories.values()) - conversion_aliases
2232
+ unexpected_extras = sorted(scalar_ops - supported_union)
2233
+ if unexpected_extras:
2234
+ errors.append(
2235
+ "ScalarFunction ops not supported by any dtype category: "
2236
+ f"{unexpected_extras}"
2237
+ )
2238
+ if errors:
2239
+ raise AssertionError(
2240
+ "ScalarFunction/_supported_ops drift detected:\n" + "\n".join(errors)
2241
+ )
2242
+
2243
+
2244
+ def _scalar_info_for_key(key: ScalarFunctionKey) -> tuple[_ScalarTypeInfo, str]:
2245
+ if key.function in _CONVERSION_SOURCE_BY_FUNCTION:
2246
+ source_type = _CONVERSION_SOURCE_BY_FUNCTION[key.function]
2247
+ if source_type == ScalarType.F32:
2248
+ return _scalar_type_info(key.return_type), "from_f32"
2249
+ if source_type == ScalarType.BOOL:
2250
+ if key.return_type != ScalarType.F32:
2251
+ raise ScalarFunctionError(
2252
+ f"unsupported scalar conversion from {source_type.value} to {key.return_type.value}"
2253
+ )
2254
+ return _scalar_type_info(source_type), "to_f32"
2255
+ raise ScalarFunctionError(
2256
+ f"unsupported scalar conversion from {source_type.value} to {key.return_type.value}"
2257
+ )
2258
+ return _scalar_type_info(key.return_type), key.function.value
2259
+
2260
+
2261
+ def _generate_scalar(key: ScalarFunctionKey) -> _GeneratedScalar:
2262
+ dtype_info, op_name = _scalar_info_for_key(key)
2263
+ if _normalize_op_name(op_name) not in _supported_ops(dtype_info):
2264
+ raise ScalarFunctionError(
2265
+ f"unsupported scalar op {op_name} for {dtype_info.suffix}"
2266
+ )
2267
+ if dtype_info.is_float:
2268
+ param_handler = _PARAMETERIZED_FLOAT_OPS.get(key.function)
2269
+ if param_handler is not None:
2270
+ generated = param_handler(
2271
+ dtype_info, key.params, _function_name_for_key(key)
2272
+ )
2273
+ else:
2274
+ generated = _float_from_ops(dtype_info, op_name)
2275
+ elif dtype_info.is_bool:
2276
+ generated = _bool_from_ops(op_name)
2277
+ else:
2278
+ generated = _int_from_ops(dtype_info, op_name)
2279
+ includes = set(generated.includes)
2280
+ if dtype_info.is_float:
2281
+ includes.update({"#include <math.h>", "#include <float.h>"})
2282
+ if not dtype_info.is_float and not dtype_info.is_bool:
2283
+ includes.update({"#include <stdint.h>"})
2284
+ if dtype_info.is_signed:
2285
+ includes.add("#include <limits.h>")
2286
+ if dtype_info.is_bool:
2287
+ includes.add("#include <stdbool.h>")
2288
+ return _GeneratedScalar(lines=generated.lines, deps=generated.deps, includes=includes)
2289
+
2290
+
2291
+ def _function_name_for_key(key: ScalarFunctionKey) -> str:
2292
+ param_suffix = _param_suffix(key.params)
2293
+ if key.function in _CONVERSION_SOURCE_BY_FUNCTION:
2294
+ source_type = _CONVERSION_SOURCE_BY_FUNCTION[key.function]
2295
+ if source_type == ScalarType.F32:
2296
+ if key.return_type in {
2297
+ ScalarType.I8,
2298
+ ScalarType.I16,
2299
+ ScalarType.I32,
2300
+ ScalarType.I64,
2301
+ ScalarType.U8,
2302
+ ScalarType.U16,
2303
+ ScalarType.U32,
2304
+ ScalarType.U64,
2305
+ ScalarType.BOOL,
2306
+ }:
2307
+ target_info = _scalar_type_info(key.return_type)
2308
+ return f"{target_info.prefix}from_f32{param_suffix}"
2309
+ raise ScalarFunctionError(
2310
+ f"unsupported scalar conversion from {source_type.value} to {key.return_type.value}"
2311
+ )
2312
+ if source_type == ScalarType.BOOL:
2313
+ if key.return_type == ScalarType.F32:
2314
+ source_info = _scalar_type_info(source_type)
2315
+ return f"{source_info.prefix}to_f32{param_suffix}"
2316
+ raise ScalarFunctionError(
2317
+ f"unsupported scalar conversion from {source_type.value} to {key.return_type.value}"
2318
+ )
2319
+ raise ScalarFunctionError(
2320
+ f"unsupported scalar conversion from {source_type.value} to {key.return_type.value}"
2321
+ )
2322
+ op_name = key.function.value
2323
+ dtype_info = _scalar_type_info(key.return_type)
2324
+ if _normalize_op_name(op_name) not in _supported_ops(dtype_info):
2325
+ raise ScalarFunctionError(
2326
+ f"unsupported scalar op {op_name} for {dtype_info.suffix}"
2327
+ )
2328
+ return f"{dtype_info.prefix}{op_name}{param_suffix}"
2329
+
2330
+
2331
+ class ScalarFunctionRegistry:
2332
+ def __init__(self) -> None:
2333
+ self._requested: List[ScalarFunctionKey] = []
2334
+ self._requested_set: Set[ScalarFunctionKey] = set()
2335
+ self._key_to_name: Dict[ScalarFunctionKey, str] = {}
2336
+ self._generated: Dict[ScalarFunctionKey, _GeneratedScalar] = {}
2337
+
2338
+ def request(self, key: ScalarFunctionKey) -> str:
2339
+ name = self._key_to_name.get(key)
2340
+ if name is None:
2341
+ name = _function_name_for_key(key)
2342
+ self._key_to_name[key] = name
2343
+ self._register_key(key)
2344
+ return name
2345
+
2346
+ def _register_key(self, key: ScalarFunctionKey) -> None:
2347
+ if key in self._requested_set:
2348
+ return
2349
+ self._requested.append(key)
2350
+ self._requested_set.add(key)
2351
+
2352
+ def include_lines(self) -> List[str]:
2353
+ includes: Set[str] = set()
2354
+ visited: Set[ScalarFunctionKey] = set()
2355
+
2356
+ def collect(key: ScalarFunctionKey) -> None:
2357
+ if key in visited:
2358
+ return
2359
+ self._ensure_generated(key)
2360
+ entry = self._generated[key]
2361
+ visited.add(key)
2362
+ for dep in entry.deps:
2363
+ collect(dep)
2364
+ includes.update(entry.includes)
2365
+
2366
+ for key in self._requested:
2367
+ collect(key)
2368
+ ordered = sorted(includes)
2369
+ preamble = [
2370
+ "#ifndef REF_PI_F",
2371
+ "#define REF_PI_F 3.14159265358979323846f",
2372
+ "#endif",
2373
+ "#ifndef REF_PI_D",
2374
+ "#define REF_PI_D 3.14159265358979323846",
2375
+ "#endif",
2376
+ ]
2377
+ return ordered + preamble
2378
+
2379
+ def render(self) -> List[str]:
2380
+ if not self._requested:
2381
+ return []
2382
+ lines: List[str] = []
2383
+ emitted: Set[ScalarFunctionKey] = set()
2384
+
2385
+ def emit(key: ScalarFunctionKey) -> None:
2386
+ if key in emitted:
2387
+ return
2388
+ self._ensure_generated(key)
2389
+ entry = self._generated[key]
2390
+ for dep in sorted(entry.deps, key=_function_name_for_key):
2391
+ emit(dep)
2392
+ lines.extend(entry.lines)
2393
+ lines.append("")
2394
+ emitted.add(key)
2395
+
2396
+ for key in self._requested:
2397
+ emit(key)
2398
+ while lines and lines[-1] == "":
2399
+ lines.pop()
2400
+ return lines
2401
+
2402
+ def _ensure_generated(self, key: ScalarFunctionKey) -> None:
2403
+ if key in self._generated:
2404
+ return
2405
+ self._generated[key] = _generate_scalar(key)