emx-onnx-cgen 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/__init__.py +6 -0
- emx_onnx_cgen/__main__.py +9 -0
- emx_onnx_cgen/_build_info.py +3 -0
- emx_onnx_cgen/cli.py +328 -0
- emx_onnx_cgen/codegen/__init__.py +25 -0
- emx_onnx_cgen/codegen/c_emitter.py +9044 -0
- emx_onnx_cgen/compiler.py +601 -0
- emx_onnx_cgen/dtypes.py +40 -0
- emx_onnx_cgen/errors.py +14 -0
- emx_onnx_cgen/ir/__init__.py +3 -0
- emx_onnx_cgen/ir/model.py +55 -0
- emx_onnx_cgen/lowering/__init__.py +3 -0
- emx_onnx_cgen/lowering/arg_reduce.py +99 -0
- emx_onnx_cgen/lowering/attention.py +421 -0
- emx_onnx_cgen/lowering/average_pool.py +229 -0
- emx_onnx_cgen/lowering/batch_normalization.py +116 -0
- emx_onnx_cgen/lowering/cast.py +70 -0
- emx_onnx_cgen/lowering/common.py +72 -0
- emx_onnx_cgen/lowering/concat.py +31 -0
- emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
- emx_onnx_cgen/lowering/conv.py +192 -0
- emx_onnx_cgen/lowering/cumsum.py +118 -0
- emx_onnx_cgen/lowering/depth_space.py +114 -0
- emx_onnx_cgen/lowering/dropout.py +46 -0
- emx_onnx_cgen/lowering/elementwise.py +164 -0
- emx_onnx_cgen/lowering/expand.py +151 -0
- emx_onnx_cgen/lowering/eye_like.py +43 -0
- emx_onnx_cgen/lowering/flatten.py +60 -0
- emx_onnx_cgen/lowering/gather.py +48 -0
- emx_onnx_cgen/lowering/gather_elements.py +60 -0
- emx_onnx_cgen/lowering/gemm.py +139 -0
- emx_onnx_cgen/lowering/grid_sample.py +149 -0
- emx_onnx_cgen/lowering/group_normalization.py +68 -0
- emx_onnx_cgen/lowering/identity.py +43 -0
- emx_onnx_cgen/lowering/instance_normalization.py +50 -0
- emx_onnx_cgen/lowering/layer_normalization.py +110 -0
- emx_onnx_cgen/lowering/logsoftmax.py +47 -0
- emx_onnx_cgen/lowering/lp_normalization.py +45 -0
- emx_onnx_cgen/lowering/lrn.py +104 -0
- emx_onnx_cgen/lowering/lstm.py +355 -0
- emx_onnx_cgen/lowering/matmul.py +120 -0
- emx_onnx_cgen/lowering/maxpool.py +195 -0
- emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
- emx_onnx_cgen/lowering/pad.py +287 -0
- emx_onnx_cgen/lowering/range.py +104 -0
- emx_onnx_cgen/lowering/reduce.py +544 -0
- emx_onnx_cgen/lowering/registry.py +51 -0
- emx_onnx_cgen/lowering/reshape.py +188 -0
- emx_onnx_cgen/lowering/resize.py +445 -0
- emx_onnx_cgen/lowering/rms_normalization.py +67 -0
- emx_onnx_cgen/lowering/shape.py +78 -0
- emx_onnx_cgen/lowering/size.py +33 -0
- emx_onnx_cgen/lowering/slice.py +425 -0
- emx_onnx_cgen/lowering/softmax.py +47 -0
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
- emx_onnx_cgen/lowering/split.py +150 -0
- emx_onnx_cgen/lowering/squeeze.py +161 -0
- emx_onnx_cgen/lowering/tile.py +81 -0
- emx_onnx_cgen/lowering/transpose.py +46 -0
- emx_onnx_cgen/lowering/unsqueeze.py +157 -0
- emx_onnx_cgen/lowering/variadic.py +95 -0
- emx_onnx_cgen/lowering/where.py +73 -0
- emx_onnx_cgen/onnx_import.py +261 -0
- emx_onnx_cgen/ops.py +565 -0
- emx_onnx_cgen/runtime/__init__.py +1 -0
- emx_onnx_cgen/runtime/evaluator.py +2206 -0
- emx_onnx_cgen/validation.py +76 -0
- emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
- emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
- emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
- emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
- shared/__init__.py +2 -0
- shared/scalar_functions.py +2405 -0
- shared/scalar_types.py +243 -0
|
@@ -0,0 +1,2405 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
import math
|
|
6
|
+
from typing import Callable, Dict, List, Mapping, Set
|
|
7
|
+
|
|
8
|
+
from shared.scalar_types import ScalarFunctionError, ScalarType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class _ScalarTypeInfo:
|
|
13
|
+
scalar_type: ScalarType
|
|
14
|
+
c_type: str
|
|
15
|
+
prefix: str
|
|
16
|
+
suffix: str
|
|
17
|
+
is_float: bool
|
|
18
|
+
is_bool: bool
|
|
19
|
+
is_signed: bool
|
|
20
|
+
is_small_int: bool
|
|
21
|
+
bits: int | None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True)
|
|
25
|
+
class _GeneratedScalar:
|
|
26
|
+
lines: List[str]
|
|
27
|
+
deps: Set[ScalarFunctionKey]
|
|
28
|
+
includes: Set[str]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _scalar_function_spec(
|
|
32
|
+
value: str,
|
|
33
|
+
*,
|
|
34
|
+
supports_float: bool = True,
|
|
35
|
+
supports_signed_int: bool = True,
|
|
36
|
+
supports_unsigned_int: bool = True,
|
|
37
|
+
supports_bool: bool = True,
|
|
38
|
+
int_from_f32_arity: int | None = None,
|
|
39
|
+
bool_from_f32_arity: int | None = None,
|
|
40
|
+
) -> tuple[
|
|
41
|
+
str,
|
|
42
|
+
bool,
|
|
43
|
+
bool,
|
|
44
|
+
bool,
|
|
45
|
+
bool,
|
|
46
|
+
int | None,
|
|
47
|
+
int | None,
|
|
48
|
+
]:
|
|
49
|
+
return (
|
|
50
|
+
value,
|
|
51
|
+
supports_float,
|
|
52
|
+
supports_signed_int,
|
|
53
|
+
supports_unsigned_int,
|
|
54
|
+
supports_bool,
|
|
55
|
+
int_from_f32_arity,
|
|
56
|
+
bool_from_f32_arity,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _common_unary_from_f32_spec(value: str) -> tuple[
|
|
61
|
+
str, bool, bool, bool, bool, int | None, int | None
|
|
62
|
+
]:
|
|
63
|
+
return _scalar_function_spec(value, int_from_f32_arity=1, bool_from_f32_arity=1)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _common_binary_from_f32_spec(value: str) -> tuple[
|
|
67
|
+
str, bool, bool, bool, bool, int | None, int | None
|
|
68
|
+
]:
|
|
69
|
+
return _scalar_function_spec(value, int_from_f32_arity=2, bool_from_f32_arity=2)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _bool_unary_from_f32_spec(
|
|
73
|
+
value: str, *, supports_unsigned_int: bool = True
|
|
74
|
+
) -> tuple[str, bool, bool, bool, bool, int | None, int | None]:
|
|
75
|
+
return _scalar_function_spec(
|
|
76
|
+
value,
|
|
77
|
+
supports_unsigned_int=supports_unsigned_int,
|
|
78
|
+
bool_from_f32_arity=1,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _bool_binary_from_f32_spec(value: str) -> tuple[
|
|
83
|
+
str, bool, bool, bool, bool, int | None, int | None
|
|
84
|
+
]:
|
|
85
|
+
return _scalar_function_spec(value, bool_from_f32_arity=2)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _no_float_spec(value: str) -> tuple[str, bool, bool, bool, bool, int | None, int | None]:
|
|
89
|
+
return _scalar_function_spec(value, supports_float=False)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _int_only_spec(value: str) -> tuple[str, bool, bool, bool, bool, int | None, int | None]:
|
|
93
|
+
return _scalar_function_spec(value, supports_float=False, supports_bool=False)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _conversion_spec(value: str) -> tuple[str, bool, bool, bool, bool, int | None, int | None]:
|
|
97
|
+
return _scalar_function_spec(
|
|
98
|
+
value,
|
|
99
|
+
supports_float=False,
|
|
100
|
+
supports_signed_int=False,
|
|
101
|
+
supports_unsigned_int=False,
|
|
102
|
+
supports_bool=False,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ScalarFunction(str, Enum):
|
|
107
|
+
def __new__(
|
|
108
|
+
cls,
|
|
109
|
+
value: str,
|
|
110
|
+
supports_float: bool,
|
|
111
|
+
supports_signed_int: bool,
|
|
112
|
+
supports_unsigned_int: bool,
|
|
113
|
+
supports_bool: bool,
|
|
114
|
+
int_from_f32_arity: int | None = None,
|
|
115
|
+
bool_from_f32_arity: int | None = None,
|
|
116
|
+
) -> "ScalarFunction":
|
|
117
|
+
obj = str.__new__(cls, value)
|
|
118
|
+
obj._value_ = value
|
|
119
|
+
obj.supports_float = supports_float
|
|
120
|
+
obj.supports_signed_int = supports_signed_int
|
|
121
|
+
obj.supports_unsigned_int = supports_unsigned_int
|
|
122
|
+
obj.supports_bool = supports_bool
|
|
123
|
+
obj.int_from_f32_arity = int_from_f32_arity
|
|
124
|
+
obj.bool_from_f32_arity = bool_from_f32_arity
|
|
125
|
+
return obj
|
|
126
|
+
|
|
127
|
+
ABS = _bool_unary_from_f32_spec("abs")
|
|
128
|
+
ABSOLUTE = _bool_unary_from_f32_spec("absolute")
|
|
129
|
+
ACOS = _common_unary_from_f32_spec("acos")
|
|
130
|
+
ACOSH = _common_unary_from_f32_spec("acosh")
|
|
131
|
+
ADD = _bool_binary_from_f32_spec("add")
|
|
132
|
+
ANGLE = _common_unary_from_f32_spec("angle")
|
|
133
|
+
ARCCOS = _common_unary_from_f32_spec("arccos")
|
|
134
|
+
ARCSIN = _common_unary_from_f32_spec("arcsin")
|
|
135
|
+
ARCSINH = _common_unary_from_f32_spec("arcsinh")
|
|
136
|
+
ARCTAN = _common_unary_from_f32_spec("arctan")
|
|
137
|
+
ASIN = _common_unary_from_f32_spec("asin")
|
|
138
|
+
ASINH = _common_unary_from_f32_spec("asinh")
|
|
139
|
+
ATAN = _common_unary_from_f32_spec("atan")
|
|
140
|
+
ATAN2 = _common_binary_from_f32_spec("atan2")
|
|
141
|
+
ATANH = _common_unary_from_f32_spec("atanh")
|
|
142
|
+
AFFINE = _common_unary_from_f32_spec("affine")
|
|
143
|
+
BITWISE_AND = _no_float_spec("bitwise_and")
|
|
144
|
+
BITWISE_LEFT_SHIFT = _int_only_spec("bitwise_left_shift")
|
|
145
|
+
BITWISE_NOT = _no_float_spec("bitwise_not")
|
|
146
|
+
BITWISE_OR = _no_float_spec("bitwise_or")
|
|
147
|
+
BITWISE_RIGHT_SHIFT = _int_only_spec("bitwise_right_shift")
|
|
148
|
+
BITWISE_XOR = _no_float_spec("bitwise_xor")
|
|
149
|
+
CBRT = _common_unary_from_f32_spec("cbrt")
|
|
150
|
+
CEIL = _bool_unary_from_f32_spec("ceil")
|
|
151
|
+
CELU = _common_unary_from_f32_spec("celu")
|
|
152
|
+
CLAMP_MAX = _bool_binary_from_f32_spec("clamp_max")
|
|
153
|
+
CLAMP_MIN = _bool_binary_from_f32_spec("clamp_min")
|
|
154
|
+
CONJ = _bool_unary_from_f32_spec("conj", supports_unsigned_int=False)
|
|
155
|
+
CONJ_PHYSICAL = _bool_unary_from_f32_spec("conj_physical", supports_unsigned_int=False)
|
|
156
|
+
COPYSIGN = _bool_binary_from_f32_spec("copysign")
|
|
157
|
+
COS = _common_unary_from_f32_spec("cos")
|
|
158
|
+
COSH = _common_unary_from_f32_spec("cosh")
|
|
159
|
+
DEG2RAD = _common_unary_from_f32_spec("deg2rad")
|
|
160
|
+
DIGAMMA = _common_unary_from_f32_spec("digamma")
|
|
161
|
+
DIV = _bool_binary_from_f32_spec("div")
|
|
162
|
+
ELU = _common_unary_from_f32_spec("elu")
|
|
163
|
+
EQ = _scalar_function_spec("eq")
|
|
164
|
+
ERF = _common_unary_from_f32_spec("erf")
|
|
165
|
+
ERFC = _common_unary_from_f32_spec("erfc")
|
|
166
|
+
ERFINV = _common_unary_from_f32_spec("erfinv")
|
|
167
|
+
EXP = _common_unary_from_f32_spec("exp")
|
|
168
|
+
EXP2 = _common_unary_from_f32_spec("exp2")
|
|
169
|
+
EXPM1 = _common_unary_from_f32_spec("expm1")
|
|
170
|
+
FLOOR = _bool_unary_from_f32_spec("floor")
|
|
171
|
+
FLOOR_DIVIDE = _bool_binary_from_f32_spec("floor_divide")
|
|
172
|
+
FMAX = _bool_binary_from_f32_spec("fmax")
|
|
173
|
+
FMIN = _bool_binary_from_f32_spec("fmin")
|
|
174
|
+
FMOD = _bool_binary_from_f32_spec("fmod")
|
|
175
|
+
FRAC = _bool_unary_from_f32_spec("frac", supports_unsigned_int=False)
|
|
176
|
+
GE = _scalar_function_spec("ge")
|
|
177
|
+
GELU = _common_unary_from_f32_spec("gelu")
|
|
178
|
+
GT = _scalar_function_spec("gt")
|
|
179
|
+
HARDSIGMOID = _common_unary_from_f32_spec("hardsigmoid")
|
|
180
|
+
HARDSWISH = _common_unary_from_f32_spec("hardswish")
|
|
181
|
+
HEAVISIDE = _common_binary_from_f32_spec("heaviside")
|
|
182
|
+
HYPOT = _common_binary_from_f32_spec("hypot")
|
|
183
|
+
I0 = _common_unary_from_f32_spec("i0")
|
|
184
|
+
ISFINITE = _common_unary_from_f32_spec("isfinite")
|
|
185
|
+
ISINF = _common_unary_from_f32_spec("isinf")
|
|
186
|
+
ISNAN = _common_unary_from_f32_spec("isnan")
|
|
187
|
+
ISNEGINF = _common_unary_from_f32_spec("isneginf")
|
|
188
|
+
ISPOSINF = _common_unary_from_f32_spec("isposinf")
|
|
189
|
+
LDEXP = _common_binary_from_f32_spec("ldexp")
|
|
190
|
+
LE = _scalar_function_spec("le")
|
|
191
|
+
LEAKY_RELU = _common_unary_from_f32_spec("leaky_relu")
|
|
192
|
+
LGAMMA = _common_unary_from_f32_spec("lgamma")
|
|
193
|
+
LOG = _common_unary_from_f32_spec("log")
|
|
194
|
+
LOG10 = _common_unary_from_f32_spec("log10")
|
|
195
|
+
LOG1P = _common_unary_from_f32_spec("log1p")
|
|
196
|
+
LOG2 = _common_unary_from_f32_spec("log2")
|
|
197
|
+
LOG_SIGMOID = _common_unary_from_f32_spec("log_sigmoid")
|
|
198
|
+
LOGADDEXP = _common_binary_from_f32_spec("logaddexp")
|
|
199
|
+
LOGADDEXP2 = _common_binary_from_f32_spec("logaddexp2")
|
|
200
|
+
LOGICAL_AND = _scalar_function_spec("logical_and")
|
|
201
|
+
LOGICAL_NOT = _scalar_function_spec("logical_not")
|
|
202
|
+
LOGICAL_OR = _scalar_function_spec("logical_or")
|
|
203
|
+
LOGICAL_XOR = _scalar_function_spec("logical_xor")
|
|
204
|
+
LOGIT = _common_unary_from_f32_spec("logit")
|
|
205
|
+
LT = _scalar_function_spec("lt")
|
|
206
|
+
MAXIMUM = _bool_binary_from_f32_spec("maximum")
|
|
207
|
+
MEAN = _scalar_function_spec(
|
|
208
|
+
"mean",
|
|
209
|
+
supports_signed_int=False,
|
|
210
|
+
supports_unsigned_int=False,
|
|
211
|
+
supports_bool=False,
|
|
212
|
+
)
|
|
213
|
+
MINIMUM = _bool_binary_from_f32_spec("minimum")
|
|
214
|
+
MISH = _common_unary_from_f32_spec("mish")
|
|
215
|
+
MUL = _bool_binary_from_f32_spec("mul")
|
|
216
|
+
NAN_TO_NUM = _common_unary_from_f32_spec("nan_to_num")
|
|
217
|
+
NE = _scalar_function_spec("ne")
|
|
218
|
+
NEG = _bool_unary_from_f32_spec("neg")
|
|
219
|
+
NEXTAFTER = _common_binary_from_f32_spec("nextafter")
|
|
220
|
+
POSITIVE = _bool_unary_from_f32_spec("positive", supports_unsigned_int=False)
|
|
221
|
+
POW = _common_binary_from_f32_spec("pow")
|
|
222
|
+
PRELU = _scalar_function_spec(
|
|
223
|
+
"prelu",
|
|
224
|
+
supports_signed_int=False,
|
|
225
|
+
supports_unsigned_int=False,
|
|
226
|
+
supports_bool=False,
|
|
227
|
+
)
|
|
228
|
+
RAD2DEG = _common_unary_from_f32_spec("rad2deg")
|
|
229
|
+
SCALED_TANH = _common_unary_from_f32_spec("scaled_tanh")
|
|
230
|
+
REAL = _bool_unary_from_f32_spec("real", supports_unsigned_int=False)
|
|
231
|
+
RECIPROCAL = _bool_unary_from_f32_spec("reciprocal")
|
|
232
|
+
RELU = _bool_unary_from_f32_spec("relu")
|
|
233
|
+
RELU6 = _common_unary_from_f32_spec("relu6")
|
|
234
|
+
REMAINDER = _bool_binary_from_f32_spec("remainder")
|
|
235
|
+
ROUND = _bool_unary_from_f32_spec("round")
|
|
236
|
+
RSQRT = _common_unary_from_f32_spec("rsqrt")
|
|
237
|
+
SELU = _common_unary_from_f32_spec("selu")
|
|
238
|
+
SGN = _bool_unary_from_f32_spec("sgn", supports_unsigned_int=False)
|
|
239
|
+
SIGMOID = _common_unary_from_f32_spec("sigmoid")
|
|
240
|
+
SIGN = _bool_unary_from_f32_spec("sign", supports_unsigned_int=False)
|
|
241
|
+
SILU = _common_unary_from_f32_spec("silu")
|
|
242
|
+
SIN = _common_unary_from_f32_spec("sin")
|
|
243
|
+
SINC = _common_unary_from_f32_spec("sinc")
|
|
244
|
+
SINH = _common_unary_from_f32_spec("sinh")
|
|
245
|
+
SOFTPLUS = _common_unary_from_f32_spec("softplus")
|
|
246
|
+
SOFTSIGN = _scalar_function_spec(
|
|
247
|
+
"softsign",
|
|
248
|
+
supports_signed_int=False,
|
|
249
|
+
supports_unsigned_int=False,
|
|
250
|
+
supports_bool=False,
|
|
251
|
+
)
|
|
252
|
+
SQRT = _common_unary_from_f32_spec("sqrt")
|
|
253
|
+
SQUARE = _bool_unary_from_f32_spec("square", supports_unsigned_int=False)
|
|
254
|
+
SHRINK = _common_unary_from_f32_spec("shrink")
|
|
255
|
+
SUB = _bool_binary_from_f32_spec("sub")
|
|
256
|
+
SWISH = _common_unary_from_f32_spec("swish")
|
|
257
|
+
TAN = _common_unary_from_f32_spec("tan")
|
|
258
|
+
TANH = _common_unary_from_f32_spec("tanh")
|
|
259
|
+
THRESHOLDED_RELU = _scalar_function_spec(
|
|
260
|
+
"thresholded_relu",
|
|
261
|
+
supports_signed_int=False,
|
|
262
|
+
supports_unsigned_int=False,
|
|
263
|
+
supports_bool=False,
|
|
264
|
+
)
|
|
265
|
+
TRUNC = _bool_unary_from_f32_spec("trunc", supports_unsigned_int=False)
|
|
266
|
+
XLOGY = _common_binary_from_f32_spec("xlogy")
|
|
267
|
+
CONVERT_FROM_F32 = _conversion_spec("convert_from_f32")
|
|
268
|
+
CONVERT_FROM_F64 = _conversion_spec("convert_from_f64")
|
|
269
|
+
CONVERT_FROM_I8 = _conversion_spec("convert_from_i8")
|
|
270
|
+
CONVERT_FROM_I16 = _conversion_spec("convert_from_i16")
|
|
271
|
+
CONVERT_FROM_I32 = _conversion_spec("convert_from_i32")
|
|
272
|
+
CONVERT_FROM_I64 = _conversion_spec("convert_from_i64")
|
|
273
|
+
CONVERT_FROM_U8 = _conversion_spec("convert_from_u8")
|
|
274
|
+
CONVERT_FROM_U16 = _conversion_spec("convert_from_u16")
|
|
275
|
+
CONVERT_FROM_U32 = _conversion_spec("convert_from_u32")
|
|
276
|
+
CONVERT_FROM_U64 = _conversion_spec("convert_from_u64")
|
|
277
|
+
CONVERT_FROM_BOOL = _conversion_spec("convert_from_bool")
|
|
278
|
+
|
|
279
|
+
def supports_dtype(self, dtype_info: _ScalarTypeInfo) -> bool:
|
|
280
|
+
if dtype_info.is_float:
|
|
281
|
+
return self.supports_float
|
|
282
|
+
if dtype_info.is_bool:
|
|
283
|
+
return self.supports_bool
|
|
284
|
+
if dtype_info.is_signed:
|
|
285
|
+
return self.supports_signed_int
|
|
286
|
+
return self.supports_unsigned_int
|
|
287
|
+
|
|
288
|
+
@classmethod
|
|
289
|
+
def from_op_name(cls, op_name: str) -> "ScalarFunction":
|
|
290
|
+
try:
|
|
291
|
+
return cls(op_name)
|
|
292
|
+
except ValueError as exc:
|
|
293
|
+
raise ScalarFunctionError(
|
|
294
|
+
f"unknown scalar function op name: {op_name}"
|
|
295
|
+
) from exc
|
|
296
|
+
|
|
297
|
+
@classmethod
|
|
298
|
+
def from_onnx_op(cls, op_type: str) -> "ScalarFunction":
|
|
299
|
+
canonical = _normalize_op_name(op_type)
|
|
300
|
+
if canonical != op_type:
|
|
301
|
+
op_type = canonical
|
|
302
|
+
try:
|
|
303
|
+
return _ONNX_OP_TO_SCALAR_FUNCTION[op_type]
|
|
304
|
+
except KeyError as exc:
|
|
305
|
+
raise ScalarFunctionError(
|
|
306
|
+
f"unsupported ONNX scalar op: {op_type}"
|
|
307
|
+
) from exc
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@dataclass(frozen=True)
|
|
311
|
+
class ScalarFunctionKey:
|
|
312
|
+
function: ScalarFunction
|
|
313
|
+
return_type: ScalarType
|
|
314
|
+
params: tuple[float, ...] = ()
|
|
315
|
+
|
|
316
|
+
@classmethod
|
|
317
|
+
def for_torch_dtype(
|
|
318
|
+
cls, function: ScalarFunction, dtype: object
|
|
319
|
+
) -> "ScalarFunctionKey":
|
|
320
|
+
return cls(function=function, return_type=ScalarType.from_torch_dtype(dtype))
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _conversion_key_from_alias(
|
|
324
|
+
dtype_info: _ScalarTypeInfo, alias: str
|
|
325
|
+
) -> ScalarFunctionKey:
|
|
326
|
+
if alias == "from_f32":
|
|
327
|
+
return ScalarFunctionKey(
|
|
328
|
+
function=ScalarFunction.CONVERT_FROM_F32,
|
|
329
|
+
return_type=dtype_info.scalar_type,
|
|
330
|
+
params=(),
|
|
331
|
+
)
|
|
332
|
+
if alias == "to_f32":
|
|
333
|
+
return ScalarFunctionKey(
|
|
334
|
+
function=ScalarFunction.CONVERT_FROM_BOOL,
|
|
335
|
+
return_type=ScalarType.F32,
|
|
336
|
+
params=(),
|
|
337
|
+
)
|
|
338
|
+
raise ScalarFunctionError(f"unknown conversion alias: {alias}")
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def _scalar_key_from_op(
|
|
342
|
+
dtype_info: _ScalarTypeInfo, op_name: str
|
|
343
|
+
) -> ScalarFunctionKey:
|
|
344
|
+
canonical_name = _normalize_op_name(op_name)
|
|
345
|
+
if canonical_name in {"from_f32", "to_f32"}:
|
|
346
|
+
return _conversion_key_from_alias(dtype_info, canonical_name)
|
|
347
|
+
return ScalarFunctionKey(
|
|
348
|
+
function=ScalarFunction.from_op_name(canonical_name),
|
|
349
|
+
return_type=dtype_info.scalar_type,
|
|
350
|
+
params=(),
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
_OP_ALIASES = {
|
|
355
|
+
"absolute": "abs",
|
|
356
|
+
"arccos": "acos",
|
|
357
|
+
"arcsin": "asin",
|
|
358
|
+
"arcsinh": "asinh",
|
|
359
|
+
"arctan": "atan",
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
_ONNX_OP_TO_SCALAR_FUNCTION = {
|
|
363
|
+
"Abs": ScalarFunction.ABS,
|
|
364
|
+
"Acos": ScalarFunction.ACOS,
|
|
365
|
+
"Acosh": ScalarFunction.ACOSH,
|
|
366
|
+
"Add": ScalarFunction.ADD,
|
|
367
|
+
"And": ScalarFunction.LOGICAL_AND,
|
|
368
|
+
"Asin": ScalarFunction.ASIN,
|
|
369
|
+
"Asinh": ScalarFunction.ASINH,
|
|
370
|
+
"Atan": ScalarFunction.ATAN,
|
|
371
|
+
"Atanh": ScalarFunction.ATANH,
|
|
372
|
+
"BitwiseAnd": ScalarFunction.BITWISE_AND,
|
|
373
|
+
"BitwiseNot": ScalarFunction.BITWISE_NOT,
|
|
374
|
+
"BitwiseOr": ScalarFunction.BITWISE_OR,
|
|
375
|
+
"BitwiseXor": ScalarFunction.BITWISE_XOR,
|
|
376
|
+
"Ceil": ScalarFunction.CEIL,
|
|
377
|
+
"Celu": ScalarFunction.CELU,
|
|
378
|
+
"Cos": ScalarFunction.COS,
|
|
379
|
+
"Cosh": ScalarFunction.COSH,
|
|
380
|
+
"Div": ScalarFunction.DIV,
|
|
381
|
+
"Elu": ScalarFunction.ELU,
|
|
382
|
+
"Equal": ScalarFunction.EQ,
|
|
383
|
+
"Erf": ScalarFunction.ERF,
|
|
384
|
+
"Exp": ScalarFunction.EXP,
|
|
385
|
+
"Floor": ScalarFunction.FLOOR,
|
|
386
|
+
"Gelu": ScalarFunction.GELU,
|
|
387
|
+
"Greater": ScalarFunction.GT,
|
|
388
|
+
"GreaterOrEqual": ScalarFunction.GE,
|
|
389
|
+
"HardSigmoid": ScalarFunction.HARDSIGMOID,
|
|
390
|
+
"HardSwish": ScalarFunction.HARDSWISH,
|
|
391
|
+
"Identity": ScalarFunction.POSITIVE,
|
|
392
|
+
"LeakyRelu": ScalarFunction.LEAKY_RELU,
|
|
393
|
+
"Less": ScalarFunction.LT,
|
|
394
|
+
"LessOrEqual": ScalarFunction.LE,
|
|
395
|
+
"Log": ScalarFunction.LOG,
|
|
396
|
+
"Max": ScalarFunction.MAXIMUM,
|
|
397
|
+
"Mean": ScalarFunction.MEAN,
|
|
398
|
+
"Min": ScalarFunction.MINIMUM,
|
|
399
|
+
"Mod": ScalarFunction.FMOD,
|
|
400
|
+
"Mul": ScalarFunction.MUL,
|
|
401
|
+
"Neg": ScalarFunction.NEG,
|
|
402
|
+
"Not": ScalarFunction.LOGICAL_NOT,
|
|
403
|
+
"Or": ScalarFunction.LOGICAL_OR,
|
|
404
|
+
"PRelu": ScalarFunction.PRELU,
|
|
405
|
+
"Pow": ScalarFunction.POW,
|
|
406
|
+
"Reciprocal": ScalarFunction.RECIPROCAL,
|
|
407
|
+
"Relu": ScalarFunction.RELU,
|
|
408
|
+
"Round": ScalarFunction.ROUND,
|
|
409
|
+
"Selu": ScalarFunction.SELU,
|
|
410
|
+
"Sigmoid": ScalarFunction.SIGMOID,
|
|
411
|
+
"Sign": ScalarFunction.SIGN,
|
|
412
|
+
"Sin": ScalarFunction.SIN,
|
|
413
|
+
"Sinh": ScalarFunction.SINH,
|
|
414
|
+
"Softplus": ScalarFunction.SOFTPLUS,
|
|
415
|
+
"Softsign": ScalarFunction.SOFTSIGN,
|
|
416
|
+
"Shrink": ScalarFunction.SHRINK,
|
|
417
|
+
"Sqrt": ScalarFunction.SQRT,
|
|
418
|
+
"Sub": ScalarFunction.SUB,
|
|
419
|
+
"Sum": ScalarFunction.ADD,
|
|
420
|
+
"Swish": ScalarFunction.SWISH,
|
|
421
|
+
"Tan": ScalarFunction.TAN,
|
|
422
|
+
"Tanh": ScalarFunction.TANH,
|
|
423
|
+
"ThresholdedRelu": ScalarFunction.THRESHOLDED_RELU,
|
|
424
|
+
"Xor": ScalarFunction.LOGICAL_XOR,
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
_NO_SUFFIX_MATH = {"isfinite", "isnan", "isinf", "signbit"}
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def _float_literal(value: float, dtype_info: _ScalarTypeInfo) -> str:
|
|
432
|
+
if dtype_info.suffix == "f32":
|
|
433
|
+
if value == int(value):
|
|
434
|
+
return f"{int(value)}.0f"
|
|
435
|
+
literal = f"{value}"
|
|
436
|
+
if "e" in literal or "E" in literal:
|
|
437
|
+
return f"{literal}f"
|
|
438
|
+
if "." not in literal:
|
|
439
|
+
literal = f"{literal}.0"
|
|
440
|
+
return f"{literal}f"
|
|
441
|
+
if value == int(value):
|
|
442
|
+
return f"{int(value)}.0"
|
|
443
|
+
literal = f"{value}"
|
|
444
|
+
if "." not in literal and "e" not in literal and "E" not in literal:
|
|
445
|
+
literal = f"{literal}.0"
|
|
446
|
+
return literal
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def _param_suffix(params: tuple[float, ...]) -> str:
|
|
450
|
+
if not params:
|
|
451
|
+
return ""
|
|
452
|
+
parts: list[str] = []
|
|
453
|
+
for value in params:
|
|
454
|
+
if math.isnan(value):
|
|
455
|
+
encoded = "nan"
|
|
456
|
+
elif math.isinf(value):
|
|
457
|
+
encoded = "neg_inf" if value < 0 else "inf"
|
|
458
|
+
else:
|
|
459
|
+
encoded = format(value, ".17g")
|
|
460
|
+
if encoded == "-0":
|
|
461
|
+
encoded = "0"
|
|
462
|
+
encoded = encoded.replace("e-", "e_neg").replace("e+", "e")
|
|
463
|
+
encoded = encoded.replace("-", "neg").replace(".", "p")
|
|
464
|
+
parts.append(encoded)
|
|
465
|
+
return "__" + "_".join(parts)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def _math_fn(base: str, dtype_info: _ScalarTypeInfo) -> str:
|
|
469
|
+
if dtype_info.suffix == "f32" and base not in _NO_SUFFIX_MATH:
|
|
470
|
+
return f"{base}f"
|
|
471
|
+
return base
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def _normalize_op_name(op_name: str) -> str:
|
|
475
|
+
return _OP_ALIASES.get(op_name, op_name)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def _cast_value(expr: str, dtype_info: _ScalarTypeInfo) -> str:
|
|
479
|
+
if dtype_info.is_small_int:
|
|
480
|
+
return f"({dtype_info.c_type})({expr})"
|
|
481
|
+
return expr
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _simple_unary(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
|
|
485
|
+
lines = [
|
|
486
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a) {{",
|
|
487
|
+
f" return {expr};",
|
|
488
|
+
"}",
|
|
489
|
+
]
|
|
490
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def _simple_binary(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
|
|
494
|
+
lines = [
|
|
495
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
496
|
+
f" return {expr};",
|
|
497
|
+
"}",
|
|
498
|
+
]
|
|
499
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def _float_unary_math(dtype_info: _ScalarTypeInfo, name: str, base: str) -> _GeneratedScalar:
|
|
503
|
+
return _simple_unary(dtype_info, name, f"{_math_fn(base, dtype_info)}(a)")
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def _float_binary_math(dtype_info: _ScalarTypeInfo, name: str, base: str) -> _GeneratedScalar:
|
|
507
|
+
return _simple_binary(dtype_info, name, f"{_math_fn(base, dtype_info)}(a, b)")
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def _float_isfinite(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
511
|
+
one = _float_literal(1.0, dtype_info)
|
|
512
|
+
zero = _float_literal(0.0, dtype_info)
|
|
513
|
+
return _simple_unary(dtype_info, "isfinite", f"isfinite(a) ? {one} : {zero}")
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def _float_isnan(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
517
|
+
one = _float_literal(1.0, dtype_info)
|
|
518
|
+
zero = _float_literal(0.0, dtype_info)
|
|
519
|
+
return _simple_unary(dtype_info, "isnan", f"isnan(a) ? {one} : {zero}")
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def _float_isinf(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
523
|
+
one = _float_literal(1.0, dtype_info)
|
|
524
|
+
zero = _float_literal(0.0, dtype_info)
|
|
525
|
+
return _simple_unary(dtype_info, "isinf", f"isinf(a) ? {one} : {zero}")
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def _float_isneginf(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
529
|
+
one = _float_literal(1.0, dtype_info)
|
|
530
|
+
zero = _float_literal(0.0, dtype_info)
|
|
531
|
+
return _simple_unary(
|
|
532
|
+
dtype_info, "isneginf", f"(isinf(a) && signbit(a)) ? {one} : {zero}"
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def _float_isposinf(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
537
|
+
one = _float_literal(1.0, dtype_info)
|
|
538
|
+
zero = _float_literal(0.0, dtype_info)
|
|
539
|
+
return _simple_unary(
|
|
540
|
+
dtype_info, "isposinf", f"(isinf(a) && !signbit(a)) ? {one} : {zero}"
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def _float_comparison(
|
|
545
|
+
dtype_info: _ScalarTypeInfo, name: str, op: str
|
|
546
|
+
) -> _GeneratedScalar:
|
|
547
|
+
one = _float_literal(1.0, dtype_info)
|
|
548
|
+
zero = _float_literal(0.0, dtype_info)
|
|
549
|
+
return _simple_binary(dtype_info, name, f"a {op} b ? {one} : {zero}")
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def _float_logical_binary(
|
|
553
|
+
dtype_info: _ScalarTypeInfo, name: str, expr: str
|
|
554
|
+
) -> _GeneratedScalar:
|
|
555
|
+
one = _float_literal(1.0, dtype_info)
|
|
556
|
+
zero = _float_literal(0.0, dtype_info)
|
|
557
|
+
return _simple_binary(dtype_info, name, f"{expr} ? {one} : {zero}")
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
def _float_logical_not(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
561
|
+
one = _float_literal(1.0, dtype_info)
|
|
562
|
+
zero = _float_literal(0.0, dtype_info)
|
|
563
|
+
return _simple_unary(dtype_info, "logical_not", f"a == {zero} ? {one} : {zero}")
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def _float_remainder(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
567
|
+
nan = "NAN"
|
|
568
|
+
lines = [
|
|
569
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}remainder({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
570
|
+
" if (isnan(a) || isnan(b)) {",
|
|
571
|
+
f" return {nan};",
|
|
572
|
+
" }",
|
|
573
|
+
f" if (b == {_float_literal(0.0, dtype_info)}) {{",
|
|
574
|
+
f" return {nan};",
|
|
575
|
+
" }",
|
|
576
|
+
f" {dtype_info.c_type} mod = {_math_fn('fmod', dtype_info)}(a, b);",
|
|
577
|
+
f" if (mod == {_float_literal(0.0, dtype_info)}) {{",
|
|
578
|
+
" return mod;",
|
|
579
|
+
" }",
|
|
580
|
+
" if ((mod < 0) != (b < 0)) {",
|
|
581
|
+
" mod += b;",
|
|
582
|
+
" }",
|
|
583
|
+
" return mod;",
|
|
584
|
+
"}",
|
|
585
|
+
]
|
|
586
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def _float_floor_divide(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
590
|
+
return _simple_binary(
|
|
591
|
+
dtype_info, "floor_divide", f"{_math_fn('floor', dtype_info)}(a / b)"
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def _float_logaddexp(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
596
|
+
lines = [
|
|
597
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}logaddexp({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
598
|
+
" if (isnan(a) || isnan(b)) {",
|
|
599
|
+
" return NAN;",
|
|
600
|
+
" }",
|
|
601
|
+
f" {dtype_info.c_type} max_val = {_math_fn('fmax', dtype_info)}(a, b);",
|
|
602
|
+
f" {dtype_info.c_type} min_val = {_math_fn('fmin', dtype_info)}(a, b);",
|
|
603
|
+
" if (max_val == -INFINITY) {",
|
|
604
|
+
" return -INFINITY;",
|
|
605
|
+
" }",
|
|
606
|
+
f" return max_val + {_math_fn('log1p', dtype_info)}({_math_fn('exp', dtype_info)}(min_val - max_val));",
|
|
607
|
+
"}",
|
|
608
|
+
]
|
|
609
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def _float_logaddexp2(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
613
|
+
one = _float_literal(1.0, dtype_info)
|
|
614
|
+
lines = [
|
|
615
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}logaddexp2({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
616
|
+
" if (isnan(a) || isnan(b)) {",
|
|
617
|
+
" return NAN;",
|
|
618
|
+
" }",
|
|
619
|
+
f" {dtype_info.c_type} max_val = {_math_fn('fmax', dtype_info)}(a, b);",
|
|
620
|
+
f" {dtype_info.c_type} min_val = {_math_fn('fmin', dtype_info)}(a, b);",
|
|
621
|
+
" if (max_val == -INFINITY) {",
|
|
622
|
+
" return -INFINITY;",
|
|
623
|
+
" }",
|
|
624
|
+
f" return max_val + {_math_fn('log2', dtype_info)}({one} + {_math_fn('exp2', dtype_info)}(min_val - max_val));",
|
|
625
|
+
"}",
|
|
626
|
+
]
|
|
627
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
def _float_xlogy(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
631
|
+
zero = _float_literal(0.0, dtype_info)
|
|
632
|
+
lines = [
|
|
633
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}xlogy({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
634
|
+
" if (isnan(a) || isnan(b)) {",
|
|
635
|
+
" return NAN;",
|
|
636
|
+
" }",
|
|
637
|
+
f" if (a == {zero}) {{",
|
|
638
|
+
f" return {zero};",
|
|
639
|
+
" }",
|
|
640
|
+
f" return a * {_math_fn('log', dtype_info)}(b);",
|
|
641
|
+
"}",
|
|
642
|
+
]
|
|
643
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def _float_heaviside(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
647
|
+
zero = _float_literal(0.0, dtype_info)
|
|
648
|
+
one = _float_literal(1.0, dtype_info)
|
|
649
|
+
lines = [
|
|
650
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}heaviside({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
651
|
+
f" if (a > {zero}) {{",
|
|
652
|
+
f" return {one};",
|
|
653
|
+
" }",
|
|
654
|
+
f" if (a == {zero}) {{",
|
|
655
|
+
" return b;",
|
|
656
|
+
" }",
|
|
657
|
+
f" return {zero};",
|
|
658
|
+
"}",
|
|
659
|
+
]
|
|
660
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def _float_ldexp(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
664
|
+
return _simple_binary(
|
|
665
|
+
dtype_info, "ldexp", f"a * {_math_fn('exp2', dtype_info)}(b)"
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def _float_reciprocal(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
670
|
+
one = _float_literal(1.0, dtype_info)
|
|
671
|
+
return _simple_unary(dtype_info, "reciprocal", f"{one} / a")
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def _float_relu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
675
|
+
zero = _float_literal(0.0, dtype_info)
|
|
676
|
+
return _simple_unary(dtype_info, "relu", f"a > {zero} ? a : {zero}")
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def _float_rsqrt(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
680
|
+
one = _float_literal(1.0, dtype_info)
|
|
681
|
+
return _simple_unary(dtype_info, "rsqrt", f"{one} / {_math_fn('sqrt', dtype_info)}(a)")
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def _float_sigmoid(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
685
|
+
one = _float_literal(1.0, dtype_info)
|
|
686
|
+
return _simple_unary(dtype_info, "sigmoid", f"{one} / ({one} + {_math_fn('exp', dtype_info)}(-a))")
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def _float_log_sigmoid(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
690
|
+
zero = _float_literal(0.0, dtype_info)
|
|
691
|
+
lines = [
|
|
692
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}log_sigmoid({dtype_info.c_type} a) {{",
|
|
693
|
+
f" if (a >= {zero}) {{",
|
|
694
|
+
f" return -{_math_fn('log1p', dtype_info)}({_math_fn('exp', dtype_info)}(-a));",
|
|
695
|
+
" }",
|
|
696
|
+
f" return a - {_math_fn('log1p', dtype_info)}({_math_fn('exp', dtype_info)}(a));",
|
|
697
|
+
"}",
|
|
698
|
+
]
|
|
699
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def _float_gelu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
703
|
+
inv_sqrt2 = _float_literal(0.7071067811865475, dtype_info)
|
|
704
|
+
half = _float_literal(0.5, dtype_info)
|
|
705
|
+
one = _float_literal(1.0, dtype_info)
|
|
706
|
+
lines = [
|
|
707
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}gelu({dtype_info.c_type} a) {{",
|
|
708
|
+
f" const {dtype_info.c_type} inv_sqrt2 = {inv_sqrt2};",
|
|
709
|
+
f" return {half} * a * ({one} + {_math_fn('erf', dtype_info)}(a * inv_sqrt2));",
|
|
710
|
+
"}",
|
|
711
|
+
]
|
|
712
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
def _float_elu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
716
|
+
one = _float_literal(1.0, dtype_info)
|
|
717
|
+
lines = [
|
|
718
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}elu({dtype_info.c_type} a) {{",
|
|
719
|
+
f" const {dtype_info.c_type} alpha = {one};",
|
|
720
|
+
f" const {dtype_info.c_type} scale = {one};",
|
|
721
|
+
f" const {dtype_info.c_type} input_scale = {one};",
|
|
722
|
+
" if (a > 0) {",
|
|
723
|
+
" return scale * a;",
|
|
724
|
+
" }",
|
|
725
|
+
f" return scale * alpha * ({_math_fn('exp', dtype_info)}(input_scale * a) - {one});",
|
|
726
|
+
"}",
|
|
727
|
+
]
|
|
728
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def _float_elu_param(
|
|
732
|
+
dtype_info: _ScalarTypeInfo,
|
|
733
|
+
params: tuple[float, ...],
|
|
734
|
+
function_name: str,
|
|
735
|
+
) -> _GeneratedScalar:
|
|
736
|
+
if params and len(params) != 1:
|
|
737
|
+
raise ScalarFunctionError("elu expects 1 parameter: alpha")
|
|
738
|
+
alpha_value = params[0] if params else 1.0
|
|
739
|
+
alpha = _float_literal(alpha_value, dtype_info)
|
|
740
|
+
one = _float_literal(1.0, dtype_info)
|
|
741
|
+
lines = [
|
|
742
|
+
f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
|
|
743
|
+
f" const {dtype_info.c_type} alpha = {alpha};",
|
|
744
|
+
" if (a >= 0) {",
|
|
745
|
+
" return a;",
|
|
746
|
+
" }",
|
|
747
|
+
f" return alpha * ({_math_fn('exp', dtype_info)}(a) - {one});",
|
|
748
|
+
"}",
|
|
749
|
+
]
|
|
750
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
def _float_celu(
|
|
754
|
+
dtype_info: _ScalarTypeInfo,
|
|
755
|
+
params: tuple[float, ...],
|
|
756
|
+
function_name: str,
|
|
757
|
+
) -> _GeneratedScalar:
|
|
758
|
+
if params and len(params) != 1:
|
|
759
|
+
raise ScalarFunctionError("celu expects 1 parameter: alpha")
|
|
760
|
+
alpha_value = params[0] if params else 1.0
|
|
761
|
+
alpha = _float_literal(alpha_value, dtype_info)
|
|
762
|
+
one = _float_literal(1.0, dtype_info)
|
|
763
|
+
lines = [
|
|
764
|
+
f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
|
|
765
|
+
f" const {dtype_info.c_type} alpha = {alpha};",
|
|
766
|
+
" if (a > 0) {",
|
|
767
|
+
" return a;",
|
|
768
|
+
" }",
|
|
769
|
+
f" return alpha * ({_math_fn('exp', dtype_info)}(a / alpha) - {one});",
|
|
770
|
+
"}",
|
|
771
|
+
]
|
|
772
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
def _float_affine(
|
|
776
|
+
dtype_info: _ScalarTypeInfo,
|
|
777
|
+
params: tuple[float, ...],
|
|
778
|
+
function_name: str,
|
|
779
|
+
) -> _GeneratedScalar:
|
|
780
|
+
if params and len(params) != 2:
|
|
781
|
+
raise ScalarFunctionError("affine expects 2 parameters: alpha, beta")
|
|
782
|
+
alpha_value = params[0] if params else 1.0
|
|
783
|
+
beta_value = params[1] if len(params) > 1 else 0.0
|
|
784
|
+
alpha = _float_literal(alpha_value, dtype_info)
|
|
785
|
+
beta = _float_literal(beta_value, dtype_info)
|
|
786
|
+
lines = [
|
|
787
|
+
f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
|
|
788
|
+
f" const {dtype_info.c_type} alpha = {alpha};",
|
|
789
|
+
f" const {dtype_info.c_type} beta = {beta};",
|
|
790
|
+
" return alpha * a + beta;",
|
|
791
|
+
"}",
|
|
792
|
+
]
|
|
793
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def _float_scaled_tanh(
|
|
797
|
+
dtype_info: _ScalarTypeInfo,
|
|
798
|
+
params: tuple[float, ...],
|
|
799
|
+
function_name: str,
|
|
800
|
+
) -> _GeneratedScalar:
|
|
801
|
+
if params and len(params) != 2:
|
|
802
|
+
raise ScalarFunctionError("scaled_tanh expects 2 parameters: alpha, beta")
|
|
803
|
+
alpha_value = params[0] if params else 1.0
|
|
804
|
+
beta_value = params[1] if len(params) > 1 else 1.0
|
|
805
|
+
alpha = _float_literal(alpha_value, dtype_info)
|
|
806
|
+
beta = _float_literal(beta_value, dtype_info)
|
|
807
|
+
lines = [
|
|
808
|
+
f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
|
|
809
|
+
f" const {dtype_info.c_type} alpha = {alpha};",
|
|
810
|
+
f" const {dtype_info.c_type} beta = {beta};",
|
|
811
|
+
f" return alpha * {_math_fn('tanh', dtype_info)}(beta * a);",
|
|
812
|
+
"}",
|
|
813
|
+
]
|
|
814
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
def _float_swish(
|
|
818
|
+
dtype_info: _ScalarTypeInfo,
|
|
819
|
+
params: tuple[float, ...],
|
|
820
|
+
function_name: str,
|
|
821
|
+
) -> _GeneratedScalar:
|
|
822
|
+
if params and len(params) != 1:
|
|
823
|
+
raise ScalarFunctionError("swish expects 1 parameter: alpha")
|
|
824
|
+
alpha_value = params[0] if params else 1.0
|
|
825
|
+
alpha = _float_literal(alpha_value, dtype_info)
|
|
826
|
+
one = _float_literal(1.0, dtype_info)
|
|
827
|
+
lines = [
|
|
828
|
+
f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
|
|
829
|
+
f" const {dtype_info.c_type} alpha = {alpha};",
|
|
830
|
+
f" return a / ({one} + {_math_fn('exp', dtype_info)}(-alpha * a));",
|
|
831
|
+
"}",
|
|
832
|
+
]
|
|
833
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
834
|
+
|
|
835
|
+
|
|
836
|
+
def _float_leaky_relu_param(
|
|
837
|
+
dtype_info: _ScalarTypeInfo,
|
|
838
|
+
params: tuple[float, ...],
|
|
839
|
+
function_name: str,
|
|
840
|
+
) -> _GeneratedScalar:
|
|
841
|
+
if params and len(params) != 1:
|
|
842
|
+
raise ScalarFunctionError("leaky_relu expects 1 parameter: alpha")
|
|
843
|
+
alpha_value = params[0] if params else 0.01
|
|
844
|
+
alpha = _float_literal(alpha_value, dtype_info)
|
|
845
|
+
lines = [
|
|
846
|
+
f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
|
|
847
|
+
f" const {dtype_info.c_type} alpha = {alpha};",
|
|
848
|
+
" return a < 0 ? alpha * a : a;",
|
|
849
|
+
"}",
|
|
850
|
+
]
|
|
851
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
def _float_thresholded_relu_param(
|
|
855
|
+
dtype_info: _ScalarTypeInfo,
|
|
856
|
+
params: tuple[float, ...],
|
|
857
|
+
function_name: str,
|
|
858
|
+
) -> _GeneratedScalar:
|
|
859
|
+
if params and len(params) != 1:
|
|
860
|
+
raise ScalarFunctionError("thresholded_relu expects 1 parameter: alpha")
|
|
861
|
+
alpha_value = params[0] if params else 1.0
|
|
862
|
+
alpha = _float_literal(alpha_value, dtype_info)
|
|
863
|
+
zero = _float_literal(0.0, dtype_info)
|
|
864
|
+
lines = [
|
|
865
|
+
f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
|
|
866
|
+
f" const {dtype_info.c_type} alpha = {alpha};",
|
|
867
|
+
f" return a > alpha ? a : {zero};",
|
|
868
|
+
"}",
|
|
869
|
+
]
|
|
870
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
def _float_hardsigmoid_param(
|
|
874
|
+
dtype_info: _ScalarTypeInfo,
|
|
875
|
+
params: tuple[float, ...],
|
|
876
|
+
function_name: str,
|
|
877
|
+
) -> _GeneratedScalar:
|
|
878
|
+
if params and len(params) != 2:
|
|
879
|
+
raise ScalarFunctionError("hardsigmoid expects 2 parameters: alpha, beta")
|
|
880
|
+
alpha_value = params[0] if params else 0.2
|
|
881
|
+
beta_value = params[1] if len(params) > 1 else 0.5
|
|
882
|
+
alpha = _float_literal(alpha_value, dtype_info)
|
|
883
|
+
beta = _float_literal(beta_value, dtype_info)
|
|
884
|
+
zero = _float_literal(0.0, dtype_info)
|
|
885
|
+
one = _float_literal(1.0, dtype_info)
|
|
886
|
+
lines = [
|
|
887
|
+
f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
|
|
888
|
+
f" {dtype_info.c_type} value = a * {alpha} + {beta};",
|
|
889
|
+
f" {dtype_info.c_type} clamped = {_math_fn('fmin', dtype_info)}({one}, {_math_fn('fmax', dtype_info)}({zero}, value));",
|
|
890
|
+
" return clamped;",
|
|
891
|
+
"}",
|
|
892
|
+
]
|
|
893
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
def _float_shrink(
|
|
897
|
+
dtype_info: _ScalarTypeInfo,
|
|
898
|
+
params: tuple[float, ...],
|
|
899
|
+
function_name: str,
|
|
900
|
+
) -> _GeneratedScalar:
|
|
901
|
+
if params and len(params) != 2:
|
|
902
|
+
raise ScalarFunctionError("shrink expects 2 parameters: bias, lambd")
|
|
903
|
+
bias_value = params[0] if params else 0.0
|
|
904
|
+
lambd_value = params[1] if params else 0.5
|
|
905
|
+
bias = _float_literal(bias_value, dtype_info)
|
|
906
|
+
lambd = _float_literal(lambd_value, dtype_info)
|
|
907
|
+
zero = _float_literal(0.0, dtype_info)
|
|
908
|
+
lines = [
|
|
909
|
+
f"static inline {dtype_info.c_type} {function_name}({dtype_info.c_type} a) {{",
|
|
910
|
+
f" const {dtype_info.c_type} bias = {bias};",
|
|
911
|
+
f" const {dtype_info.c_type} lambd = {lambd};",
|
|
912
|
+
" if (a < -lambd) {",
|
|
913
|
+
" return a + bias;",
|
|
914
|
+
" }",
|
|
915
|
+
" if (a > lambd) {",
|
|
916
|
+
" return a - bias;",
|
|
917
|
+
" }",
|
|
918
|
+
f" return {zero};",
|
|
919
|
+
"}",
|
|
920
|
+
]
|
|
921
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
def _float_leaky_relu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
925
|
+
negative_slope = _float_literal(0.01, dtype_info)
|
|
926
|
+
lines = [
|
|
927
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}leaky_relu({dtype_info.c_type} a) {{",
|
|
928
|
+
f" const {dtype_info.c_type} negative_slope = {negative_slope};",
|
|
929
|
+
" return a > 0 ? a : negative_slope * a;",
|
|
930
|
+
"}",
|
|
931
|
+
]
|
|
932
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
def _float_softplus(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
936
|
+
beta = _float_literal(1.0, dtype_info)
|
|
937
|
+
threshold = _float_literal(20.0, dtype_info)
|
|
938
|
+
lines = [
|
|
939
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}softplus({dtype_info.c_type} a) {{",
|
|
940
|
+
f" const {dtype_info.c_type} beta = {beta};",
|
|
941
|
+
f" const {dtype_info.c_type} threshold = {threshold};",
|
|
942
|
+
" if (beta * a > threshold) {",
|
|
943
|
+
" return a;",
|
|
944
|
+
" }",
|
|
945
|
+
f" return {_math_fn('log1p', dtype_info)}({_math_fn('exp', dtype_info)}(beta * a)) / beta;",
|
|
946
|
+
"}",
|
|
947
|
+
]
|
|
948
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
def _float_softsign(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
952
|
+
one = _float_literal(1.0, dtype_info)
|
|
953
|
+
return _simple_unary(
|
|
954
|
+
dtype_info,
|
|
955
|
+
"softsign",
|
|
956
|
+
f"a / ({one} + {_math_fn('fabs', dtype_info)}(a))",
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
def _float_silu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
961
|
+
one = _float_literal(1.0, dtype_info)
|
|
962
|
+
return _simple_unary(dtype_info, "silu", f"a / ({one} + {_math_fn('exp', dtype_info)}(-a))")
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def _float_mish(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
966
|
+
twenty = _float_literal(20.0, dtype_info)
|
|
967
|
+
zero = _float_literal(0.0, dtype_info)
|
|
968
|
+
lines = [
|
|
969
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}mish({dtype_info.c_type} a) {{",
|
|
970
|
+
f" if (a > {twenty}) {{",
|
|
971
|
+
" return a;",
|
|
972
|
+
" }",
|
|
973
|
+
f" if (a < -{twenty}) {{",
|
|
974
|
+
f" {dtype_info.c_type} exp_a = {_math_fn('exp', dtype_info)}(a);",
|
|
975
|
+
" return a * exp_a;",
|
|
976
|
+
" }",
|
|
977
|
+
f" {dtype_info.c_type} softplus = {_math_fn('log1p', dtype_info)}({_math_fn('exp', dtype_info)}(a));",
|
|
978
|
+
f" return a * {_math_fn('tanh', dtype_info)}(softplus);",
|
|
979
|
+
"}",
|
|
980
|
+
]
|
|
981
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
def _float_selu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
985
|
+
alpha = "1.6732632423543772848170429916717"
|
|
986
|
+
scale = "1.0507009873554804934193349852946"
|
|
987
|
+
alpha_literal = f"{alpha}f" if dtype_info.suffix == "f32" else alpha
|
|
988
|
+
scale_literal = f"{scale}f" if dtype_info.suffix == "f32" else scale
|
|
989
|
+
one = _float_literal(1.0, dtype_info)
|
|
990
|
+
lines = [
|
|
991
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}selu({dtype_info.c_type} a) {{",
|
|
992
|
+
f" const {dtype_info.c_type} alpha = {alpha_literal};",
|
|
993
|
+
f" const {dtype_info.c_type} scale = {scale_literal};",
|
|
994
|
+
" if (a > 0) {",
|
|
995
|
+
" return scale * a;",
|
|
996
|
+
" }",
|
|
997
|
+
f" return scale * alpha * ({_math_fn('exp', dtype_info)}(a) - {one});",
|
|
998
|
+
"}",
|
|
999
|
+
]
|
|
1000
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
def _float_thresholded_relu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1004
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1005
|
+
alpha = _float_literal(1.0, dtype_info)
|
|
1006
|
+
return _simple_unary(
|
|
1007
|
+
dtype_info,
|
|
1008
|
+
"thresholded_relu",
|
|
1009
|
+
f"a > {alpha} ? a : {zero}",
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
def _float_relu6(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1014
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1015
|
+
six = _float_literal(6.0, dtype_info)
|
|
1016
|
+
return _simple_unary(
|
|
1017
|
+
dtype_info,
|
|
1018
|
+
"relu6",
|
|
1019
|
+
f"{_math_fn('fmin', dtype_info)}({six}, {_math_fn('fmax', dtype_info)}({zero}, a))",
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
|
|
1023
|
+
def _float_hardsigmoid(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1024
|
+
alpha = _float_literal(0.2, dtype_info)
|
|
1025
|
+
beta = _float_literal(0.5, dtype_info)
|
|
1026
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1027
|
+
one = _float_literal(1.0, dtype_info)
|
|
1028
|
+
lines = [
|
|
1029
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}hardsigmoid({dtype_info.c_type} a) {{",
|
|
1030
|
+
f" {dtype_info.c_type} value = a * {alpha} + {beta};",
|
|
1031
|
+
f" {dtype_info.c_type} clamped = {_math_fn('fmin', dtype_info)}({one}, {_math_fn('fmax', dtype_info)}({zero}, value));",
|
|
1032
|
+
" return clamped;",
|
|
1033
|
+
"}",
|
|
1034
|
+
]
|
|
1035
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
def _float_hardswish(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1039
|
+
three = _float_literal(3.0, dtype_info)
|
|
1040
|
+
six = _float_literal(6.0, dtype_info)
|
|
1041
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1042
|
+
lines = [
|
|
1043
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}hardswish({dtype_info.c_type} a) {{",
|
|
1044
|
+
f" {dtype_info.c_type} shifted = a + {three};",
|
|
1045
|
+
f" {dtype_info.c_type} clamped = {_math_fn('fmin', dtype_info)}({six}, {_math_fn('fmax', dtype_info)}({zero}, shifted));",
|
|
1046
|
+
f" return a * clamped / {six};",
|
|
1047
|
+
"}",
|
|
1048
|
+
]
|
|
1049
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def _float_sign(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1053
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1054
|
+
one = _float_literal(1.0, dtype_info)
|
|
1055
|
+
minus_one = _float_literal(-1.0, dtype_info)
|
|
1056
|
+
lines = [
|
|
1057
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}sign({dtype_info.c_type} a) {{",
|
|
1058
|
+
" if (isnan(a)) {",
|
|
1059
|
+
" return a;",
|
|
1060
|
+
" }",
|
|
1061
|
+
f" if (a > {zero}) {{",
|
|
1062
|
+
f" return {one};",
|
|
1063
|
+
" }",
|
|
1064
|
+
f" if (a < {zero}) {{",
|
|
1065
|
+
f" return {minus_one};",
|
|
1066
|
+
" }",
|
|
1067
|
+
f" return {zero};",
|
|
1068
|
+
"}",
|
|
1069
|
+
]
|
|
1070
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
def _float_round(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1074
|
+
return _float_unary_math(dtype_info, "round", "round")
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
def _float_trunc(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1078
|
+
return _float_unary_math(dtype_info, "trunc", "trunc")
|
|
1079
|
+
|
|
1080
|
+
|
|
1081
|
+
def _float_angle(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1082
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1083
|
+
pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
|
|
1084
|
+
lines = [
|
|
1085
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}angle({dtype_info.c_type} a) {{",
|
|
1086
|
+
" if (isnan(a)) {",
|
|
1087
|
+
" return a;",
|
|
1088
|
+
" }",
|
|
1089
|
+
f" return a < {zero} ? {pi} : {zero};",
|
|
1090
|
+
"}",
|
|
1091
|
+
]
|
|
1092
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1093
|
+
|
|
1094
|
+
|
|
1095
|
+
def _float_conj(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
|
|
1096
|
+
return _simple_unary(dtype_info, name, "a")
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
def _float_deg2rad(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1100
|
+
pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
|
|
1101
|
+
one_eighty = _float_literal(180.0, dtype_info)
|
|
1102
|
+
return _simple_unary(dtype_info, "deg2rad", f"a * ({pi} / {one_eighty})")
|
|
1103
|
+
|
|
1104
|
+
|
|
1105
|
+
def _float_rad2deg(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1106
|
+
pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
|
|
1107
|
+
one_eighty = _float_literal(180.0, dtype_info)
|
|
1108
|
+
return _simple_unary(dtype_info, "rad2deg", f"a * ({one_eighty} / {pi})")
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
def _float_digamma_f64() -> _GeneratedScalar:
|
|
1112
|
+
lines = [
|
|
1113
|
+
"static inline double ref_scalar_f64_digamma(double x) {",
|
|
1114
|
+
" if (isnan(x) || isinf(x)) {",
|
|
1115
|
+
" return x;",
|
|
1116
|
+
" }",
|
|
1117
|
+
" if (x <= 0.0) {",
|
|
1118
|
+
" double frac = x - floor(x);",
|
|
1119
|
+
" if (frac == 0.0) {",
|
|
1120
|
+
" return NAN;",
|
|
1121
|
+
" }",
|
|
1122
|
+
" return ref_scalar_f64_digamma(1.0 - x) - REF_PI_D / tan(REF_PI_D * x);",
|
|
1123
|
+
" }",
|
|
1124
|
+
" double result = 0.0;",
|
|
1125
|
+
" while (x < 10.0) {",
|
|
1126
|
+
" result -= 1.0 / x;",
|
|
1127
|
+
" x += 1.0;",
|
|
1128
|
+
" }",
|
|
1129
|
+
" double inv = 1.0 / x;",
|
|
1130
|
+
" double inv2 = inv * inv;",
|
|
1131
|
+
" result += log(x) - 0.5 * inv",
|
|
1132
|
+
" - inv2 * (1.0 / 12.0 - inv2 * (1.0 / 120.0",
|
|
1133
|
+
" - inv2 * (1.0 / 252.0 - inv2 * (1.0 / 240.0",
|
|
1134
|
+
" - inv2 * (1.0 / 132.0 - inv2 * (691.0 / 32760.0))))));",
|
|
1135
|
+
" return result;",
|
|
1136
|
+
"}",
|
|
1137
|
+
]
|
|
1138
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
def _float_digamma(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1142
|
+
if dtype_info.suffix == "f64":
|
|
1143
|
+
return _float_digamma_f64()
|
|
1144
|
+
lines = [
|
|
1145
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}digamma({dtype_info.c_type} x) {{",
|
|
1146
|
+
" return (float)ref_scalar_f64_digamma((double)x);",
|
|
1147
|
+
"}",
|
|
1148
|
+
]
|
|
1149
|
+
deps = {_scalar_key_from_op(_SCALAR_TYPES[ScalarType.F64], "digamma")}
|
|
1150
|
+
return _GeneratedScalar(lines=lines, deps=deps, includes=set())
|
|
1151
|
+
|
|
1152
|
+
|
|
1153
|
+
def _float_erfinv(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1154
|
+
pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
|
|
1155
|
+
one = _float_literal(1.0, dtype_info)
|
|
1156
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1157
|
+
two = _float_literal(2.0, dtype_info)
|
|
1158
|
+
a_literal = _float_literal(0.147, dtype_info)
|
|
1159
|
+
lines = [
|
|
1160
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}erfinv({dtype_info.c_type} x) {{",
|
|
1161
|
+
" if (isnan(x)) {",
|
|
1162
|
+
" return x;",
|
|
1163
|
+
" }",
|
|
1164
|
+
f" if (x <= -{one}) {{",
|
|
1165
|
+
f" return x == -{one} ? -INFINITY : NAN;",
|
|
1166
|
+
" }",
|
|
1167
|
+
f" if (x >= {one}) {{",
|
|
1168
|
+
f" return x == {one} ? INFINITY : NAN;",
|
|
1169
|
+
" }",
|
|
1170
|
+
f" if (x == {zero}) {{",
|
|
1171
|
+
f" return {zero};",
|
|
1172
|
+
" }",
|
|
1173
|
+
f" {dtype_info.c_type} a = {a_literal};",
|
|
1174
|
+
f" {dtype_info.c_type} ln = {_math_fn('log', dtype_info)}({one} - x * x);",
|
|
1175
|
+
f" {dtype_info.c_type} term = {two} / ({pi} * a) + ln / {two};",
|
|
1176
|
+
f" {dtype_info.c_type} inner = term * term - ln / a;",
|
|
1177
|
+
f" {dtype_info.c_type} approx = {_math_fn('sqrt', dtype_info)}({_math_fn('fmax', dtype_info)}({zero}, {_math_fn('sqrt', dtype_info)}(inner) - term));",
|
|
1178
|
+
f" if (x < {zero}) {{",
|
|
1179
|
+
" approx = -approx;",
|
|
1180
|
+
" }",
|
|
1181
|
+
" for (int i = 0; i < 2; ++i) {",
|
|
1182
|
+
f" {dtype_info.c_type} err = {_math_fn('erf', dtype_info)}(approx) - x;",
|
|
1183
|
+
f" {dtype_info.c_type} deriv = {two} / {_math_fn('sqrt', dtype_info)}({pi}) * {_math_fn('exp', dtype_info)}(-approx * approx);",
|
|
1184
|
+
" approx -= err / deriv;",
|
|
1185
|
+
" }",
|
|
1186
|
+
" return approx;",
|
|
1187
|
+
"}",
|
|
1188
|
+
]
|
|
1189
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1190
|
+
|
|
1191
|
+
|
|
1192
|
+
def _float_frac(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1193
|
+
return _simple_unary(dtype_info, "frac", f"a - {_math_fn('trunc', dtype_info)}(a)")
|
|
1194
|
+
|
|
1195
|
+
|
|
1196
|
+
def _float_i0(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1197
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1198
|
+
three_seven_five = _float_literal(3.75, dtype_info)
|
|
1199
|
+
lines = [
|
|
1200
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}i0({dtype_info.c_type} x) {{",
|
|
1201
|
+
f" {dtype_info.c_type} ax = {_math_fn('fabs', dtype_info)}(x);",
|
|
1202
|
+
f" if (ax < {three_seven_five}) {{",
|
|
1203
|
+
f" {dtype_info.c_type} y = x / {three_seven_five};",
|
|
1204
|
+
" y *= y;",
|
|
1205
|
+
f" return { _float_literal(1.0, dtype_info)} + y * ({_float_literal(3.5156229, dtype_info)} + y * ({_float_literal(3.0899424, dtype_info)} + y * ({_float_literal(1.2067492, dtype_info)}",
|
|
1206
|
+
f" + y * ({_float_literal(0.2659732, dtype_info)} + y * ({_float_literal(0.0360768, dtype_info)} + y * {_float_literal(0.0045813, dtype_info)})))));",
|
|
1207
|
+
" }",
|
|
1208
|
+
f" {dtype_info.c_type} y = {three_seven_five} / ax;",
|
|
1209
|
+
f" return ({_math_fn('exp', dtype_info)}(ax) / {_math_fn('sqrt', dtype_info)}(ax)) * ({_float_literal(0.39894228, dtype_info)} + y * ({_float_literal(0.01328592, dtype_info)}",
|
|
1210
|
+
f" + y * ({_float_literal(0.00225319, dtype_info)} + y * ({_float_literal(-0.00157565, dtype_info)} + y * ({_float_literal(0.00916281, dtype_info)}",
|
|
1211
|
+
f" + y * ({_float_literal(-0.02057706, dtype_info)} + y * ({_float_literal(0.02635537, dtype_info)}",
|
|
1212
|
+
f" + y * ({_float_literal(-0.01647633, dtype_info)} + y * {_float_literal(0.00392377, dtype_info)}))))))));",
|
|
1213
|
+
"}",
|
|
1214
|
+
]
|
|
1215
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1216
|
+
|
|
1217
|
+
|
|
1218
|
+
def _float_logit(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1219
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1220
|
+
one = _float_literal(1.0, dtype_info)
|
|
1221
|
+
lines = [
|
|
1222
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}logit({dtype_info.c_type} a) {{",
|
|
1223
|
+
" if (isnan(a)) {",
|
|
1224
|
+
" return a;",
|
|
1225
|
+
" }",
|
|
1226
|
+
f" if (a == {zero}) {{",
|
|
1227
|
+
" return -INFINITY;",
|
|
1228
|
+
" }",
|
|
1229
|
+
f" if (a == {one}) {{",
|
|
1230
|
+
" return INFINITY;",
|
|
1231
|
+
" }",
|
|
1232
|
+
f" if (a < {zero} || a > {one}) {{",
|
|
1233
|
+
" return NAN;",
|
|
1234
|
+
" }",
|
|
1235
|
+
f" return {_math_fn('log', dtype_info)}(a / ({one} - a));",
|
|
1236
|
+
"}",
|
|
1237
|
+
]
|
|
1238
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1239
|
+
|
|
1240
|
+
|
|
1241
|
+
def _float_nan_to_num(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1242
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1243
|
+
lines = [
|
|
1244
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}nan_to_num({dtype_info.c_type} a) {{",
|
|
1245
|
+
" if (isnan(a)) {",
|
|
1246
|
+
f" return {zero};",
|
|
1247
|
+
" }",
|
|
1248
|
+
" if (isinf(a)) {",
|
|
1249
|
+
" return signbit(a) ? -FLT_MAX : FLT_MAX;",
|
|
1250
|
+
" }",
|
|
1251
|
+
" return a;",
|
|
1252
|
+
"}",
|
|
1253
|
+
]
|
|
1254
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1255
|
+
|
|
1256
|
+
|
|
1257
|
+
def _float_sgn(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1258
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1259
|
+
one = _float_literal(1.0, dtype_info)
|
|
1260
|
+
minus_one = _float_literal(-1.0, dtype_info)
|
|
1261
|
+
lines = [
|
|
1262
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}sgn({dtype_info.c_type} a) {{",
|
|
1263
|
+
" if (isnan(a)) {",
|
|
1264
|
+
f" return {zero};",
|
|
1265
|
+
" }",
|
|
1266
|
+
f" if (a > {zero}) {{",
|
|
1267
|
+
f" return {one};",
|
|
1268
|
+
" }",
|
|
1269
|
+
f" if (a < {zero}) {{",
|
|
1270
|
+
f" return {minus_one};",
|
|
1271
|
+
" }",
|
|
1272
|
+
f" return {zero};",
|
|
1273
|
+
"}",
|
|
1274
|
+
]
|
|
1275
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1276
|
+
|
|
1277
|
+
|
|
1278
|
+
def _float_sinc(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1279
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1280
|
+
one = _float_literal(1.0, dtype_info)
|
|
1281
|
+
pi = "REF_PI_F" if dtype_info.suffix == "f32" else "REF_PI_D"
|
|
1282
|
+
lines = [
|
|
1283
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}sinc({dtype_info.c_type} a) {{",
|
|
1284
|
+
f" if (a == {zero}) {{",
|
|
1285
|
+
f" return {one};",
|
|
1286
|
+
" }",
|
|
1287
|
+
f" {dtype_info.c_type} x = {pi} * a;",
|
|
1288
|
+
f" return {_math_fn('sin', dtype_info)}(x) / x;",
|
|
1289
|
+
"}",
|
|
1290
|
+
]
|
|
1291
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1292
|
+
|
|
1293
|
+
|
|
1294
|
+
def _float_square(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1295
|
+
return _simple_unary(dtype_info, "square", "a * a")
|
|
1296
|
+
|
|
1297
|
+
|
|
1298
|
+
def _float_positive(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1299
|
+
return _simple_unary(dtype_info, "positive", "a")
|
|
1300
|
+
|
|
1301
|
+
|
|
1302
|
+
def _float_clamp_min(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1303
|
+
return _simple_binary(dtype_info, "clamp_min", f"{_math_fn('fmax', dtype_info)}(a, b)")
|
|
1304
|
+
|
|
1305
|
+
|
|
1306
|
+
def _float_clamp_max(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1307
|
+
return _simple_binary(dtype_info, "clamp_max", f"{_math_fn('fmin', dtype_info)}(a, b)")
|
|
1308
|
+
|
|
1309
|
+
def _float_binary_op_handler(name: str, op: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
|
|
1310
|
+
def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1311
|
+
return _simple_binary(dtype_info, name, f"a {op} b")
|
|
1312
|
+
|
|
1313
|
+
return handler
|
|
1314
|
+
|
|
1315
|
+
|
|
1316
|
+
def _float_unary_math_handler(name: str, base: str | None = None) -> Callable[
|
|
1317
|
+
[_ScalarTypeInfo], _GeneratedScalar
|
|
1318
|
+
]:
|
|
1319
|
+
base_name = base or name
|
|
1320
|
+
|
|
1321
|
+
def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1322
|
+
return _float_unary_math(dtype_info, name, base_name)
|
|
1323
|
+
|
|
1324
|
+
return handler
|
|
1325
|
+
|
|
1326
|
+
|
|
1327
|
+
def _float_binary_math_handler(name: str, base: str | None = None) -> Callable[
|
|
1328
|
+
[_ScalarTypeInfo], _GeneratedScalar
|
|
1329
|
+
]:
|
|
1330
|
+
base_name = base or name
|
|
1331
|
+
|
|
1332
|
+
def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1333
|
+
return _float_binary_math(dtype_info, name, base_name)
|
|
1334
|
+
|
|
1335
|
+
return handler
|
|
1336
|
+
|
|
1337
|
+
|
|
1338
|
+
def _float_comparison_handler(name: str, op: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
|
|
1339
|
+
def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1340
|
+
return _float_comparison(dtype_info, name, op)
|
|
1341
|
+
|
|
1342
|
+
return handler
|
|
1343
|
+
|
|
1344
|
+
|
|
1345
|
+
def _float_logical_or(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1346
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1347
|
+
return _float_logical_binary(dtype_info, "logical_or", f"(a != {zero} || b != {zero})")
|
|
1348
|
+
|
|
1349
|
+
|
|
1350
|
+
def _float_logical_and(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1351
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1352
|
+
return _float_logical_binary(dtype_info, "logical_and", f"(a != {zero} && b != {zero})")
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
def _float_logical_xor(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1356
|
+
zero = _float_literal(0.0, dtype_info)
|
|
1357
|
+
return _float_logical_binary(dtype_info, "logical_xor", f"((a != {zero}) != (b != {zero}))")
|
|
1358
|
+
|
|
1359
|
+
|
|
1360
|
+
_FLOAT_OP_DISPATCH: Mapping[str, Callable[[_ScalarTypeInfo], _GeneratedScalar]] = {
|
|
1361
|
+
"abs": _float_unary_math_handler("abs", "fabs"),
|
|
1362
|
+
"add": _float_binary_op_handler("add", "+"),
|
|
1363
|
+
"sub": _float_binary_op_handler("sub", "-"),
|
|
1364
|
+
"mul": _float_binary_op_handler("mul", "*"),
|
|
1365
|
+
"div": _float_binary_op_handler("div", "/"),
|
|
1366
|
+
"maximum": _float_binary_math_handler("maximum", "fmax"),
|
|
1367
|
+
"fmax": _float_binary_math_handler("fmax", "fmax"),
|
|
1368
|
+
"minimum": _float_binary_math_handler("minimum", "fmin"),
|
|
1369
|
+
"fmin": _float_binary_math_handler("fmin", "fmin"),
|
|
1370
|
+
"le": _float_comparison_handler("le", "<="),
|
|
1371
|
+
"lt": _float_comparison_handler("lt", "<"),
|
|
1372
|
+
"ge": _float_comparison_handler("ge", ">="),
|
|
1373
|
+
"gt": _float_comparison_handler("gt", ">"),
|
|
1374
|
+
"eq": _float_comparison_handler("eq", "=="),
|
|
1375
|
+
"ne": _float_comparison_handler("ne", "!="),
|
|
1376
|
+
"logical_or": _float_logical_or,
|
|
1377
|
+
"logical_and": _float_logical_and,
|
|
1378
|
+
"logical_xor": _float_logical_xor,
|
|
1379
|
+
"logical_not": _float_logical_not,
|
|
1380
|
+
"copysign": _float_binary_math_handler("copysign"),
|
|
1381
|
+
"hypot": _float_binary_math_handler("hypot"),
|
|
1382
|
+
"atan2": _float_binary_math_handler("atan2"),
|
|
1383
|
+
"pow": _float_binary_math_handler("pow"),
|
|
1384
|
+
"fmod": _float_binary_math_handler("fmod"),
|
|
1385
|
+
"remainder": _float_remainder,
|
|
1386
|
+
"floor_divide": _float_floor_divide,
|
|
1387
|
+
"logaddexp": _float_logaddexp,
|
|
1388
|
+
"logaddexp2": _float_logaddexp2,
|
|
1389
|
+
"nextafter": _float_binary_math_handler("nextafter"),
|
|
1390
|
+
"xlogy": _float_xlogy,
|
|
1391
|
+
"heaviside": _float_heaviside,
|
|
1392
|
+
"ldexp": _float_ldexp,
|
|
1393
|
+
"clamp_min": _float_clamp_min,
|
|
1394
|
+
"clamp_max": _float_clamp_max,
|
|
1395
|
+
"neg": lambda dtype_info: _simple_unary(dtype_info, "neg", "-a"),
|
|
1396
|
+
"reciprocal": _float_reciprocal,
|
|
1397
|
+
"relu": _float_relu,
|
|
1398
|
+
"ceil": _float_unary_math_handler("ceil"),
|
|
1399
|
+
"floor": _float_unary_math_handler("floor"),
|
|
1400
|
+
"sin": _float_unary_math_handler("sin"),
|
|
1401
|
+
"cos": _float_unary_math_handler("cos"),
|
|
1402
|
+
"sqrt": _float_unary_math_handler("sqrt"),
|
|
1403
|
+
"cbrt": _float_unary_math_handler("cbrt"),
|
|
1404
|
+
"exp": _float_unary_math_handler("exp"),
|
|
1405
|
+
"tanh": _float_unary_math_handler("tanh"),
|
|
1406
|
+
"log": _float_unary_math_handler("log"),
|
|
1407
|
+
"acos": _float_unary_math_handler("acos"),
|
|
1408
|
+
"acosh": _float_unary_math_handler("acosh"),
|
|
1409
|
+
"asin": _float_unary_math_handler("asin"),
|
|
1410
|
+
"asinh": _float_unary_math_handler("asinh"),
|
|
1411
|
+
"atan": _float_unary_math_handler("atan"),
|
|
1412
|
+
"atanh": _float_unary_math_handler("atanh"),
|
|
1413
|
+
"cosh": _float_unary_math_handler("cosh"),
|
|
1414
|
+
"sinh": _float_unary_math_handler("sinh"),
|
|
1415
|
+
"tan": _float_unary_math_handler("tan"),
|
|
1416
|
+
"erf": _float_unary_math_handler("erf"),
|
|
1417
|
+
"erfc": _float_unary_math_handler("erfc"),
|
|
1418
|
+
"expm1": _float_unary_math_handler("expm1"),
|
|
1419
|
+
"log1p": _float_unary_math_handler("log1p"),
|
|
1420
|
+
"log2": _float_unary_math_handler("log2"),
|
|
1421
|
+
"log10": _float_unary_math_handler("log10"),
|
|
1422
|
+
"exp2": _float_unary_math_handler("exp2"),
|
|
1423
|
+
"lgamma": _float_unary_math_handler("lgamma"),
|
|
1424
|
+
"isfinite": _float_isfinite,
|
|
1425
|
+
"rsqrt": _float_rsqrt,
|
|
1426
|
+
"sigmoid": _float_sigmoid,
|
|
1427
|
+
"log_sigmoid": _float_log_sigmoid,
|
|
1428
|
+
"gelu": _float_gelu,
|
|
1429
|
+
"elu": _float_elu,
|
|
1430
|
+
"leaky_relu": _float_leaky_relu,
|
|
1431
|
+
"softplus": _float_softplus,
|
|
1432
|
+
"softsign": _float_softsign,
|
|
1433
|
+
"silu": _float_silu,
|
|
1434
|
+
"mish": _float_mish,
|
|
1435
|
+
"selu": _float_selu,
|
|
1436
|
+
"relu6": _float_relu6,
|
|
1437
|
+
"hardsigmoid": _float_hardsigmoid,
|
|
1438
|
+
"hardswish": _float_hardswish,
|
|
1439
|
+
"thresholded_relu": _float_thresholded_relu,
|
|
1440
|
+
"sign": _float_sign,
|
|
1441
|
+
"round": _float_round,
|
|
1442
|
+
"trunc": _float_trunc,
|
|
1443
|
+
"angle": _float_angle,
|
|
1444
|
+
"conj": lambda dtype_info: _float_conj(dtype_info, "conj"),
|
|
1445
|
+
"conj_physical": lambda dtype_info: _float_conj(dtype_info, "conj_physical"),
|
|
1446
|
+
"deg2rad": _float_deg2rad,
|
|
1447
|
+
"digamma": _float_digamma,
|
|
1448
|
+
"erfinv": _float_erfinv,
|
|
1449
|
+
"frac": _float_frac,
|
|
1450
|
+
"i0": _float_i0,
|
|
1451
|
+
"logit": _float_logit,
|
|
1452
|
+
"isnan": _float_isnan,
|
|
1453
|
+
"isinf": _float_isinf,
|
|
1454
|
+
"isneginf": _float_isneginf,
|
|
1455
|
+
"isposinf": _float_isposinf,
|
|
1456
|
+
"nan_to_num": _float_nan_to_num,
|
|
1457
|
+
"positive": _float_positive,
|
|
1458
|
+
"rad2deg": _float_rad2deg,
|
|
1459
|
+
"real": lambda dtype_info: _simple_unary(dtype_info, "real", "a"),
|
|
1460
|
+
"sgn": _float_sgn,
|
|
1461
|
+
"sinc": _float_sinc,
|
|
1462
|
+
"square": _float_square,
|
|
1463
|
+
}
|
|
1464
|
+
|
|
1465
|
+
_PARAMETERIZED_FLOAT_OPS: Mapping[
|
|
1466
|
+
ScalarFunction,
|
|
1467
|
+
Callable[[_ScalarTypeInfo, tuple[float, ...], str], _GeneratedScalar],
|
|
1468
|
+
] = {
|
|
1469
|
+
ScalarFunction.AFFINE: _float_affine,
|
|
1470
|
+
ScalarFunction.CELU: _float_celu,
|
|
1471
|
+
ScalarFunction.ELU: _float_elu_param,
|
|
1472
|
+
ScalarFunction.HARDSIGMOID: _float_hardsigmoid_param,
|
|
1473
|
+
ScalarFunction.LEAKY_RELU: _float_leaky_relu_param,
|
|
1474
|
+
ScalarFunction.SCALED_TANH: _float_scaled_tanh,
|
|
1475
|
+
ScalarFunction.SHRINK: _float_shrink,
|
|
1476
|
+
ScalarFunction.SWISH: _float_swish,
|
|
1477
|
+
ScalarFunction.THRESHOLDED_RELU: _float_thresholded_relu_param,
|
|
1478
|
+
}
|
|
1479
|
+
|
|
1480
|
+
|
|
1481
|
+
def _float_from_ops(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
|
|
1482
|
+
canonical_name = _normalize_op_name(name)
|
|
1483
|
+
if canonical_name != name:
|
|
1484
|
+
lines = [
|
|
1485
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a) {{",
|
|
1486
|
+
f" return {dtype_info.prefix}{canonical_name}(a);",
|
|
1487
|
+
"}",
|
|
1488
|
+
]
|
|
1489
|
+
deps = {_scalar_key_from_op(dtype_info, canonical_name)}
|
|
1490
|
+
return _GeneratedScalar(lines=lines, deps=deps, includes=set())
|
|
1491
|
+
name = canonical_name
|
|
1492
|
+
handler = _FLOAT_OP_DISPATCH.get(name)
|
|
1493
|
+
if handler is None:
|
|
1494
|
+
raise ScalarFunctionError(f"unsupported float scalar op: {name}")
|
|
1495
|
+
return handler(dtype_info)
|
|
1496
|
+
|
|
1497
|
+
|
|
1498
|
+
def _int_from_f32(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1499
|
+
lines: List[str] = []
|
|
1500
|
+
includes: Set[str] = {"#include <math.h>", "#include <stdint.h>"}
|
|
1501
|
+
if dtype_info.is_signed:
|
|
1502
|
+
includes.add("#include <limits.h>")
|
|
1503
|
+
min_name = f"INT{dtype_info.bits}_MIN"
|
|
1504
|
+
max_name = f"INT{dtype_info.bits}_MAX"
|
|
1505
|
+
lines.extend(
|
|
1506
|
+
[
|
|
1507
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}from_f32(float value) {{",
|
|
1508
|
+
" if (!isfinite(value)) {",
|
|
1509
|
+
f" return {min_name};",
|
|
1510
|
+
" }",
|
|
1511
|
+
f" if (value > (float){max_name}) {{",
|
|
1512
|
+
f" return {max_name};",
|
|
1513
|
+
" }",
|
|
1514
|
+
f" if (value < (float){min_name}) {{",
|
|
1515
|
+
f" return {min_name};",
|
|
1516
|
+
" }",
|
|
1517
|
+
f" return ({dtype_info.c_type})value;",
|
|
1518
|
+
"}",
|
|
1519
|
+
]
|
|
1520
|
+
)
|
|
1521
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=includes)
|
|
1522
|
+
max_name = f"UINT{dtype_info.bits}_MAX"
|
|
1523
|
+
if dtype_info.bits in {32, 64}:
|
|
1524
|
+
lines.extend(
|
|
1525
|
+
[
|
|
1526
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}from_f32(float value) {{",
|
|
1527
|
+
" if (!isfinite(value)) {",
|
|
1528
|
+
" return 0;",
|
|
1529
|
+
" }",
|
|
1530
|
+
" if (value <= 0.0f) {",
|
|
1531
|
+
" return 0;",
|
|
1532
|
+
" }",
|
|
1533
|
+
f" if (value >= (float){max_name}) {{",
|
|
1534
|
+
f" return {max_name};",
|
|
1535
|
+
" }",
|
|
1536
|
+
f" return ({dtype_info.c_type})value;",
|
|
1537
|
+
"}",
|
|
1538
|
+
]
|
|
1539
|
+
)
|
|
1540
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=includes)
|
|
1541
|
+
lines.extend(
|
|
1542
|
+
[
|
|
1543
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}from_f32(float value) {{",
|
|
1544
|
+
" if (!isfinite(value)) {",
|
|
1545
|
+
" return 0;",
|
|
1546
|
+
" }",
|
|
1547
|
+
f" return ({dtype_info.c_type})value;",
|
|
1548
|
+
"}",
|
|
1549
|
+
]
|
|
1550
|
+
)
|
|
1551
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=includes)
|
|
1552
|
+
|
|
1553
|
+
|
|
1554
|
+
def _int_unary_from_f32(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
|
|
1555
|
+
lines = [
|
|
1556
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a) {{",
|
|
1557
|
+
f" return {dtype_info.prefix}from_f32(ref_scalar_f32_{name}((float)a));",
|
|
1558
|
+
"}",
|
|
1559
|
+
]
|
|
1560
|
+
deps = {
|
|
1561
|
+
_conversion_key_from_alias(dtype_info, "from_f32"),
|
|
1562
|
+
_scalar_key_from_op(_SCALAR_TYPES[ScalarType.F32], name),
|
|
1563
|
+
}
|
|
1564
|
+
return _GeneratedScalar(lines=lines, deps=deps, includes=set())
|
|
1565
|
+
|
|
1566
|
+
|
|
1567
|
+
def _int_binary_from_f32(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
|
|
1568
|
+
lines = [
|
|
1569
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
1570
|
+
f" return {dtype_info.prefix}from_f32(ref_scalar_f32_{name}((float)a, (float)b));",
|
|
1571
|
+
"}",
|
|
1572
|
+
]
|
|
1573
|
+
deps = {
|
|
1574
|
+
_conversion_key_from_alias(dtype_info, "from_f32"),
|
|
1575
|
+
_scalar_key_from_op(_SCALAR_TYPES[ScalarType.F32], name),
|
|
1576
|
+
}
|
|
1577
|
+
return _GeneratedScalar(lines=lines, deps=deps, includes=set())
|
|
1578
|
+
|
|
1579
|
+
|
|
1580
|
+
def _int_bool_literal(dtype_info: _ScalarTypeInfo, value: bool) -> str:
|
|
1581
|
+
if dtype_info.is_small_int:
|
|
1582
|
+
return f"({dtype_info.c_type}){1 if value else 0}"
|
|
1583
|
+
return "1" if value else "0"
|
|
1584
|
+
|
|
1585
|
+
|
|
1586
|
+
def _int_binary_op(dtype_info: _ScalarTypeInfo, name: str, op: str) -> _GeneratedScalar:
|
|
1587
|
+
expr = _cast_value(f"a {op} b", dtype_info)
|
|
1588
|
+
return _simple_binary(dtype_info, name, expr)
|
|
1589
|
+
|
|
1590
|
+
|
|
1591
|
+
def _int_unary_op(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
|
|
1592
|
+
return _simple_unary(dtype_info, name, _cast_value(expr, dtype_info))
|
|
1593
|
+
|
|
1594
|
+
|
|
1595
|
+
def _int_comparison(dtype_info: _ScalarTypeInfo, name: str, op: str) -> _GeneratedScalar:
|
|
1596
|
+
one = _int_bool_literal(dtype_info, True)
|
|
1597
|
+
zero = _int_bool_literal(dtype_info, False)
|
|
1598
|
+
return _simple_binary(dtype_info, name, f"a {op} b ? {one} : {zero}")
|
|
1599
|
+
|
|
1600
|
+
|
|
1601
|
+
def _int_logical(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
|
|
1602
|
+
one = _int_bool_literal(dtype_info, True)
|
|
1603
|
+
zero = _int_bool_literal(dtype_info, False)
|
|
1604
|
+
return _simple_binary(dtype_info, name, f"{expr} ? {one} : {zero}")
|
|
1605
|
+
|
|
1606
|
+
|
|
1607
|
+
def _int_logical_not(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1608
|
+
zero = _int_bool_literal(dtype_info, False)
|
|
1609
|
+
one = _int_bool_literal(dtype_info, True)
|
|
1610
|
+
return _simple_unary(dtype_info, "logical_not", f"a == {zero} ? {one} : {zero}")
|
|
1611
|
+
|
|
1612
|
+
|
|
1613
|
+
def _int_abs(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1614
|
+
if not dtype_info.is_signed:
|
|
1615
|
+
return _simple_unary(dtype_info, "abs", "a")
|
|
1616
|
+
min_name = f"INT{dtype_info.bits}_MIN"
|
|
1617
|
+
includes = {"#include <limits.h>"}
|
|
1618
|
+
lines = [
|
|
1619
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}abs({dtype_info.c_type} a) {{",
|
|
1620
|
+
f" if (a == {min_name}) {{",
|
|
1621
|
+
f" return {min_name};",
|
|
1622
|
+
" }",
|
|
1623
|
+
" return a < 0 ? -a : a;",
|
|
1624
|
+
"}",
|
|
1625
|
+
]
|
|
1626
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=includes)
|
|
1627
|
+
|
|
1628
|
+
|
|
1629
|
+
def _int_absolute(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1630
|
+
lines = [
|
|
1631
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}absolute({dtype_info.c_type} a) {{",
|
|
1632
|
+
f" return {dtype_info.prefix}abs(a);",
|
|
1633
|
+
"}",
|
|
1634
|
+
]
|
|
1635
|
+
deps = {_scalar_key_from_op(dtype_info, "abs")}
|
|
1636
|
+
return _GeneratedScalar(lines=lines, deps=deps, includes=set())
|
|
1637
|
+
|
|
1638
|
+
|
|
1639
|
+
def _int_div(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1640
|
+
expr = _cast_value("a / b", dtype_info)
|
|
1641
|
+
lines = [
|
|
1642
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}div({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
1643
|
+
" if (b == 0) {",
|
|
1644
|
+
" return 0;",
|
|
1645
|
+
" }",
|
|
1646
|
+
f" return {expr};",
|
|
1647
|
+
"}",
|
|
1648
|
+
]
|
|
1649
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1650
|
+
|
|
1651
|
+
|
|
1652
|
+
def _int_fmod(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1653
|
+
expr = _cast_value("a % b", dtype_info)
|
|
1654
|
+
lines = [
|
|
1655
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}fmod({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
1656
|
+
" if (b == 0) {",
|
|
1657
|
+
" return 0;",
|
|
1658
|
+
" }",
|
|
1659
|
+
f" return {expr};",
|
|
1660
|
+
"}",
|
|
1661
|
+
]
|
|
1662
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1663
|
+
|
|
1664
|
+
|
|
1665
|
+
def _int_remainder(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1666
|
+
expr = _cast_value("a % b", dtype_info)
|
|
1667
|
+
lines = [
|
|
1668
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}remainder({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
1669
|
+
" if (b == 0) {",
|
|
1670
|
+
" return 0;",
|
|
1671
|
+
" }",
|
|
1672
|
+
f" {dtype_info.c_type} mod = {expr};",
|
|
1673
|
+
" if (mod == 0) {",
|
|
1674
|
+
" return mod;",
|
|
1675
|
+
" }",
|
|
1676
|
+
" if ((mod < 0) != (b < 0)) {",
|
|
1677
|
+
" mod += b;",
|
|
1678
|
+
" }",
|
|
1679
|
+
" return mod;",
|
|
1680
|
+
"}",
|
|
1681
|
+
]
|
|
1682
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1683
|
+
|
|
1684
|
+
|
|
1685
|
+
def _int_floor_divide(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1686
|
+
expr = _cast_value("a / b", dtype_info)
|
|
1687
|
+
lines = [
|
|
1688
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}floor_divide({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
1689
|
+
" if (b == 0) {",
|
|
1690
|
+
" return 0;",
|
|
1691
|
+
" }",
|
|
1692
|
+
]
|
|
1693
|
+
if dtype_info.is_signed:
|
|
1694
|
+
lines.extend(
|
|
1695
|
+
[
|
|
1696
|
+
f" {dtype_info.c_type} quo = a / b;",
|
|
1697
|
+
f" {dtype_info.c_type} rem = a % b;",
|
|
1698
|
+
" if (rem != 0 && ((rem < 0) != (b < 0))) {",
|
|
1699
|
+
" quo -= 1;",
|
|
1700
|
+
" }",
|
|
1701
|
+
" return quo;",
|
|
1702
|
+
"}",
|
|
1703
|
+
]
|
|
1704
|
+
)
|
|
1705
|
+
else:
|
|
1706
|
+
lines.extend(
|
|
1707
|
+
[
|
|
1708
|
+
f" return {expr};",
|
|
1709
|
+
"}",
|
|
1710
|
+
]
|
|
1711
|
+
)
|
|
1712
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1713
|
+
|
|
1714
|
+
|
|
1715
|
+
def _int_copysign(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1716
|
+
lines = [
|
|
1717
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}copysign({dtype_info.c_type} a, {dtype_info.c_type} b) {{",
|
|
1718
|
+
f" {dtype_info.c_type} magnitude = {dtype_info.prefix}abs(a);",
|
|
1719
|
+
" return b < 0 ? -magnitude : magnitude;",
|
|
1720
|
+
"}",
|
|
1721
|
+
]
|
|
1722
|
+
deps = {_scalar_key_from_op(dtype_info, "abs")}
|
|
1723
|
+
return _GeneratedScalar(lines=lines, deps=deps, includes=set())
|
|
1724
|
+
|
|
1725
|
+
|
|
1726
|
+
def _int_neg(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1727
|
+
if dtype_info.is_signed:
|
|
1728
|
+
min_name = f"INT{dtype_info.bits}_MIN"
|
|
1729
|
+
includes = {"#include <limits.h>"}
|
|
1730
|
+
lines = [
|
|
1731
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}neg({dtype_info.c_type} a) {{",
|
|
1732
|
+
f" if (a == {min_name}) {{",
|
|
1733
|
+
f" return {min_name};",
|
|
1734
|
+
" }",
|
|
1735
|
+
" return -a;",
|
|
1736
|
+
"}",
|
|
1737
|
+
]
|
|
1738
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=includes)
|
|
1739
|
+
expr = _cast_value("0 - a", dtype_info)
|
|
1740
|
+
return _simple_unary(dtype_info, "neg", expr)
|
|
1741
|
+
|
|
1742
|
+
|
|
1743
|
+
def _int_reciprocal(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1744
|
+
expr = _cast_value("1 / a", dtype_info)
|
|
1745
|
+
lines = [
|
|
1746
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}reciprocal({dtype_info.c_type} a) {{",
|
|
1747
|
+
" if (a == 0) {",
|
|
1748
|
+
" return 0;",
|
|
1749
|
+
" }",
|
|
1750
|
+
f" return {expr};",
|
|
1751
|
+
"}",
|
|
1752
|
+
]
|
|
1753
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1754
|
+
|
|
1755
|
+
|
|
1756
|
+
def _int_relu(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1757
|
+
if dtype_info.is_signed:
|
|
1758
|
+
return _simple_unary(dtype_info, "relu", "a > 0 ? a : 0")
|
|
1759
|
+
return _simple_unary(dtype_info, "relu", "a")
|
|
1760
|
+
|
|
1761
|
+
|
|
1762
|
+
def _int_ceil_floor(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
|
|
1763
|
+
return _simple_unary(dtype_info, name, "a")
|
|
1764
|
+
|
|
1765
|
+
|
|
1766
|
+
def _int_round(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1767
|
+
return _simple_unary(dtype_info, "round", "a")
|
|
1768
|
+
|
|
1769
|
+
|
|
1770
|
+
def _int_trunc(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1771
|
+
return _simple_unary(dtype_info, "trunc", "a")
|
|
1772
|
+
|
|
1773
|
+
|
|
1774
|
+
def _int_frac(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1775
|
+
lines = [
|
|
1776
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}frac({dtype_info.c_type} a) {{",
|
|
1777
|
+
" (void)a;",
|
|
1778
|
+
" return 0;",
|
|
1779
|
+
"}",
|
|
1780
|
+
]
|
|
1781
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1782
|
+
|
|
1783
|
+
|
|
1784
|
+
def _int_sign(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1785
|
+
if not dtype_info.is_signed:
|
|
1786
|
+
return _simple_unary(dtype_info, "sign", "a > 0 ? 1 : 0")
|
|
1787
|
+
lines = [
|
|
1788
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}sign({dtype_info.c_type} a) {{",
|
|
1789
|
+
" if (a > 0) {",
|
|
1790
|
+
" return 1;",
|
|
1791
|
+
" }",
|
|
1792
|
+
" if (a < 0) {",
|
|
1793
|
+
" return -1;",
|
|
1794
|
+
" }",
|
|
1795
|
+
" return 0;",
|
|
1796
|
+
"}",
|
|
1797
|
+
]
|
|
1798
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1799
|
+
|
|
1800
|
+
|
|
1801
|
+
def _int_conj(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
|
|
1802
|
+
return _simple_unary(dtype_info, name, "a")
|
|
1803
|
+
|
|
1804
|
+
|
|
1805
|
+
def _int_positive(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1806
|
+
return _simple_unary(dtype_info, "positive", "a")
|
|
1807
|
+
|
|
1808
|
+
|
|
1809
|
+
def _int_sgn(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1810
|
+
if not dtype_info.is_signed:
|
|
1811
|
+
return _simple_unary(dtype_info, "sgn", "a > 0 ? 1 : 0")
|
|
1812
|
+
lines = [
|
|
1813
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}sgn({dtype_info.c_type} a) {{",
|
|
1814
|
+
" if (a > 0) {",
|
|
1815
|
+
" return 1;",
|
|
1816
|
+
" }",
|
|
1817
|
+
" if (a < 0) {",
|
|
1818
|
+
" return -1;",
|
|
1819
|
+
" }",
|
|
1820
|
+
" return 0;",
|
|
1821
|
+
"}",
|
|
1822
|
+
]
|
|
1823
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1824
|
+
|
|
1825
|
+
|
|
1826
|
+
def _int_square(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1827
|
+
return _simple_unary(dtype_info, "square", _cast_value("a * a", dtype_info))
|
|
1828
|
+
|
|
1829
|
+
def _int_binary_op_handler(name: str, op: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
|
|
1830
|
+
def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1831
|
+
return _int_binary_op(dtype_info, name, op)
|
|
1832
|
+
|
|
1833
|
+
return handler
|
|
1834
|
+
|
|
1835
|
+
|
|
1836
|
+
def _int_simple_binary_handler(name: str, expr: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
|
|
1837
|
+
def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1838
|
+
return _simple_binary(dtype_info, name, expr)
|
|
1839
|
+
|
|
1840
|
+
return handler
|
|
1841
|
+
|
|
1842
|
+
|
|
1843
|
+
def _int_comparison_handler(name: str, op: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
|
|
1844
|
+
def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1845
|
+
return _int_comparison(dtype_info, name, op)
|
|
1846
|
+
|
|
1847
|
+
return handler
|
|
1848
|
+
|
|
1849
|
+
|
|
1850
|
+
def _int_logical_handler(name: str, expr: str) -> Callable[[_ScalarTypeInfo], _GeneratedScalar]:
|
|
1851
|
+
def handler(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1852
|
+
return _int_logical(dtype_info, name, expr)
|
|
1853
|
+
|
|
1854
|
+
return handler
|
|
1855
|
+
|
|
1856
|
+
|
|
1857
|
+
_INT_OP_DISPATCH: Mapping[str, Callable[[_ScalarTypeInfo], _GeneratedScalar]] = {
|
|
1858
|
+
"abs": _int_abs,
|
|
1859
|
+
"absolute": _int_absolute,
|
|
1860
|
+
"add": _int_binary_op_handler("add", "+"),
|
|
1861
|
+
"sub": _int_binary_op_handler("sub", "-"),
|
|
1862
|
+
"mul": _int_binary_op_handler("mul", "*"),
|
|
1863
|
+
"bitwise_and": _int_binary_op_handler("bitwise_and", "&"),
|
|
1864
|
+
"bitwise_or": _int_binary_op_handler("bitwise_or", "|"),
|
|
1865
|
+
"bitwise_xor": _int_binary_op_handler("bitwise_xor", "^"),
|
|
1866
|
+
"bitwise_left_shift": _int_binary_op_handler("bitwise_left_shift", "<<"),
|
|
1867
|
+
"bitwise_right_shift": _int_binary_op_handler("bitwise_right_shift", ">>"),
|
|
1868
|
+
"bitwise_not": lambda dtype_info: _int_unary_op(dtype_info, "bitwise_not", "~a"),
|
|
1869
|
+
"div": _int_div,
|
|
1870
|
+
"maximum": _int_simple_binary_handler("maximum", "a > b ? a : b"),
|
|
1871
|
+
"minimum": _int_simple_binary_handler("minimum", "a < b ? a : b"),
|
|
1872
|
+
"le": _int_comparison_handler("le", "<="),
|
|
1873
|
+
"lt": _int_comparison_handler("lt", "<"),
|
|
1874
|
+
"ge": _int_comparison_handler("ge", ">="),
|
|
1875
|
+
"gt": _int_comparison_handler("gt", ">"),
|
|
1876
|
+
"eq": _int_comparison_handler("eq", "=="),
|
|
1877
|
+
"ne": _int_comparison_handler("ne", "!="),
|
|
1878
|
+
"logical_or": _int_logical_handler("logical_or", "(a != 0 || b != 0)"),
|
|
1879
|
+
"logical_and": _int_logical_handler("logical_and", "(a != 0 && b != 0)"),
|
|
1880
|
+
"logical_xor": _int_logical_handler("logical_xor", "((a != 0) != (b != 0))"),
|
|
1881
|
+
"logical_not": _int_logical_not,
|
|
1882
|
+
"fmax": _int_simple_binary_handler("fmax", "a > b ? a : b"),
|
|
1883
|
+
"fmin": _int_simple_binary_handler("fmin", "a < b ? a : b"),
|
|
1884
|
+
"copysign": _int_copysign,
|
|
1885
|
+
"fmod": _int_fmod,
|
|
1886
|
+
"remainder": _int_remainder,
|
|
1887
|
+
"floor_divide": _int_floor_divide,
|
|
1888
|
+
"clamp_min": _int_simple_binary_handler("clamp_min", "a > b ? a : b"),
|
|
1889
|
+
"clamp_max": _int_simple_binary_handler("clamp_max", "a < b ? a : b"),
|
|
1890
|
+
"neg": _int_neg,
|
|
1891
|
+
"reciprocal": _int_reciprocal,
|
|
1892
|
+
"relu": _int_relu,
|
|
1893
|
+
"ceil": lambda dtype_info: _int_ceil_floor(dtype_info, "ceil"),
|
|
1894
|
+
"floor": lambda dtype_info: _int_ceil_floor(dtype_info, "floor"),
|
|
1895
|
+
"round": _int_round,
|
|
1896
|
+
"trunc": _int_trunc,
|
|
1897
|
+
"frac": _int_frac,
|
|
1898
|
+
"sign": _int_sign,
|
|
1899
|
+
"conj": lambda dtype_info: _int_conj(dtype_info, "conj"),
|
|
1900
|
+
"conj_physical": lambda dtype_info: _int_conj(dtype_info, "conj_physical"),
|
|
1901
|
+
"positive": _int_positive,
|
|
1902
|
+
"real": lambda dtype_info: _simple_unary(dtype_info, "real", "a"),
|
|
1903
|
+
"sgn": _int_sgn,
|
|
1904
|
+
"square": _int_square,
|
|
1905
|
+
}
|
|
1906
|
+
|
|
1907
|
+
|
|
1908
|
+
def _int_from_ops(dtype_info: _ScalarTypeInfo, name: str) -> _GeneratedScalar:
|
|
1909
|
+
canonical_name = _normalize_op_name(name)
|
|
1910
|
+
if canonical_name != name:
|
|
1911
|
+
lines = [
|
|
1912
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a) {{",
|
|
1913
|
+
f" return {dtype_info.prefix}{canonical_name}(a);",
|
|
1914
|
+
"}",
|
|
1915
|
+
]
|
|
1916
|
+
deps = {_scalar_key_from_op(dtype_info, canonical_name)}
|
|
1917
|
+
return _GeneratedScalar(lines=lines, deps=deps, includes=set())
|
|
1918
|
+
name = canonical_name
|
|
1919
|
+
if name == "from_f32":
|
|
1920
|
+
return _int_from_f32(dtype_info)
|
|
1921
|
+
handler = _INT_OP_DISPATCH.get(name)
|
|
1922
|
+
if handler is not None:
|
|
1923
|
+
return handler(dtype_info)
|
|
1924
|
+
function = ScalarFunction.from_op_name(name)
|
|
1925
|
+
if function.int_from_f32_arity == 1:
|
|
1926
|
+
return _int_unary_from_f32(dtype_info, name)
|
|
1927
|
+
if function.int_from_f32_arity == 2:
|
|
1928
|
+
return _int_binary_from_f32(dtype_info, name)
|
|
1929
|
+
raise ScalarFunctionError(f"unsupported int scalar op: {name}")
|
|
1930
|
+
|
|
1931
|
+
|
|
1932
|
+
def _bool_to_f32() -> _GeneratedScalar:
|
|
1933
|
+
lines = [
|
|
1934
|
+
"static inline float ref_scalar_bool_to_f32(bool value) {",
|
|
1935
|
+
" return value ? 1.0f : 0.0f;",
|
|
1936
|
+
"}",
|
|
1937
|
+
]
|
|
1938
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1939
|
+
|
|
1940
|
+
|
|
1941
|
+
def _bool_from_f32() -> _GeneratedScalar:
|
|
1942
|
+
lines = [
|
|
1943
|
+
"static inline bool ref_scalar_bool_from_f32(float value) {",
|
|
1944
|
+
" return value != 0.0f;",
|
|
1945
|
+
"}",
|
|
1946
|
+
]
|
|
1947
|
+
return _GeneratedScalar(lines=lines, deps=set(), includes=set())
|
|
1948
|
+
|
|
1949
|
+
|
|
1950
|
+
def _bool_bitwise(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
|
|
1951
|
+
return _simple_binary(dtype_info, name, expr)
|
|
1952
|
+
|
|
1953
|
+
|
|
1954
|
+
def _bool_bitwise_not(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1955
|
+
return _simple_unary(dtype_info, "bitwise_not", "!a")
|
|
1956
|
+
|
|
1957
|
+
|
|
1958
|
+
def _bool_logical(dtype_info: _ScalarTypeInfo, name: str, expr: str) -> _GeneratedScalar:
|
|
1959
|
+
return _simple_binary(dtype_info, name, expr)
|
|
1960
|
+
|
|
1961
|
+
|
|
1962
|
+
def _bool_logical_not(dtype_info: _ScalarTypeInfo) -> _GeneratedScalar:
|
|
1963
|
+
return _simple_unary(dtype_info, "logical_not", "!a")
|
|
1964
|
+
|
|
1965
|
+
|
|
1966
|
+
def _bool_comparison(dtype_info: _ScalarTypeInfo, name: str, op: str) -> _GeneratedScalar:
|
|
1967
|
+
return _simple_binary(dtype_info, name, f"a {op} b")
|
|
1968
|
+
|
|
1969
|
+
|
|
1970
|
+
def _bool_unary_from_f32(name: str) -> _GeneratedScalar:
|
|
1971
|
+
lines = [
|
|
1972
|
+
f"static inline bool ref_scalar_bool_{name}(bool a) {{",
|
|
1973
|
+
f" return ref_scalar_bool_from_f32(ref_scalar_f32_{name}(ref_scalar_bool_to_f32(a)));",
|
|
1974
|
+
"}",
|
|
1975
|
+
]
|
|
1976
|
+
bool_info = _SCALAR_TYPES[ScalarType.BOOL]
|
|
1977
|
+
deps = {
|
|
1978
|
+
_conversion_key_from_alias(bool_info, "from_f32"),
|
|
1979
|
+
_conversion_key_from_alias(bool_info, "to_f32"),
|
|
1980
|
+
_scalar_key_from_op(_SCALAR_TYPES[ScalarType.F32], name),
|
|
1981
|
+
}
|
|
1982
|
+
return _GeneratedScalar(lines=lines, deps=deps, includes=set())
|
|
1983
|
+
|
|
1984
|
+
|
|
1985
|
+
def _bool_binary_from_f32(name: str) -> _GeneratedScalar:
|
|
1986
|
+
lines = [
|
|
1987
|
+
f"static inline bool ref_scalar_bool_{name}(bool a, bool b) {{",
|
|
1988
|
+
" return ref_scalar_bool_from_f32(",
|
|
1989
|
+
f" ref_scalar_f32_{name}(ref_scalar_bool_to_f32(a), ref_scalar_bool_to_f32(b))",
|
|
1990
|
+
" );",
|
|
1991
|
+
"}",
|
|
1992
|
+
]
|
|
1993
|
+
bool_info = _SCALAR_TYPES[ScalarType.BOOL]
|
|
1994
|
+
deps = {
|
|
1995
|
+
_conversion_key_from_alias(bool_info, "from_f32"),
|
|
1996
|
+
_conversion_key_from_alias(bool_info, "to_f32"),
|
|
1997
|
+
_scalar_key_from_op(_SCALAR_TYPES[ScalarType.F32], name),
|
|
1998
|
+
}
|
|
1999
|
+
return _GeneratedScalar(lines=lines, deps=deps, includes=set())
|
|
2000
|
+
|
|
2001
|
+
|
|
2002
|
+
_BOOL_OP_DISPATCH: Mapping[str, Callable[[], _GeneratedScalar]] = {
|
|
2003
|
+
"to_f32": _bool_to_f32,
|
|
2004
|
+
"from_f32": _bool_from_f32,
|
|
2005
|
+
"bitwise_and": lambda: _simple_binary(_SCALAR_TYPES[ScalarType.BOOL], "bitwise_and", "a & b"),
|
|
2006
|
+
"bitwise_or": lambda: _simple_binary(_SCALAR_TYPES[ScalarType.BOOL], "bitwise_or", "a | b"),
|
|
2007
|
+
"bitwise_xor": lambda: _simple_binary(_SCALAR_TYPES[ScalarType.BOOL], "bitwise_xor", "a ^ b"),
|
|
2008
|
+
"bitwise_not": lambda: _bool_bitwise_not(_SCALAR_TYPES[ScalarType.BOOL]),
|
|
2009
|
+
"logical_or": lambda: _bool_logical(_SCALAR_TYPES[ScalarType.BOOL], "logical_or", "a || b"),
|
|
2010
|
+
"logical_and": lambda: _bool_logical(_SCALAR_TYPES[ScalarType.BOOL], "logical_and", "a && b"),
|
|
2011
|
+
"logical_xor": lambda: _bool_logical(_SCALAR_TYPES[ScalarType.BOOL], "logical_xor", "a != b"),
|
|
2012
|
+
"logical_not": lambda: _bool_logical_not(_SCALAR_TYPES[ScalarType.BOOL]),
|
|
2013
|
+
"le": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "le", "<="),
|
|
2014
|
+
"lt": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "lt", "<"),
|
|
2015
|
+
"ge": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "ge", ">="),
|
|
2016
|
+
"gt": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "gt", ">"),
|
|
2017
|
+
"eq": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "eq", "=="),
|
|
2018
|
+
"ne": lambda: _bool_comparison(_SCALAR_TYPES[ScalarType.BOOL], "ne", "!="),
|
|
2019
|
+
}
|
|
2020
|
+
|
|
2021
|
+
|
|
2022
|
+
def _bool_from_ops(name: str) -> _GeneratedScalar:
|
|
2023
|
+
canonical_name = _normalize_op_name(name)
|
|
2024
|
+
if canonical_name != name:
|
|
2025
|
+
dtype_info = _SCALAR_TYPES[ScalarType.BOOL]
|
|
2026
|
+
lines = [
|
|
2027
|
+
f"static inline {dtype_info.c_type} {dtype_info.prefix}{name}({dtype_info.c_type} a) {{",
|
|
2028
|
+
f" return {dtype_info.prefix}{canonical_name}(a);",
|
|
2029
|
+
"}",
|
|
2030
|
+
]
|
|
2031
|
+
deps = {_scalar_key_from_op(dtype_info, canonical_name)}
|
|
2032
|
+
return _GeneratedScalar(lines=lines, deps=deps, includes=set())
|
|
2033
|
+
name = canonical_name
|
|
2034
|
+
handler = _BOOL_OP_DISPATCH.get(name)
|
|
2035
|
+
if handler is not None:
|
|
2036
|
+
return handler()
|
|
2037
|
+
function = ScalarFunction.from_op_name(name)
|
|
2038
|
+
if function.bool_from_f32_arity == 1:
|
|
2039
|
+
return _bool_unary_from_f32(name)
|
|
2040
|
+
if function.bool_from_f32_arity == 2:
|
|
2041
|
+
return _bool_binary_from_f32(name)
|
|
2042
|
+
raise ScalarFunctionError(f"unsupported bool scalar op: {name}")
|
|
2043
|
+
|
|
2044
|
+
|
|
2045
|
+
_SCALAR_TYPES: Dict[ScalarType, _ScalarTypeInfo] = {
|
|
2046
|
+
ScalarType.F32: _ScalarTypeInfo(
|
|
2047
|
+
scalar_type=ScalarType.F32,
|
|
2048
|
+
c_type="float",
|
|
2049
|
+
prefix="ref_scalar_f32_",
|
|
2050
|
+
suffix="f32",
|
|
2051
|
+
is_float=True,
|
|
2052
|
+
is_bool=False,
|
|
2053
|
+
is_signed=True,
|
|
2054
|
+
is_small_int=False,
|
|
2055
|
+
bits=None,
|
|
2056
|
+
),
|
|
2057
|
+
ScalarType.F64: _ScalarTypeInfo(
|
|
2058
|
+
scalar_type=ScalarType.F64,
|
|
2059
|
+
c_type="double",
|
|
2060
|
+
prefix="ref_scalar_f64_",
|
|
2061
|
+
suffix="f64",
|
|
2062
|
+
is_float=True,
|
|
2063
|
+
is_bool=False,
|
|
2064
|
+
is_signed=True,
|
|
2065
|
+
is_small_int=False,
|
|
2066
|
+
bits=None,
|
|
2067
|
+
),
|
|
2068
|
+
ScalarType.I8: _ScalarTypeInfo(
|
|
2069
|
+
scalar_type=ScalarType.I8,
|
|
2070
|
+
c_type="int8_t",
|
|
2071
|
+
prefix="ref_scalar_i8_",
|
|
2072
|
+
suffix="i8",
|
|
2073
|
+
is_float=False,
|
|
2074
|
+
is_bool=False,
|
|
2075
|
+
is_signed=True,
|
|
2076
|
+
is_small_int=True,
|
|
2077
|
+
bits=8,
|
|
2078
|
+
),
|
|
2079
|
+
ScalarType.I16: _ScalarTypeInfo(
|
|
2080
|
+
scalar_type=ScalarType.I16,
|
|
2081
|
+
c_type="int16_t",
|
|
2082
|
+
prefix="ref_scalar_i16_",
|
|
2083
|
+
suffix="i16",
|
|
2084
|
+
is_float=False,
|
|
2085
|
+
is_bool=False,
|
|
2086
|
+
is_signed=True,
|
|
2087
|
+
is_small_int=True,
|
|
2088
|
+
bits=16,
|
|
2089
|
+
),
|
|
2090
|
+
ScalarType.I32: _ScalarTypeInfo(
|
|
2091
|
+
scalar_type=ScalarType.I32,
|
|
2092
|
+
c_type="int32_t",
|
|
2093
|
+
prefix="ref_scalar_i32_",
|
|
2094
|
+
suffix="i32",
|
|
2095
|
+
is_float=False,
|
|
2096
|
+
is_bool=False,
|
|
2097
|
+
is_signed=True,
|
|
2098
|
+
is_small_int=False,
|
|
2099
|
+
bits=32,
|
|
2100
|
+
),
|
|
2101
|
+
ScalarType.I64: _ScalarTypeInfo(
|
|
2102
|
+
scalar_type=ScalarType.I64,
|
|
2103
|
+
c_type="int64_t",
|
|
2104
|
+
prefix="ref_scalar_i64_",
|
|
2105
|
+
suffix="i64",
|
|
2106
|
+
is_float=False,
|
|
2107
|
+
is_bool=False,
|
|
2108
|
+
is_signed=True,
|
|
2109
|
+
is_small_int=False,
|
|
2110
|
+
bits=64,
|
|
2111
|
+
),
|
|
2112
|
+
ScalarType.U8: _ScalarTypeInfo(
|
|
2113
|
+
scalar_type=ScalarType.U8,
|
|
2114
|
+
c_type="uint8_t",
|
|
2115
|
+
prefix="ref_scalar_u8_",
|
|
2116
|
+
suffix="u8",
|
|
2117
|
+
is_float=False,
|
|
2118
|
+
is_bool=False,
|
|
2119
|
+
is_signed=False,
|
|
2120
|
+
is_small_int=True,
|
|
2121
|
+
bits=8,
|
|
2122
|
+
),
|
|
2123
|
+
ScalarType.U16: _ScalarTypeInfo(
|
|
2124
|
+
scalar_type=ScalarType.U16,
|
|
2125
|
+
c_type="uint16_t",
|
|
2126
|
+
prefix="ref_scalar_u16_",
|
|
2127
|
+
suffix="u16",
|
|
2128
|
+
is_float=False,
|
|
2129
|
+
is_bool=False,
|
|
2130
|
+
is_signed=False,
|
|
2131
|
+
is_small_int=True,
|
|
2132
|
+
bits=16,
|
|
2133
|
+
),
|
|
2134
|
+
ScalarType.U32: _ScalarTypeInfo(
|
|
2135
|
+
scalar_type=ScalarType.U32,
|
|
2136
|
+
c_type="uint32_t",
|
|
2137
|
+
prefix="ref_scalar_u32_",
|
|
2138
|
+
suffix="u32",
|
|
2139
|
+
is_float=False,
|
|
2140
|
+
is_bool=False,
|
|
2141
|
+
is_signed=False,
|
|
2142
|
+
is_small_int=False,
|
|
2143
|
+
bits=32,
|
|
2144
|
+
),
|
|
2145
|
+
ScalarType.U64: _ScalarTypeInfo(
|
|
2146
|
+
scalar_type=ScalarType.U64,
|
|
2147
|
+
c_type="uint64_t",
|
|
2148
|
+
prefix="ref_scalar_u64_",
|
|
2149
|
+
suffix="u64",
|
|
2150
|
+
is_float=False,
|
|
2151
|
+
is_bool=False,
|
|
2152
|
+
is_signed=False,
|
|
2153
|
+
is_small_int=False,
|
|
2154
|
+
bits=64,
|
|
2155
|
+
),
|
|
2156
|
+
ScalarType.BOOL: _ScalarTypeInfo(
|
|
2157
|
+
scalar_type=ScalarType.BOOL,
|
|
2158
|
+
c_type="bool",
|
|
2159
|
+
prefix="ref_scalar_bool_",
|
|
2160
|
+
suffix="bool",
|
|
2161
|
+
is_float=False,
|
|
2162
|
+
is_bool=True,
|
|
2163
|
+
is_signed=False,
|
|
2164
|
+
is_small_int=False,
|
|
2165
|
+
bits=None,
|
|
2166
|
+
),
|
|
2167
|
+
}
|
|
2168
|
+
|
|
2169
|
+
|
|
2170
|
+
_SCALAR_TYPE_BY_ENUM: Mapping[ScalarType, _ScalarTypeInfo] = _SCALAR_TYPES
|
|
2171
|
+
|
|
2172
|
+
|
|
2173
|
+
_CONVERSION_SOURCE_BY_FUNCTION: Mapping[ScalarFunction, ScalarType] = {
|
|
2174
|
+
ScalarFunction.CONVERT_FROM_F32: ScalarType.F32,
|
|
2175
|
+
ScalarFunction.CONVERT_FROM_F64: ScalarType.F64,
|
|
2176
|
+
ScalarFunction.CONVERT_FROM_I8: ScalarType.I8,
|
|
2177
|
+
ScalarFunction.CONVERT_FROM_I16: ScalarType.I16,
|
|
2178
|
+
ScalarFunction.CONVERT_FROM_I32: ScalarType.I32,
|
|
2179
|
+
ScalarFunction.CONVERT_FROM_I64: ScalarType.I64,
|
|
2180
|
+
ScalarFunction.CONVERT_FROM_U8: ScalarType.U8,
|
|
2181
|
+
ScalarFunction.CONVERT_FROM_U16: ScalarType.U16,
|
|
2182
|
+
ScalarFunction.CONVERT_FROM_U32: ScalarType.U32,
|
|
2183
|
+
ScalarFunction.CONVERT_FROM_U64: ScalarType.U64,
|
|
2184
|
+
ScalarFunction.CONVERT_FROM_BOOL: ScalarType.BOOL,
|
|
2185
|
+
}
|
|
2186
|
+
|
|
2187
|
+
|
|
2188
|
+
def _scalar_type_info(dtype: ScalarType) -> _ScalarTypeInfo:
|
|
2189
|
+
try:
|
|
2190
|
+
return _SCALAR_TYPE_BY_ENUM[dtype]
|
|
2191
|
+
except KeyError as exc:
|
|
2192
|
+
raise ScalarFunctionError(
|
|
2193
|
+
f"unsupported scalar dtype: {dtype.value}"
|
|
2194
|
+
) from exc
|
|
2195
|
+
|
|
2196
|
+
|
|
2197
|
+
def _supported_ops(dtype_info: _ScalarTypeInfo) -> Set[str]:
|
|
2198
|
+
supported = {
|
|
2199
|
+
_normalize_op_name(function.value)
|
|
2200
|
+
for function in ScalarFunction
|
|
2201
|
+
if not function.value.startswith("convert_from_")
|
|
2202
|
+
and function.supports_dtype(dtype_info)
|
|
2203
|
+
}
|
|
2204
|
+
if not dtype_info.is_float:
|
|
2205
|
+
supported.add("from_f32")
|
|
2206
|
+
if dtype_info.is_bool:
|
|
2207
|
+
supported.add("to_f32")
|
|
2208
|
+
return supported
|
|
2209
|
+
|
|
2210
|
+
|
|
2211
|
+
def validate_scalar_function_supported_ops() -> None:
|
|
2212
|
+
scalar_ops = {
|
|
2213
|
+
_normalize_op_name(function.value)
|
|
2214
|
+
for function in ScalarFunction
|
|
2215
|
+
if not function.value.startswith("convert_from_")
|
|
2216
|
+
}
|
|
2217
|
+
conversion_aliases = {"from_f32", "to_f32"}
|
|
2218
|
+
categories = {
|
|
2219
|
+
"float": _supported_ops(_SCALAR_TYPES[ScalarType.F32]),
|
|
2220
|
+
"bool": _supported_ops(_SCALAR_TYPES[ScalarType.BOOL]),
|
|
2221
|
+
"signed_int": _supported_ops(_SCALAR_TYPES[ScalarType.I8]),
|
|
2222
|
+
"unsigned_int": _supported_ops(_SCALAR_TYPES[ScalarType.U8]),
|
|
2223
|
+
}
|
|
2224
|
+
errors: List[str] = []
|
|
2225
|
+
for category, supported in categories.items():
|
|
2226
|
+
missing = sorted(supported - scalar_ops - conversion_aliases)
|
|
2227
|
+
if missing:
|
|
2228
|
+
errors.append(
|
|
2229
|
+
f"{category} missing ScalarFunction ops (defined in _supported_ops): {missing}"
|
|
2230
|
+
)
|
|
2231
|
+
supported_union = set().union(*categories.values()) - conversion_aliases
|
|
2232
|
+
unexpected_extras = sorted(scalar_ops - supported_union)
|
|
2233
|
+
if unexpected_extras:
|
|
2234
|
+
errors.append(
|
|
2235
|
+
"ScalarFunction ops not supported by any dtype category: "
|
|
2236
|
+
f"{unexpected_extras}"
|
|
2237
|
+
)
|
|
2238
|
+
if errors:
|
|
2239
|
+
raise AssertionError(
|
|
2240
|
+
"ScalarFunction/_supported_ops drift detected:\n" + "\n".join(errors)
|
|
2241
|
+
)
|
|
2242
|
+
|
|
2243
|
+
|
|
2244
|
+
def _scalar_info_for_key(key: ScalarFunctionKey) -> tuple[_ScalarTypeInfo, str]:
|
|
2245
|
+
if key.function in _CONVERSION_SOURCE_BY_FUNCTION:
|
|
2246
|
+
source_type = _CONVERSION_SOURCE_BY_FUNCTION[key.function]
|
|
2247
|
+
if source_type == ScalarType.F32:
|
|
2248
|
+
return _scalar_type_info(key.return_type), "from_f32"
|
|
2249
|
+
if source_type == ScalarType.BOOL:
|
|
2250
|
+
if key.return_type != ScalarType.F32:
|
|
2251
|
+
raise ScalarFunctionError(
|
|
2252
|
+
f"unsupported scalar conversion from {source_type.value} to {key.return_type.value}"
|
|
2253
|
+
)
|
|
2254
|
+
return _scalar_type_info(source_type), "to_f32"
|
|
2255
|
+
raise ScalarFunctionError(
|
|
2256
|
+
f"unsupported scalar conversion from {source_type.value} to {key.return_type.value}"
|
|
2257
|
+
)
|
|
2258
|
+
return _scalar_type_info(key.return_type), key.function.value
|
|
2259
|
+
|
|
2260
|
+
|
|
2261
|
+
def _generate_scalar(key: ScalarFunctionKey) -> _GeneratedScalar:
|
|
2262
|
+
dtype_info, op_name = _scalar_info_for_key(key)
|
|
2263
|
+
if _normalize_op_name(op_name) not in _supported_ops(dtype_info):
|
|
2264
|
+
raise ScalarFunctionError(
|
|
2265
|
+
f"unsupported scalar op {op_name} for {dtype_info.suffix}"
|
|
2266
|
+
)
|
|
2267
|
+
if dtype_info.is_float:
|
|
2268
|
+
param_handler = _PARAMETERIZED_FLOAT_OPS.get(key.function)
|
|
2269
|
+
if param_handler is not None:
|
|
2270
|
+
generated = param_handler(
|
|
2271
|
+
dtype_info, key.params, _function_name_for_key(key)
|
|
2272
|
+
)
|
|
2273
|
+
else:
|
|
2274
|
+
generated = _float_from_ops(dtype_info, op_name)
|
|
2275
|
+
elif dtype_info.is_bool:
|
|
2276
|
+
generated = _bool_from_ops(op_name)
|
|
2277
|
+
else:
|
|
2278
|
+
generated = _int_from_ops(dtype_info, op_name)
|
|
2279
|
+
includes = set(generated.includes)
|
|
2280
|
+
if dtype_info.is_float:
|
|
2281
|
+
includes.update({"#include <math.h>", "#include <float.h>"})
|
|
2282
|
+
if not dtype_info.is_float and not dtype_info.is_bool:
|
|
2283
|
+
includes.update({"#include <stdint.h>"})
|
|
2284
|
+
if dtype_info.is_signed:
|
|
2285
|
+
includes.add("#include <limits.h>")
|
|
2286
|
+
if dtype_info.is_bool:
|
|
2287
|
+
includes.add("#include <stdbool.h>")
|
|
2288
|
+
return _GeneratedScalar(lines=generated.lines, deps=generated.deps, includes=includes)
|
|
2289
|
+
|
|
2290
|
+
|
|
2291
|
+
def _function_name_for_key(key: ScalarFunctionKey) -> str:
|
|
2292
|
+
param_suffix = _param_suffix(key.params)
|
|
2293
|
+
if key.function in _CONVERSION_SOURCE_BY_FUNCTION:
|
|
2294
|
+
source_type = _CONVERSION_SOURCE_BY_FUNCTION[key.function]
|
|
2295
|
+
if source_type == ScalarType.F32:
|
|
2296
|
+
if key.return_type in {
|
|
2297
|
+
ScalarType.I8,
|
|
2298
|
+
ScalarType.I16,
|
|
2299
|
+
ScalarType.I32,
|
|
2300
|
+
ScalarType.I64,
|
|
2301
|
+
ScalarType.U8,
|
|
2302
|
+
ScalarType.U16,
|
|
2303
|
+
ScalarType.U32,
|
|
2304
|
+
ScalarType.U64,
|
|
2305
|
+
ScalarType.BOOL,
|
|
2306
|
+
}:
|
|
2307
|
+
target_info = _scalar_type_info(key.return_type)
|
|
2308
|
+
return f"{target_info.prefix}from_f32{param_suffix}"
|
|
2309
|
+
raise ScalarFunctionError(
|
|
2310
|
+
f"unsupported scalar conversion from {source_type.value} to {key.return_type.value}"
|
|
2311
|
+
)
|
|
2312
|
+
if source_type == ScalarType.BOOL:
|
|
2313
|
+
if key.return_type == ScalarType.F32:
|
|
2314
|
+
source_info = _scalar_type_info(source_type)
|
|
2315
|
+
return f"{source_info.prefix}to_f32{param_suffix}"
|
|
2316
|
+
raise ScalarFunctionError(
|
|
2317
|
+
f"unsupported scalar conversion from {source_type.value} to {key.return_type.value}"
|
|
2318
|
+
)
|
|
2319
|
+
raise ScalarFunctionError(
|
|
2320
|
+
f"unsupported scalar conversion from {source_type.value} to {key.return_type.value}"
|
|
2321
|
+
)
|
|
2322
|
+
op_name = key.function.value
|
|
2323
|
+
dtype_info = _scalar_type_info(key.return_type)
|
|
2324
|
+
if _normalize_op_name(op_name) not in _supported_ops(dtype_info):
|
|
2325
|
+
raise ScalarFunctionError(
|
|
2326
|
+
f"unsupported scalar op {op_name} for {dtype_info.suffix}"
|
|
2327
|
+
)
|
|
2328
|
+
return f"{dtype_info.prefix}{op_name}{param_suffix}"
|
|
2329
|
+
|
|
2330
|
+
|
|
2331
|
+
class ScalarFunctionRegistry:
|
|
2332
|
+
def __init__(self) -> None:
|
|
2333
|
+
self._requested: List[ScalarFunctionKey] = []
|
|
2334
|
+
self._requested_set: Set[ScalarFunctionKey] = set()
|
|
2335
|
+
self._key_to_name: Dict[ScalarFunctionKey, str] = {}
|
|
2336
|
+
self._generated: Dict[ScalarFunctionKey, _GeneratedScalar] = {}
|
|
2337
|
+
|
|
2338
|
+
def request(self, key: ScalarFunctionKey) -> str:
|
|
2339
|
+
name = self._key_to_name.get(key)
|
|
2340
|
+
if name is None:
|
|
2341
|
+
name = _function_name_for_key(key)
|
|
2342
|
+
self._key_to_name[key] = name
|
|
2343
|
+
self._register_key(key)
|
|
2344
|
+
return name
|
|
2345
|
+
|
|
2346
|
+
def _register_key(self, key: ScalarFunctionKey) -> None:
|
|
2347
|
+
if key in self._requested_set:
|
|
2348
|
+
return
|
|
2349
|
+
self._requested.append(key)
|
|
2350
|
+
self._requested_set.add(key)
|
|
2351
|
+
|
|
2352
|
+
def include_lines(self) -> List[str]:
|
|
2353
|
+
includes: Set[str] = set()
|
|
2354
|
+
visited: Set[ScalarFunctionKey] = set()
|
|
2355
|
+
|
|
2356
|
+
def collect(key: ScalarFunctionKey) -> None:
|
|
2357
|
+
if key in visited:
|
|
2358
|
+
return
|
|
2359
|
+
self._ensure_generated(key)
|
|
2360
|
+
entry = self._generated[key]
|
|
2361
|
+
visited.add(key)
|
|
2362
|
+
for dep in entry.deps:
|
|
2363
|
+
collect(dep)
|
|
2364
|
+
includes.update(entry.includes)
|
|
2365
|
+
|
|
2366
|
+
for key in self._requested:
|
|
2367
|
+
collect(key)
|
|
2368
|
+
ordered = sorted(includes)
|
|
2369
|
+
preamble = [
|
|
2370
|
+
"#ifndef REF_PI_F",
|
|
2371
|
+
"#define REF_PI_F 3.14159265358979323846f",
|
|
2372
|
+
"#endif",
|
|
2373
|
+
"#ifndef REF_PI_D",
|
|
2374
|
+
"#define REF_PI_D 3.14159265358979323846",
|
|
2375
|
+
"#endif",
|
|
2376
|
+
]
|
|
2377
|
+
return ordered + preamble
|
|
2378
|
+
|
|
2379
|
+
def render(self) -> List[str]:
|
|
2380
|
+
if not self._requested:
|
|
2381
|
+
return []
|
|
2382
|
+
lines: List[str] = []
|
|
2383
|
+
emitted: Set[ScalarFunctionKey] = set()
|
|
2384
|
+
|
|
2385
|
+
def emit(key: ScalarFunctionKey) -> None:
|
|
2386
|
+
if key in emitted:
|
|
2387
|
+
return
|
|
2388
|
+
self._ensure_generated(key)
|
|
2389
|
+
entry = self._generated[key]
|
|
2390
|
+
for dep in sorted(entry.deps, key=_function_name_for_key):
|
|
2391
|
+
emit(dep)
|
|
2392
|
+
lines.extend(entry.lines)
|
|
2393
|
+
lines.append("")
|
|
2394
|
+
emitted.add(key)
|
|
2395
|
+
|
|
2396
|
+
for key in self._requested:
|
|
2397
|
+
emit(key)
|
|
2398
|
+
while lines and lines[-1] == "":
|
|
2399
|
+
lines.pop()
|
|
2400
|
+
return lines
|
|
2401
|
+
|
|
2402
|
+
def _ensure_generated(self, key: ScalarFunctionKey) -> None:
|
|
2403
|
+
if key in self._generated:
|
|
2404
|
+
return
|
|
2405
|
+
self._generated[key] = _generate_scalar(key)
|