gumath 0.2.0dev5 → 0.2.0dev8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
|
@@ -0,0 +1,1100 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* BSD 3-Clause License
|
|
3
|
+
*
|
|
4
|
+
* Copyright (c) 2017-2018, plures
|
|
5
|
+
* All rights reserved.
|
|
6
|
+
*
|
|
7
|
+
* Redistribution and use in source and binary forms, with or without
|
|
8
|
+
* modification, are permitted provided that the following conditions are met:
|
|
9
|
+
*
|
|
10
|
+
* 1. Redistributions of source code must retain the above copyright notice,
|
|
11
|
+
* this list of conditions and the following disclaimer.
|
|
12
|
+
*
|
|
13
|
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
14
|
+
* this list of conditions and the following disclaimer in the documentation
|
|
15
|
+
* and/or other materials provided with the distribution.
|
|
16
|
+
*
|
|
17
|
+
* 3. Neither the name of the copyright holder nor the names of its
|
|
18
|
+
* contributors may be used to endorse or promote products derived from
|
|
19
|
+
* this software without specific prior written permission.
|
|
20
|
+
*
|
|
21
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
22
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
23
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
24
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
25
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
26
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
27
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
28
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
29
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
30
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
31
|
+
*/
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
#include <stdlib.h>
|
|
35
|
+
#include <stdint.h>
|
|
36
|
+
#include <string.h>
|
|
37
|
+
#include "ndtypes.h"
|
|
38
|
+
#include "xnd.h"
|
|
39
|
+
#include "gumath.h"
|
|
40
|
+
#include "common.h"
|
|
41
|
+
#include "cpu_device_unary.h"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
/****************************************************************************/
|
|
45
|
+
/* Kernel locations for optimized lookup */
|
|
46
|
+
/****************************************************************************/
|
|
47
|
+
|
|
48
|
+
static int
|
|
49
|
+
copy_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
|
|
50
|
+
{
|
|
51
|
+
const ndt_t *t = ndt_dtype(in);
|
|
52
|
+
const ndt_t *u = ndt_dtype(out);
|
|
53
|
+
|
|
54
|
+
switch (t->tag) {
|
|
55
|
+
case Bool: {
|
|
56
|
+
switch (u->tag) {
|
|
57
|
+
case Bool: return 0;
|
|
58
|
+
case Uint8: return 6;
|
|
59
|
+
case Uint16: return 12;
|
|
60
|
+
case Uint32: return 18;
|
|
61
|
+
case Uint64: return 24;
|
|
62
|
+
case Int8: return 30;
|
|
63
|
+
case Int16: return 36;
|
|
64
|
+
case Int32: return 42;
|
|
65
|
+
case Int64: return 48;
|
|
66
|
+
case BFloat16: return 54;
|
|
67
|
+
case Float16: return 60;
|
|
68
|
+
case Float32: return 66;
|
|
69
|
+
case Float64: return 72;
|
|
70
|
+
case Complex32: return 78;
|
|
71
|
+
case Complex64: return 84;
|
|
72
|
+
case Complex128: return 90;
|
|
73
|
+
default: goto invalid_combination;
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
case Uint8: {
|
|
78
|
+
switch (u->tag) {
|
|
79
|
+
case Uint8: return 96;
|
|
80
|
+
case Uint16: return 102;
|
|
81
|
+
case Uint32: return 108;
|
|
82
|
+
case Uint64: return 114;
|
|
83
|
+
case Int16: return 120;
|
|
84
|
+
case Int32: return 126;
|
|
85
|
+
case Int64: return 132;
|
|
86
|
+
case BFloat16: return 138;
|
|
87
|
+
case Float16: return 144;
|
|
88
|
+
case Float32: return 150;
|
|
89
|
+
case Float64: return 156;
|
|
90
|
+
case Complex32: return 162;
|
|
91
|
+
case Complex64: return 168;
|
|
92
|
+
case Complex128: return 174;
|
|
93
|
+
default: goto invalid_combination;
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
case Uint16: {
|
|
98
|
+
switch (u->tag) {
|
|
99
|
+
case Uint16: return 180;
|
|
100
|
+
case Uint32: return 186;
|
|
101
|
+
case Uint64: return 192;
|
|
102
|
+
case Int32: return 198;
|
|
103
|
+
case Int64: return 204;
|
|
104
|
+
case Float32: return 210;
|
|
105
|
+
case Float64: return 216;
|
|
106
|
+
case Complex64: return 222;
|
|
107
|
+
case Complex128: return 228;
|
|
108
|
+
default: goto invalid_combination;
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
case Uint32: {
|
|
113
|
+
switch (u->tag) {
|
|
114
|
+
case Uint32: return 234;
|
|
115
|
+
case Uint64: return 240;
|
|
116
|
+
case Int64: return 246;
|
|
117
|
+
case Float64: return 252;
|
|
118
|
+
case Complex128: return 258;
|
|
119
|
+
default: goto invalid_combination;
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
case Uint64: {
|
|
124
|
+
switch (u->tag) {
|
|
125
|
+
case Uint64: return 264;
|
|
126
|
+
default: goto invalid_combination;
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
case Int8: {
|
|
131
|
+
switch (u->tag) {
|
|
132
|
+
case Int8: return 270;
|
|
133
|
+
case Int16: return 276;
|
|
134
|
+
case Int32: return 282;
|
|
135
|
+
case Int64: return 288;
|
|
136
|
+
case BFloat16: return 294;
|
|
137
|
+
case Float16: return 300;
|
|
138
|
+
case Float32: return 306;
|
|
139
|
+
case Float64: return 312;
|
|
140
|
+
case Complex32: return 318;
|
|
141
|
+
case Complex64: return 324;
|
|
142
|
+
case Complex128: return 330;
|
|
143
|
+
default: goto invalid_combination;
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
case Int16: {
|
|
148
|
+
switch (u->tag) {
|
|
149
|
+
case Int16: return 336;
|
|
150
|
+
case Int32: return 342;
|
|
151
|
+
case Int64: return 348;
|
|
152
|
+
case Float32: return 354;
|
|
153
|
+
case Float64: return 360;
|
|
154
|
+
case Complex64: return 366;
|
|
155
|
+
case Complex128: return 372;
|
|
156
|
+
default: goto invalid_combination;
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
case Int32: {
|
|
161
|
+
switch (u->tag) {
|
|
162
|
+
case Int32: return 378;
|
|
163
|
+
case Int64: return 384;
|
|
164
|
+
case Float64: return 390;
|
|
165
|
+
case Complex128: return 396;
|
|
166
|
+
default: goto invalid_combination;
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
case Int64: {
|
|
171
|
+
switch (u->tag) {
|
|
172
|
+
case Int64: return 402;
|
|
173
|
+
default: goto invalid_combination;
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
case BFloat16: {
|
|
178
|
+
switch (u->tag) {
|
|
179
|
+
case BFloat16: return 408;
|
|
180
|
+
case Float32: return 414;
|
|
181
|
+
case Float64: return 420;
|
|
182
|
+
case Complex64: return 426;
|
|
183
|
+
case Complex128: return 432;
|
|
184
|
+
default: goto invalid_combination;
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
case Float16: {
|
|
189
|
+
switch (u->tag) {
|
|
190
|
+
case Float16: return 438;
|
|
191
|
+
case Float32: return 444;
|
|
192
|
+
case Float64: return 450;
|
|
193
|
+
case Complex32: return 456;
|
|
194
|
+
case Complex64: return 462;
|
|
195
|
+
case Complex128: return 468;
|
|
196
|
+
default: goto invalid_combination;
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
case Float32: {
|
|
201
|
+
switch (u->tag) {
|
|
202
|
+
case Float32: return 474;
|
|
203
|
+
case Float64: return 480;
|
|
204
|
+
case Complex64: return 486;
|
|
205
|
+
case Complex128: return 492;
|
|
206
|
+
default: goto invalid_combination;
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
case Float64: {
|
|
211
|
+
switch (u->tag) {
|
|
212
|
+
case Float64: return 498;
|
|
213
|
+
case Complex128: return 504;
|
|
214
|
+
default: goto invalid_combination;
|
|
215
|
+
}
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
case Complex32: {
|
|
219
|
+
switch (u->tag) {
|
|
220
|
+
case Complex32: return 510;
|
|
221
|
+
case Complex64: return 516;
|
|
222
|
+
case Complex128: return 522;
|
|
223
|
+
default: goto invalid_combination;
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
case Complex64: {
|
|
228
|
+
switch (u->tag) {
|
|
229
|
+
case Complex64: return 528;
|
|
230
|
+
case Complex128: return 534;
|
|
231
|
+
default: goto invalid_combination;
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
case Complex128: {
|
|
236
|
+
switch (u->tag) {
|
|
237
|
+
case Complex128: return 540;
|
|
238
|
+
default: goto invalid_combination;
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
default: goto invalid_combination;
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
invalid_combination:
|
|
246
|
+
ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
|
|
247
|
+
return -1;
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
static int
|
|
251
|
+
invert_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
|
|
252
|
+
{
|
|
253
|
+
const ndt_t *t = ndt_dtype(in);
|
|
254
|
+
(void)out;
|
|
255
|
+
|
|
256
|
+
switch (t->tag) {
|
|
257
|
+
case Bool: return 0;
|
|
258
|
+
|
|
259
|
+
case Uint8: return 6;
|
|
260
|
+
case Uint16: return 12;
|
|
261
|
+
case Uint32: return 18;
|
|
262
|
+
case Uint64: return 24;
|
|
263
|
+
|
|
264
|
+
case Int8: return 30;
|
|
265
|
+
case Int16: return 36;
|
|
266
|
+
case Int32: return 42;
|
|
267
|
+
case Int64: return 48;
|
|
268
|
+
|
|
269
|
+
default:
|
|
270
|
+
ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
|
|
271
|
+
return -1;
|
|
272
|
+
}
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
static int
|
|
276
|
+
negative_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
|
|
277
|
+
{
|
|
278
|
+
const ndt_t *t = ndt_dtype(in);
|
|
279
|
+
(void)out;
|
|
280
|
+
|
|
281
|
+
switch (t->tag) {
|
|
282
|
+
case Uint8: return 0;
|
|
283
|
+
case Uint16: return 6;
|
|
284
|
+
case Uint32: return 12;
|
|
285
|
+
|
|
286
|
+
case Int8: return 18;
|
|
287
|
+
case Int16: return 24;
|
|
288
|
+
case Int32: return 30;
|
|
289
|
+
case Int64: return 36;
|
|
290
|
+
|
|
291
|
+
case BFloat16: return 42;
|
|
292
|
+
case Float16: return 48;
|
|
293
|
+
case Float32: return 54;
|
|
294
|
+
case Float64: return 60;
|
|
295
|
+
|
|
296
|
+
case Complex32: return 66;
|
|
297
|
+
case Complex64: return 72;
|
|
298
|
+
case Complex128: return 78;
|
|
299
|
+
|
|
300
|
+
default:
|
|
301
|
+
ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
|
|
302
|
+
return -1;
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
static int
|
|
307
|
+
math_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
|
|
308
|
+
{
|
|
309
|
+
const ndt_t *t = ndt_dtype(in);
|
|
310
|
+
(void)out;
|
|
311
|
+
|
|
312
|
+
switch (t->tag) {
|
|
313
|
+
case Uint8: return 0;
|
|
314
|
+
case Int8: return 6;
|
|
315
|
+
case Float16: return 12;
|
|
316
|
+
|
|
317
|
+
case BFloat16: return 18;
|
|
318
|
+
|
|
319
|
+
case Uint16: return 24;
|
|
320
|
+
case Int16: return 30;
|
|
321
|
+
case Float32: return 36;
|
|
322
|
+
|
|
323
|
+
case Uint32: return 42;
|
|
324
|
+
case Int32: return 48;
|
|
325
|
+
case Float64: return 54;
|
|
326
|
+
|
|
327
|
+
case Complex32: return 60;
|
|
328
|
+
case Complex64: return 66;
|
|
329
|
+
case Complex128: return 72;
|
|
330
|
+
|
|
331
|
+
default:
|
|
332
|
+
ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
|
|
333
|
+
return -1;
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
/*****************************************************************************/
|
|
339
|
+
/* CPU-specific unary macros */
|
|
340
|
+
/*****************************************************************************/
|
|
341
|
+
|
|
342
|
+
#define CPU_HOST_UNARY(name, t0, t1) \
|
|
343
|
+
static int \
|
|
344
|
+
gm_cpu_host_fixed_1D_C_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
|
|
345
|
+
{ \
|
|
346
|
+
const char *a0 = apply_index(&stack[0]); \
|
|
347
|
+
char *a1 = apply_index(&stack[1]); \
|
|
348
|
+
const int64_t N = xnd_fixed_shape(&stack[0]); \
|
|
349
|
+
(void)ctx; \
|
|
350
|
+
\
|
|
351
|
+
gm_cpu_device_fixed_1D_C_##name##_##t0##_##t1(a0, a1, N); \
|
|
352
|
+
\
|
|
353
|
+
if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
|
|
354
|
+
unary_update_bitmap_1D_S(stack); \
|
|
355
|
+
} \
|
|
356
|
+
\
|
|
357
|
+
return 0; \
|
|
358
|
+
} \
|
|
359
|
+
\
|
|
360
|
+
static int \
|
|
361
|
+
gm_cpu_host_fixed_1D_S_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
|
|
362
|
+
{ \
|
|
363
|
+
const char *a0 = apply_index(&stack[0]); \
|
|
364
|
+
char *a1 = apply_index(&stack[1]); \
|
|
365
|
+
const int64_t N = xnd_fixed_shape(&stack[0]); \
|
|
366
|
+
const int64_t s0 = xnd_fixed_step(&stack[0]); \
|
|
367
|
+
const int64_t s1 = xnd_fixed_step(&stack[1]); \
|
|
368
|
+
(void)ctx; \
|
|
369
|
+
\
|
|
370
|
+
gm_cpu_device_fixed_1D_S_##name##_##t0##_##t1(a0, a1, s0, s1, N); \
|
|
371
|
+
\
|
|
372
|
+
if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
|
|
373
|
+
unary_update_bitmap_1D_S(stack); \
|
|
374
|
+
} \
|
|
375
|
+
\
|
|
376
|
+
return 0; \
|
|
377
|
+
} \
|
|
378
|
+
\
|
|
379
|
+
static int \
|
|
380
|
+
gm_cpu_host_array_1D_C_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
|
|
381
|
+
{ \
|
|
382
|
+
const char *a0 = XND_ARRAY_DATA(stack[0].ptr); \
|
|
383
|
+
const int64_t N = XND_ARRAY_SHAPE(stack[0].ptr); \
|
|
384
|
+
(void)ctx; \
|
|
385
|
+
\
|
|
386
|
+
if (array_shape_check(&stack[1], N, ctx) < 0) { \
|
|
387
|
+
return -1; \
|
|
388
|
+
} \
|
|
389
|
+
char *a1 = XND_ARRAY_DATA(stack[1].ptr); \
|
|
390
|
+
\
|
|
391
|
+
gm_cpu_device_fixed_1D_C_##name##_##t0##_##t1(a0, a1, N); \
|
|
392
|
+
\
|
|
393
|
+
if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
|
|
394
|
+
unary_update_bitmap_1D_S(stack); \
|
|
395
|
+
} \
|
|
396
|
+
\
|
|
397
|
+
return 0; \
|
|
398
|
+
} \
|
|
399
|
+
\
|
|
400
|
+
static int \
|
|
401
|
+
gm_cpu_host_0D_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
|
|
402
|
+
{ \
|
|
403
|
+
const char *a0 = stack[0].ptr; \
|
|
404
|
+
char *a1 = stack[1].ptr; \
|
|
405
|
+
(void)ctx; \
|
|
406
|
+
\
|
|
407
|
+
gm_cpu_device_0D_##name##_##t0##_##t1(a0, a1); \
|
|
408
|
+
\
|
|
409
|
+
if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
|
|
410
|
+
unary_update_bitmap_0D(stack); \
|
|
411
|
+
} \
|
|
412
|
+
\
|
|
413
|
+
return 0; \
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
#define CPU_HOST_NOIMPL(name, t0, t1) \
|
|
417
|
+
static int \
|
|
418
|
+
gm_cpu_host_fixed_1D_C_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
|
|
419
|
+
{ \
|
|
420
|
+
(void)stack; \
|
|
421
|
+
\
|
|
422
|
+
ndt_err_format(ctx, NDT_NotImplementedError, \
|
|
423
|
+
"implementation for " STRINGIZE(name) " : " \
|
|
424
|
+
STRINGIZE(t0) " -> " STRINGIZE(t1) \
|
|
425
|
+
" currently requires double rounding"); \
|
|
426
|
+
\
|
|
427
|
+
return -1; \
|
|
428
|
+
} \
|
|
429
|
+
\
|
|
430
|
+
static int \
|
|
431
|
+
gm_cpu_host_fixed_1D_S_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
|
|
432
|
+
{ \
|
|
433
|
+
(void)stack; \
|
|
434
|
+
\
|
|
435
|
+
ndt_err_format(ctx, NDT_NotImplementedError, \
|
|
436
|
+
"implementation for " STRINGIZE(name) " : " \
|
|
437
|
+
STRINGIZE(t0) " -> " STRINGIZE(t1) \
|
|
438
|
+
" currently requires double rounding"); \
|
|
439
|
+
\
|
|
440
|
+
return -1; \
|
|
441
|
+
} \
|
|
442
|
+
\
|
|
443
|
+
static int \
|
|
444
|
+
gm_cpu_host_0D_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
|
|
445
|
+
{ \
|
|
446
|
+
(void)stack; \
|
|
447
|
+
\
|
|
448
|
+
ndt_err_format(ctx, NDT_NotImplementedError, \
|
|
449
|
+
"implementation for " STRINGIZE(name) " : " \
|
|
450
|
+
STRINGIZE(t0) " -> " STRINGIZE(t1) \
|
|
451
|
+
" currently requires double rounding"); \
|
|
452
|
+
\
|
|
453
|
+
return -1; \
|
|
454
|
+
} \
|
|
455
|
+
\
|
|
456
|
+
static int \
|
|
457
|
+
gm_cpu_host_array_1D_C_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
|
|
458
|
+
{ \
|
|
459
|
+
(void)stack; \
|
|
460
|
+
\
|
|
461
|
+
ndt_err_format(ctx, NDT_NotImplementedError, \
|
|
462
|
+
"implementation for " STRINGIZE(name) " : " \
|
|
463
|
+
STRINGIZE(t0) " -> " STRINGIZE(t1) \
|
|
464
|
+
" currently requires double rounding"); \
|
|
465
|
+
\
|
|
466
|
+
return -1; \
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
#define CPU_HOST_UNARY_INIT(funcname, func, t0, t1) \
|
|
471
|
+
{ .name = STRINGIZE(funcname), \
|
|
472
|
+
.sig = "... * " STRINGIZE(t0) " -> ... * " STRINGIZE(t1), \
|
|
473
|
+
.OptC = gm_cpu_host_fixed_1D_C_##func##_##t0##_##t1, \
|
|
474
|
+
.OptS = gm_cpu_host_fixed_1D_S_##func##_##t0##_##t1, \
|
|
475
|
+
.C = gm_cpu_host_0D_##func##_##t0##_##t1 }, \
|
|
476
|
+
\
|
|
477
|
+
{ .name = STRINGIZE(funcname), \
|
|
478
|
+
.sig = "... * ?" STRINGIZE(t0) " -> ... * ?" STRINGIZE(t1), \
|
|
479
|
+
.OptC = gm_cpu_host_fixed_1D_C_##func##_##t0##_##t1, \
|
|
480
|
+
.OptS = gm_cpu_host_fixed_1D_S_##func##_##t0##_##t1, \
|
|
481
|
+
.C = gm_cpu_host_0D_##func##_##t0##_##t1 }, \
|
|
482
|
+
\
|
|
483
|
+
{ .name = STRINGIZE(funcname), \
|
|
484
|
+
.sig = "var... * " STRINGIZE(t0) " -> var... * " STRINGIZE(t1), \
|
|
485
|
+
.Xnd = gm_cpu_host_0D_##func##_##t0##_##t1 }, \
|
|
486
|
+
\
|
|
487
|
+
{ .name = STRINGIZE(funcname), \
|
|
488
|
+
.sig = "var... * ?" STRINGIZE(t0) " -> var... * ?" STRINGIZE(t1), \
|
|
489
|
+
.Xnd = gm_cpu_host_0D_##func##_##t0##_##t1 }, \
|
|
490
|
+
\
|
|
491
|
+
{ .name = STRINGIZE(funcname), \
|
|
492
|
+
.sig = "array... * " STRINGIZE(t0) " -> array... * " STRINGIZE(t1), \
|
|
493
|
+
.OptC = gm_cpu_host_array_1D_C_##func##_##t0##_##t1, \
|
|
494
|
+
.C = gm_cpu_host_0D_##func##_##t0##_##t1 }, \
|
|
495
|
+
\
|
|
496
|
+
{ .name = STRINGIZE(funcname), \
|
|
497
|
+
.sig = "array... * ?" STRINGIZE(t0) " -> array... * ?" STRINGIZE(t1), \
|
|
498
|
+
.OptC = gm_cpu_host_array_1D_C_##func##_##t0##_##t1, \
|
|
499
|
+
.C = gm_cpu_host_0D_##func##_##t0##_##t1 }
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
#undef bool
|
|
503
|
+
#define bool_t _Bool
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
/*****************************************************************************/
|
|
507
|
+
/* Copy */
|
|
508
|
+
/*****************************************************************************/
|
|
509
|
+
|
|
510
|
+
#define CPU_HOST_ALL_UNARY(name) \
|
|
511
|
+
CPU_HOST_UNARY(name, bool, bool) \
|
|
512
|
+
CPU_HOST_UNARY(name, bool, uint8) \
|
|
513
|
+
CPU_HOST_UNARY(name, bool, uint16) \
|
|
514
|
+
CPU_HOST_UNARY(name, bool, uint32) \
|
|
515
|
+
CPU_HOST_UNARY(name, bool, uint64) \
|
|
516
|
+
CPU_HOST_UNARY(name, bool, int8) \
|
|
517
|
+
CPU_HOST_UNARY(name, bool, int16) \
|
|
518
|
+
CPU_HOST_UNARY(name, bool, int32) \
|
|
519
|
+
CPU_HOST_UNARY(name, bool, int64) \
|
|
520
|
+
CPU_HOST_UNARY(name, bool, bfloat16) \
|
|
521
|
+
CPU_HOST_NOIMPL(name, bool, float16) \
|
|
522
|
+
CPU_HOST_UNARY(name, bool, float32) \
|
|
523
|
+
CPU_HOST_UNARY(name, bool, float64) \
|
|
524
|
+
CPU_HOST_NOIMPL(name,bool, complex32) \
|
|
525
|
+
CPU_HOST_UNARY(name, bool, complex64) \
|
|
526
|
+
CPU_HOST_UNARY(name, bool, complex128) \
|
|
527
|
+
\
|
|
528
|
+
CPU_HOST_UNARY(name, uint8, uint8) \
|
|
529
|
+
CPU_HOST_UNARY(name, uint8, uint16) \
|
|
530
|
+
CPU_HOST_UNARY(name, uint8, uint32) \
|
|
531
|
+
CPU_HOST_UNARY(name, uint8, uint64) \
|
|
532
|
+
CPU_HOST_UNARY(name, uint8, int16) \
|
|
533
|
+
CPU_HOST_UNARY(name, uint8, int32) \
|
|
534
|
+
CPU_HOST_UNARY(name, uint8, int64) \
|
|
535
|
+
CPU_HOST_UNARY(name, uint8, bfloat16) \
|
|
536
|
+
CPU_HOST_NOIMPL(name, uint8, float16) \
|
|
537
|
+
CPU_HOST_UNARY(name, uint8, float32) \
|
|
538
|
+
CPU_HOST_UNARY(name, uint8, float64) \
|
|
539
|
+
CPU_HOST_NOIMPL(name, uint8, complex32) \
|
|
540
|
+
CPU_HOST_UNARY(name, uint8, complex64) \
|
|
541
|
+
CPU_HOST_UNARY(name, uint8, complex128) \
|
|
542
|
+
\
|
|
543
|
+
CPU_HOST_UNARY(name, uint16, uint16) \
|
|
544
|
+
CPU_HOST_UNARY(name, uint16, uint32) \
|
|
545
|
+
CPU_HOST_UNARY(name, uint16, uint64) \
|
|
546
|
+
CPU_HOST_UNARY(name, uint16, int32) \
|
|
547
|
+
CPU_HOST_UNARY(name, uint16, int64) \
|
|
548
|
+
CPU_HOST_UNARY(name, uint16, float32) \
|
|
549
|
+
CPU_HOST_UNARY(name, uint16, float64) \
|
|
550
|
+
CPU_HOST_UNARY(name, uint16, complex64) \
|
|
551
|
+
CPU_HOST_UNARY(name, uint16, complex128) \
|
|
552
|
+
\
|
|
553
|
+
CPU_HOST_UNARY(name, uint32, uint32) \
|
|
554
|
+
CPU_HOST_UNARY(name, uint32, uint64) \
|
|
555
|
+
CPU_HOST_UNARY(name, uint32, int64) \
|
|
556
|
+
CPU_HOST_UNARY(name, uint32, float64) \
|
|
557
|
+
CPU_HOST_UNARY(name, uint32, complex128) \
|
|
558
|
+
\
|
|
559
|
+
CPU_HOST_UNARY(name, uint64, uint64) \
|
|
560
|
+
\
|
|
561
|
+
CPU_HOST_UNARY(name, int8, int8) \
|
|
562
|
+
CPU_HOST_UNARY(name, int8, int16) \
|
|
563
|
+
CPU_HOST_UNARY(name, int8, int32) \
|
|
564
|
+
CPU_HOST_UNARY(name, int8, int64) \
|
|
565
|
+
CPU_HOST_UNARY(name, int8, bfloat16) \
|
|
566
|
+
CPU_HOST_NOIMPL(name, int8, float16) \
|
|
567
|
+
CPU_HOST_UNARY(name, int8, float32) \
|
|
568
|
+
CPU_HOST_UNARY(name, int8, float64) \
|
|
569
|
+
CPU_HOST_NOIMPL(name, int8, complex32) \
|
|
570
|
+
CPU_HOST_UNARY(name, int8, complex64) \
|
|
571
|
+
CPU_HOST_UNARY(name, int8, complex128) \
|
|
572
|
+
\
|
|
573
|
+
CPU_HOST_UNARY(name, int16, int16) \
|
|
574
|
+
CPU_HOST_UNARY(name, int16, int32) \
|
|
575
|
+
CPU_HOST_UNARY(name, int16, int64) \
|
|
576
|
+
CPU_HOST_UNARY(name, int16, float32) \
|
|
577
|
+
CPU_HOST_UNARY(name, int16, float64) \
|
|
578
|
+
CPU_HOST_UNARY(name, int16, complex64) \
|
|
579
|
+
CPU_HOST_UNARY(name, int16, complex128) \
|
|
580
|
+
\
|
|
581
|
+
CPU_HOST_UNARY(name, int32, int32) \
|
|
582
|
+
CPU_HOST_UNARY(name, int32, int64) \
|
|
583
|
+
CPU_HOST_UNARY(name, int32, float64) \
|
|
584
|
+
CPU_HOST_UNARY(name, int32, complex128) \
|
|
585
|
+
\
|
|
586
|
+
CPU_HOST_UNARY(name, int64, int64) \
|
|
587
|
+
\
|
|
588
|
+
CPU_HOST_UNARY(name, bfloat16, bfloat16) \
|
|
589
|
+
CPU_HOST_UNARY(name, bfloat16, float32) \
|
|
590
|
+
CPU_HOST_UNARY(name, bfloat16, float64) \
|
|
591
|
+
CPU_HOST_UNARY(name, bfloat16, complex64) \
|
|
592
|
+
CPU_HOST_UNARY(name, bfloat16, complex128) \
|
|
593
|
+
\
|
|
594
|
+
CPU_HOST_NOIMPL(name, float16, float16) \
|
|
595
|
+
CPU_HOST_NOIMPL(name, float16, float32) \
|
|
596
|
+
CPU_HOST_NOIMPL(name, float16, float64) \
|
|
597
|
+
CPU_HOST_NOIMPL(name, float16, complex32) \
|
|
598
|
+
CPU_HOST_NOIMPL(name, float16, complex64) \
|
|
599
|
+
CPU_HOST_NOIMPL(name, float16, complex128) \
|
|
600
|
+
\
|
|
601
|
+
CPU_HOST_UNARY(name, float32, float32) \
|
|
602
|
+
CPU_HOST_UNARY(name, float32, float64) \
|
|
603
|
+
CPU_HOST_UNARY(name, float32, complex64) \
|
|
604
|
+
CPU_HOST_UNARY(name, float32, complex128) \
|
|
605
|
+
\
|
|
606
|
+
CPU_HOST_UNARY(name, float64, float64) \
|
|
607
|
+
CPU_HOST_UNARY(name, float64, complex128) \
|
|
608
|
+
\
|
|
609
|
+
CPU_HOST_NOIMPL(name, complex32, complex32) \
|
|
610
|
+
CPU_HOST_NOIMPL(name, complex32, complex64) \
|
|
611
|
+
CPU_HOST_NOIMPL(name, complex32, complex128) \
|
|
612
|
+
\
|
|
613
|
+
CPU_HOST_UNARY(name, complex64, complex64) \
|
|
614
|
+
CPU_HOST_UNARY(name, complex64, complex128) \
|
|
615
|
+
\
|
|
616
|
+
CPU_HOST_UNARY(name, complex128, complex128)
|
|
617
|
+
|
|
618
|
+
#define CPU_HOST_ALL_UNARY_INIT(name, func, hfunc) \
|
|
619
|
+
CPU_HOST_UNARY_INIT(name, func, bool, bool), \
|
|
620
|
+
CPU_HOST_UNARY_INIT(name, func, bool, uint8), \
|
|
621
|
+
CPU_HOST_UNARY_INIT(name, func, bool, uint16), \
|
|
622
|
+
CPU_HOST_UNARY_INIT(name, func, bool, uint32), \
|
|
623
|
+
CPU_HOST_UNARY_INIT(name, func, bool, uint64), \
|
|
624
|
+
CPU_HOST_UNARY_INIT(name, func, bool, int8), \
|
|
625
|
+
CPU_HOST_UNARY_INIT(name, func, bool, int16), \
|
|
626
|
+
CPU_HOST_UNARY_INIT(name, func, bool, int32), \
|
|
627
|
+
CPU_HOST_UNARY_INIT(name, func, bool, int64), \
|
|
628
|
+
CPU_HOST_UNARY_INIT(name, func, bool, bfloat16), \
|
|
629
|
+
CPU_HOST_UNARY_INIT(name, hfunc, bool, float16), \
|
|
630
|
+
CPU_HOST_UNARY_INIT(name, func, bool, float32), \
|
|
631
|
+
CPU_HOST_UNARY_INIT(name, func, bool, float64), \
|
|
632
|
+
CPU_HOST_UNARY_INIT(name, func, bool, complex32), \
|
|
633
|
+
CPU_HOST_UNARY_INIT(name, func, bool, complex64), \
|
|
634
|
+
CPU_HOST_UNARY_INIT(name, func, bool, complex128), \
|
|
635
|
+
\
|
|
636
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, uint8), \
|
|
637
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, uint16), \
|
|
638
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, uint32), \
|
|
639
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, uint64), \
|
|
640
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, int16), \
|
|
641
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, int32), \
|
|
642
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, int64), \
|
|
643
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, bfloat16), \
|
|
644
|
+
CPU_HOST_UNARY_INIT(name, hfunc, uint8, float16), \
|
|
645
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, float32), \
|
|
646
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, float64), \
|
|
647
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, complex32), \
|
|
648
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, complex64), \
|
|
649
|
+
CPU_HOST_UNARY_INIT(name, func, uint8, complex128), \
|
|
650
|
+
\
|
|
651
|
+
CPU_HOST_UNARY_INIT(name, func, uint16, uint16), \
|
|
652
|
+
CPU_HOST_UNARY_INIT(name, func, uint16, uint32), \
|
|
653
|
+
CPU_HOST_UNARY_INIT(name, func, uint16, uint64), \
|
|
654
|
+
CPU_HOST_UNARY_INIT(name, func, uint16, int32), \
|
|
655
|
+
CPU_HOST_UNARY_INIT(name, func, uint16, int64), \
|
|
656
|
+
CPU_HOST_UNARY_INIT(name, func, uint16, float32), \
|
|
657
|
+
CPU_HOST_UNARY_INIT(name, func, uint16, float64), \
|
|
658
|
+
CPU_HOST_UNARY_INIT(name, func, uint16, complex64), \
|
|
659
|
+
CPU_HOST_UNARY_INIT(name, func, uint16, complex128), \
|
|
660
|
+
\
|
|
661
|
+
CPU_HOST_UNARY_INIT(name, func, uint32, uint32), \
|
|
662
|
+
CPU_HOST_UNARY_INIT(name, func, uint32, uint64), \
|
|
663
|
+
CPU_HOST_UNARY_INIT(name, func, uint32, int64), \
|
|
664
|
+
CPU_HOST_UNARY_INIT(name, func, uint32, float64), \
|
|
665
|
+
CPU_HOST_UNARY_INIT(name, func, uint32, complex128), \
|
|
666
|
+
\
|
|
667
|
+
CPU_HOST_UNARY_INIT(name, func, uint64, uint64), \
|
|
668
|
+
\
|
|
669
|
+
CPU_HOST_UNARY_INIT(name, func, int8, int8), \
|
|
670
|
+
CPU_HOST_UNARY_INIT(name, func, int8, int16), \
|
|
671
|
+
CPU_HOST_UNARY_INIT(name, func, int8, int32), \
|
|
672
|
+
CPU_HOST_UNARY_INIT(name, func, int8, int64), \
|
|
673
|
+
CPU_HOST_UNARY_INIT(name, func, int8, bfloat16), \
|
|
674
|
+
CPU_HOST_UNARY_INIT(name, hfunc, int8, float16), \
|
|
675
|
+
CPU_HOST_UNARY_INIT(name, func, int8, float32), \
|
|
676
|
+
CPU_HOST_UNARY_INIT(name, func, int8, float64), \
|
|
677
|
+
CPU_HOST_UNARY_INIT(name, func, int8, complex32), \
|
|
678
|
+
CPU_HOST_UNARY_INIT(name, func, int8, complex64), \
|
|
679
|
+
CPU_HOST_UNARY_INIT(name, func, int8, complex128), \
|
|
680
|
+
\
|
|
681
|
+
CPU_HOST_UNARY_INIT(name, func, int16, int16), \
|
|
682
|
+
CPU_HOST_UNARY_INIT(name, func, int16, int32), \
|
|
683
|
+
CPU_HOST_UNARY_INIT(name, func, int16, int64), \
|
|
684
|
+
CPU_HOST_UNARY_INIT(name, func, int16, float32), \
|
|
685
|
+
CPU_HOST_UNARY_INIT(name, func, int16, float64), \
|
|
686
|
+
CPU_HOST_UNARY_INIT(name, func, int16, complex64), \
|
|
687
|
+
CPU_HOST_UNARY_INIT(name, func, int16, complex128), \
|
|
688
|
+
\
|
|
689
|
+
CPU_HOST_UNARY_INIT(name, func, int32, int32), \
|
|
690
|
+
CPU_HOST_UNARY_INIT(name, func, int32, int64), \
|
|
691
|
+
CPU_HOST_UNARY_INIT(name, func, int32, float64), \
|
|
692
|
+
CPU_HOST_UNARY_INIT(name, func, int32, complex128), \
|
|
693
|
+
\
|
|
694
|
+
CPU_HOST_UNARY_INIT(name, func, int64, int64), \
|
|
695
|
+
\
|
|
696
|
+
CPU_HOST_UNARY_INIT(name, func, bfloat16, bfloat16), \
|
|
697
|
+
CPU_HOST_UNARY_INIT(name, func, bfloat16, float32), \
|
|
698
|
+
CPU_HOST_UNARY_INIT(name, func, bfloat16, float64), \
|
|
699
|
+
CPU_HOST_UNARY_INIT(name, func, bfloat16, complex64), \
|
|
700
|
+
CPU_HOST_UNARY_INIT(name, func, bfloat16, complex128), \
|
|
701
|
+
\
|
|
702
|
+
CPU_HOST_UNARY_INIT(name, hfunc, float16, float16), \
|
|
703
|
+
CPU_HOST_UNARY_INIT(name, func, float16, float32), \
|
|
704
|
+
CPU_HOST_UNARY_INIT(name, func, float16, float64), \
|
|
705
|
+
CPU_HOST_UNARY_INIT(name, func, float16, complex32), \
|
|
706
|
+
CPU_HOST_UNARY_INIT(name, func, float16, complex64), \
|
|
707
|
+
CPU_HOST_UNARY_INIT(name, func, float16, complex128), \
|
|
708
|
+
\
|
|
709
|
+
CPU_HOST_UNARY_INIT(name, func, float32, float32), \
|
|
710
|
+
CPU_HOST_UNARY_INIT(name, func, float32, float64), \
|
|
711
|
+
CPU_HOST_UNARY_INIT(name, func, float32, complex64), \
|
|
712
|
+
CPU_HOST_UNARY_INIT(name, func, float32, complex128), \
|
|
713
|
+
\
|
|
714
|
+
CPU_HOST_UNARY_INIT(name, func, float64, float64), \
|
|
715
|
+
CPU_HOST_UNARY_INIT(name, func, float64, complex128), \
|
|
716
|
+
\
|
|
717
|
+
CPU_HOST_UNARY_INIT(name, func, complex32, complex32), \
|
|
718
|
+
CPU_HOST_UNARY_INIT(name, func, complex32, complex64), \
|
|
719
|
+
CPU_HOST_UNARY_INIT(name, func, complex32, complex128), \
|
|
720
|
+
\
|
|
721
|
+
CPU_HOST_UNARY_INIT(name, func, complex64, complex64), \
|
|
722
|
+
CPU_HOST_UNARY_INIT(name, func, complex64, complex128), \
|
|
723
|
+
\
|
|
724
|
+
CPU_HOST_UNARY_INIT(name, func, complex128, complex128)
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
CPU_HOST_ALL_UNARY(copy)
|
|
728
|
+
CPU_HOST_ALL_UNARY(abs)
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
static const gm_kernel_init_t unary_copy[] = {
|
|
732
|
+
/* COPY */
|
|
733
|
+
CPU_HOST_ALL_UNARY_INIT(copy, copy, copy),
|
|
734
|
+
CPU_HOST_ALL_UNARY_INIT(abs, abs, abs),
|
|
735
|
+
|
|
736
|
+
{ .name = NULL, .sig = NULL }
|
|
737
|
+
};
|
|
738
|
+
|
|
739
|
+
/*****************************************************************************/
|
|
740
|
+
/* Bitwise NOT */
|
|
741
|
+
/*****************************************************************************/
|
|
742
|
+
|
|
743
|
+
CPU_HOST_UNARY(invert, bool, bool)
|
|
744
|
+
|
|
745
|
+
CPU_HOST_UNARY(invert, uint8, uint8)
|
|
746
|
+
CPU_HOST_UNARY(invert, uint16, uint16)
|
|
747
|
+
CPU_HOST_UNARY(invert, uint32, uint32)
|
|
748
|
+
CPU_HOST_UNARY(invert, uint64, uint64)
|
|
749
|
+
|
|
750
|
+
CPU_HOST_UNARY(invert, int8, int8)
|
|
751
|
+
CPU_HOST_UNARY(invert, int16, int16)
|
|
752
|
+
CPU_HOST_UNARY(invert, int32, int32)
|
|
753
|
+
CPU_HOST_UNARY(invert, int64, int64)
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
static const gm_kernel_init_t unary_invert[] = {
|
|
757
|
+
/* INVERT */
|
|
758
|
+
CPU_HOST_UNARY_INIT(invert, invert, bool, bool),
|
|
759
|
+
|
|
760
|
+
CPU_HOST_UNARY_INIT(invert, invert, uint8, uint8),
|
|
761
|
+
CPU_HOST_UNARY_INIT(invert, invert, uint16, uint16),
|
|
762
|
+
CPU_HOST_UNARY_INIT(invert, invert, uint32, uint32),
|
|
763
|
+
CPU_HOST_UNARY_INIT(invert, invert, uint64, uint64),
|
|
764
|
+
|
|
765
|
+
CPU_HOST_UNARY_INIT(invert, invert, int8, int8),
|
|
766
|
+
CPU_HOST_UNARY_INIT(invert, invert, int16, int16),
|
|
767
|
+
CPU_HOST_UNARY_INIT(invert, invert, int32, int32),
|
|
768
|
+
CPU_HOST_UNARY_INIT(invert, invert, int64, int64),
|
|
769
|
+
|
|
770
|
+
{ .name = NULL, .sig = NULL }
|
|
771
|
+
};
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
/*****************************************************************************/
|
|
775
|
+
/* Negative */
|
|
776
|
+
/*****************************************************************************/
|
|
777
|
+
|
|
778
|
+
CPU_HOST_UNARY(negative, uint8, int16)
|
|
779
|
+
CPU_HOST_UNARY(negative, uint16, int32)
|
|
780
|
+
CPU_HOST_UNARY(negative, uint32, int64)
|
|
781
|
+
|
|
782
|
+
CPU_HOST_UNARY(negative, int8, int8)
|
|
783
|
+
CPU_HOST_UNARY(negative, int16, int16)
|
|
784
|
+
CPU_HOST_UNARY(negative, int32, int32)
|
|
785
|
+
CPU_HOST_UNARY(negative, int64, int64)
|
|
786
|
+
|
|
787
|
+
CPU_HOST_UNARY(negative, bfloat16, bfloat16)
|
|
788
|
+
CPU_HOST_NOIMPL(negative, float16, float16)
|
|
789
|
+
CPU_HOST_UNARY(negative, float32, float32)
|
|
790
|
+
CPU_HOST_UNARY(negative, float64, float64)
|
|
791
|
+
|
|
792
|
+
CPU_HOST_NOIMPL(negative, complex32, complex32)
|
|
793
|
+
CPU_HOST_UNARY(negative, complex64, complex64)
|
|
794
|
+
CPU_HOST_UNARY(negative, complex128, complex128)
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
static const gm_kernel_init_t unary_negative[] = {
|
|
798
|
+
/* NEGATIVE */
|
|
799
|
+
CPU_HOST_UNARY_INIT(negative, negative, uint8, int16),
|
|
800
|
+
CPU_HOST_UNARY_INIT(negative, negative, uint16, int32),
|
|
801
|
+
CPU_HOST_UNARY_INIT(negative, negative, uint32, int64),
|
|
802
|
+
|
|
803
|
+
CPU_HOST_UNARY_INIT(negative, negative, int8, int8),
|
|
804
|
+
CPU_HOST_UNARY_INIT(negative, negative, int16, int16),
|
|
805
|
+
CPU_HOST_UNARY_INIT(negative, negative, int32, int32),
|
|
806
|
+
CPU_HOST_UNARY_INIT(negative, negative, int64, int64),
|
|
807
|
+
|
|
808
|
+
CPU_HOST_UNARY_INIT(negative, negative, bfloat16, bfloat16),
|
|
809
|
+
CPU_HOST_UNARY_INIT(negative, negative, float16, float16),
|
|
810
|
+
CPU_HOST_UNARY_INIT(negative, negative, float32, float32),
|
|
811
|
+
CPU_HOST_UNARY_INIT(negative, negative, float64, float64),
|
|
812
|
+
|
|
813
|
+
CPU_HOST_UNARY_INIT(negative, negative, complex32, complex32),
|
|
814
|
+
CPU_HOST_UNARY_INIT(negative, negative, complex64, complex64),
|
|
815
|
+
CPU_HOST_UNARY_INIT(negative, negative, complex128, complex128),
|
|
816
|
+
|
|
817
|
+
{ .name = NULL, .sig = NULL }
|
|
818
|
+
};
|
|
819
|
+
|
|
820
|
+
|
|
821
|
+
/*****************************************************************************/
|
|
822
|
+
/* Math */
|
|
823
|
+
/*****************************************************************************/
|
|
824
|
+
|
|
825
|
+
#define _CPU_ALL_HALF_MATH(name) \
|
|
826
|
+
CPU_HOST_UNARY(name##f16, uint8, float16) \
|
|
827
|
+
CPU_HOST_UNARY(name##f16, int8, float16) \
|
|
828
|
+
CPU_HOST_UNARY(name##f16, float16, float16)
|
|
829
|
+
|
|
830
|
+
#define _CPU_ALL_HALF_MATH_NOIMPL(name) \
|
|
831
|
+
CPU_HOST_NOIMPL(name##f16, uint8, float16) \
|
|
832
|
+
CPU_HOST_NOIMPL(name##f16, int8, float16) \
|
|
833
|
+
CPU_HOST_NOIMPL(name##f16, float16, float16)
|
|
834
|
+
|
|
835
|
+
#define _CPU_ALL_COMPLEX_MATH(name) \
|
|
836
|
+
CPU_HOST_NOIMPL(name, complex32, complex32) \
|
|
837
|
+
CPU_HOST_UNARY(name, complex64, complex64) \
|
|
838
|
+
CPU_HOST_UNARY(name, complex128, complex128)
|
|
839
|
+
|
|
840
|
+
#define _CPU_ALL_COMPLEX_MATH_NOIMPL(name) \
|
|
841
|
+
CPU_HOST_NOIMPL(name, complex32, complex32) \
|
|
842
|
+
CPU_HOST_NOIMPL(name, complex64, complex64) \
|
|
843
|
+
CPU_HOST_NOIMPL(name, complex128, complex128)
|
|
844
|
+
|
|
845
|
+
#define _CPU_ALL_REAL_MATH(name) \
|
|
846
|
+
CPU_HOST_UNARY(name##b16, bfloat16, bfloat16) \
|
|
847
|
+
CPU_HOST_UNARY(name##f, uint16, float32) \
|
|
848
|
+
CPU_HOST_UNARY(name##f, int16, float32) \
|
|
849
|
+
CPU_HOST_UNARY(name##f, float32, float32) \
|
|
850
|
+
CPU_HOST_UNARY(name, uint32, float64) \
|
|
851
|
+
CPU_HOST_UNARY(name, int32, float64) \
|
|
852
|
+
CPU_HOST_UNARY(name, float64, float64) \
|
|
853
|
+
|
|
854
|
+
#define CPU_ALL_REAL_MATH(name) \
|
|
855
|
+
_CPU_ALL_HALF_MATH_NOIMPL(name) \
|
|
856
|
+
_CPU_ALL_REAL_MATH(name) \
|
|
857
|
+
_CPU_ALL_COMPLEX_MATH_NOIMPL(name)
|
|
858
|
+
|
|
859
|
+
#define CPU_ALL_REAL_MATH_WITH_HALF(name) \
|
|
860
|
+
_CPU_ALL_HALF_MATH(name) \
|
|
861
|
+
_CPU_ALL_REAL_MATH(name) \
|
|
862
|
+
_CPU_ALL_COMPLEX_MATH_NOIMPL(name)
|
|
863
|
+
|
|
864
|
+
#define CPU_ALL_COMPLEX_MATH(name) \
|
|
865
|
+
_CPU_ALL_HALF_MATH_NOIMPL(name) \
|
|
866
|
+
_CPU_ALL_REAL_MATH(name) \
|
|
867
|
+
_CPU_ALL_COMPLEX_MATH(name)
|
|
868
|
+
|
|
869
|
+
#define CPU_ALL_COMPLEX_MATH_WITH_HALF(name) \
|
|
870
|
+
_CPU_ALL_HALF_MATH(name) \
|
|
871
|
+
_CPU_ALL_REAL_MATH(name) \
|
|
872
|
+
_CPU_ALL_COMPLEX_MATH(name) \
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
#define CPU_ALL_UNARY_MATH_INIT(name) \
|
|
876
|
+
CPU_HOST_UNARY_INIT(name, name##f16, uint8, float16), \
|
|
877
|
+
CPU_HOST_UNARY_INIT(name, name##f16, int8, float16), \
|
|
878
|
+
CPU_HOST_UNARY_INIT(name, name##f16, float16, float16), \
|
|
879
|
+
\
|
|
880
|
+
CPU_HOST_UNARY_INIT(name, name##b16, bfloat16, bfloat16), \
|
|
881
|
+
\
|
|
882
|
+
CPU_HOST_UNARY_INIT(name, name##f, uint16, float32), \
|
|
883
|
+
CPU_HOST_UNARY_INIT(name, name##f, int16, float32), \
|
|
884
|
+
CPU_HOST_UNARY_INIT(name, name##f, float32, float32), \
|
|
885
|
+
\
|
|
886
|
+
CPU_HOST_UNARY_INIT(name, name, uint32, float64), \
|
|
887
|
+
CPU_HOST_UNARY_INIT(name, name, int32, float64), \
|
|
888
|
+
CPU_HOST_UNARY_INIT(name, name, float64, float64), \
|
|
889
|
+
\
|
|
890
|
+
CPU_HOST_UNARY_INIT(name, name, complex32, complex32), \
|
|
891
|
+
CPU_HOST_UNARY_INIT(name, name, complex64, complex64), \
|
|
892
|
+
CPU_HOST_UNARY_INIT(name, name, complex128, complex128)
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
/*****************************************************************************/
|
|
896
|
+
/* Abs functions */
|
|
897
|
+
/*****************************************************************************/
|
|
898
|
+
|
|
899
|
+
CPU_ALL_REAL_MATH(fabs)
|
|
900
|
+
|
|
901
|
+
|
|
902
|
+
/*****************************************************************************/
|
|
903
|
+
/* Exponential functions */
|
|
904
|
+
/*****************************************************************************/
|
|
905
|
+
|
|
906
|
+
CPU_ALL_COMPLEX_MATH(exp)
|
|
907
|
+
CPU_ALL_REAL_MATH(exp2)
|
|
908
|
+
CPU_ALL_REAL_MATH(expm1)
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
/*****************************************************************************/
|
|
912
|
+
/* Logarithm functions */
|
|
913
|
+
/*****************************************************************************/
|
|
914
|
+
|
|
915
|
+
CPU_ALL_COMPLEX_MATH(log)
|
|
916
|
+
CPU_ALL_COMPLEX_MATH(log10)
|
|
917
|
+
CPU_ALL_REAL_MATH(log2)
|
|
918
|
+
CPU_ALL_REAL_MATH(log1p)
|
|
919
|
+
CPU_ALL_REAL_MATH(logb)
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
/*****************************************************************************/
|
|
923
|
+
/* Power functions */
|
|
924
|
+
/*****************************************************************************/
|
|
925
|
+
|
|
926
|
+
CPU_ALL_COMPLEX_MATH(sqrt)
|
|
927
|
+
CPU_ALL_REAL_MATH(cbrt)
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
/*****************************************************************************/
|
|
931
|
+
/* Trigonometric functions */
|
|
932
|
+
/*****************************************************************************/
|
|
933
|
+
|
|
934
|
+
CPU_ALL_COMPLEX_MATH(sin)
|
|
935
|
+
CPU_ALL_COMPLEX_MATH(cos)
|
|
936
|
+
CPU_ALL_COMPLEX_MATH(tan)
|
|
937
|
+
CPU_ALL_COMPLEX_MATH(asin)
|
|
938
|
+
CPU_ALL_COMPLEX_MATH(acos)
|
|
939
|
+
CPU_ALL_COMPLEX_MATH(atan)
|
|
940
|
+
|
|
941
|
+
|
|
942
|
+
/*****************************************************************************/
|
|
943
|
+
/* Hyperbolic functions */
|
|
944
|
+
/*****************************************************************************/
|
|
945
|
+
|
|
946
|
+
CPU_ALL_COMPLEX_MATH(sinh)
|
|
947
|
+
CPU_ALL_COMPLEX_MATH(cosh)
|
|
948
|
+
CPU_ALL_COMPLEX_MATH(tanh)
|
|
949
|
+
CPU_ALL_COMPLEX_MATH(asinh)
|
|
950
|
+
CPU_ALL_COMPLEX_MATH(acosh)
|
|
951
|
+
CPU_ALL_COMPLEX_MATH(atanh)
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
/*****************************************************************************/
|
|
955
|
+
/* Error and gamma functions */
|
|
956
|
+
/*****************************************************************************/
|
|
957
|
+
|
|
958
|
+
CPU_ALL_REAL_MATH(erf)
|
|
959
|
+
CPU_ALL_REAL_MATH(erfc)
|
|
960
|
+
CPU_ALL_REAL_MATH(lgamma)
|
|
961
|
+
CPU_ALL_REAL_MATH(tgamma)
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
/*****************************************************************************/
|
|
965
|
+
/* Ceiling, floor, trunc */
|
|
966
|
+
/*****************************************************************************/
|
|
967
|
+
|
|
968
|
+
CPU_ALL_REAL_MATH(ceil)
|
|
969
|
+
CPU_ALL_REAL_MATH(floor)
|
|
970
|
+
CPU_ALL_REAL_MATH(trunc)
|
|
971
|
+
CPU_ALL_REAL_MATH(round)
|
|
972
|
+
CPU_ALL_REAL_MATH(nearbyint)
|
|
973
|
+
|
|
974
|
+
|
|
975
|
+
static const gm_kernel_init_t unary_float[] = {
|
|
976
|
+
/* ABS */
|
|
977
|
+
CPU_ALL_UNARY_MATH_INIT(fabs),
|
|
978
|
+
|
|
979
|
+
/* EXPONENTIAL */
|
|
980
|
+
CPU_ALL_UNARY_MATH_INIT(exp),
|
|
981
|
+
CPU_ALL_UNARY_MATH_INIT(exp2),
|
|
982
|
+
CPU_ALL_UNARY_MATH_INIT(expm1),
|
|
983
|
+
|
|
984
|
+
/* LOGARITHM */
|
|
985
|
+
CPU_ALL_UNARY_MATH_INIT(log),
|
|
986
|
+
CPU_ALL_UNARY_MATH_INIT(log2),
|
|
987
|
+
CPU_ALL_UNARY_MATH_INIT(log10),
|
|
988
|
+
CPU_ALL_UNARY_MATH_INIT(log1p),
|
|
989
|
+
CPU_ALL_UNARY_MATH_INIT(logb),
|
|
990
|
+
|
|
991
|
+
/* POWER */
|
|
992
|
+
CPU_ALL_UNARY_MATH_INIT(sqrt),
|
|
993
|
+
CPU_ALL_UNARY_MATH_INIT(cbrt),
|
|
994
|
+
|
|
995
|
+
/* TRIGONOMETRIC */
|
|
996
|
+
CPU_ALL_UNARY_MATH_INIT(sin),
|
|
997
|
+
CPU_ALL_UNARY_MATH_INIT(cos),
|
|
998
|
+
CPU_ALL_UNARY_MATH_INIT(tan),
|
|
999
|
+
CPU_ALL_UNARY_MATH_INIT(asin),
|
|
1000
|
+
CPU_ALL_UNARY_MATH_INIT(acos),
|
|
1001
|
+
CPU_ALL_UNARY_MATH_INIT(atan),
|
|
1002
|
+
|
|
1003
|
+
/* HYPERBOLIC */
|
|
1004
|
+
CPU_ALL_UNARY_MATH_INIT(sinh),
|
|
1005
|
+
CPU_ALL_UNARY_MATH_INIT(cosh),
|
|
1006
|
+
CPU_ALL_UNARY_MATH_INIT(tanh),
|
|
1007
|
+
CPU_ALL_UNARY_MATH_INIT(asinh),
|
|
1008
|
+
CPU_ALL_UNARY_MATH_INIT(acosh),
|
|
1009
|
+
CPU_ALL_UNARY_MATH_INIT(atanh),
|
|
1010
|
+
|
|
1011
|
+
/* ERROR AND GAMMA */
|
|
1012
|
+
CPU_ALL_UNARY_MATH_INIT(erf),
|
|
1013
|
+
CPU_ALL_UNARY_MATH_INIT(erfc),
|
|
1014
|
+
CPU_ALL_UNARY_MATH_INIT(lgamma),
|
|
1015
|
+
CPU_ALL_UNARY_MATH_INIT(tgamma),
|
|
1016
|
+
|
|
1017
|
+
/* CEILING, FLOOR, TRUNC */
|
|
1018
|
+
CPU_ALL_UNARY_MATH_INIT(ceil),
|
|
1019
|
+
CPU_ALL_UNARY_MATH_INIT(floor),
|
|
1020
|
+
CPU_ALL_UNARY_MATH_INIT(trunc),
|
|
1021
|
+
CPU_ALL_UNARY_MATH_INIT(round),
|
|
1022
|
+
CPU_ALL_UNARY_MATH_INIT(nearbyint),
|
|
1023
|
+
|
|
1024
|
+
{ .name = NULL, .sig = NULL }
|
|
1025
|
+
};
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
/****************************************************************************/
|
|
1029
|
+
/* Initialize kernel table */
|
|
1030
|
+
/****************************************************************************/
|
|
1031
|
+
|
|
1032
|
+
typedef _Bool bool;
|
|
1033
|
+
|
|
1034
|
+
static const gm_kernel_set_t *
|
|
1035
|
+
unary_copy_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
|
1036
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
|
1037
|
+
ndt_context_t *ctx)
|
|
1038
|
+
{
|
|
1039
|
+
return cpu_unary_typecheck(copy_kernel_location, spec, f, types, li,
|
|
1040
|
+
nin, nout, check_broadcast, ctx);
|
|
1041
|
+
}
|
|
1042
|
+
|
|
1043
|
+
static const gm_kernel_set_t *
|
|
1044
|
+
unary_invert_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
|
1045
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
|
1046
|
+
ndt_context_t *ctx)
|
|
1047
|
+
{
|
|
1048
|
+
return cpu_unary_typecheck(invert_kernel_location, spec, f, types, li,
|
|
1049
|
+
nin, nout, check_broadcast, ctx);
|
|
1050
|
+
}
|
|
1051
|
+
|
|
1052
|
+
static const gm_kernel_set_t *
|
|
1053
|
+
unary_negative_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
|
1054
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
|
1055
|
+
ndt_context_t *ctx)
|
|
1056
|
+
{
|
|
1057
|
+
return cpu_unary_typecheck(negative_kernel_location, spec, f, types, li,
|
|
1058
|
+
nin, nout, check_broadcast, ctx);
|
|
1059
|
+
}
|
|
1060
|
+
|
|
1061
|
+
static const gm_kernel_set_t *
|
|
1062
|
+
unary_math_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
|
1063
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
|
1064
|
+
ndt_context_t *ctx)
|
|
1065
|
+
{
|
|
1066
|
+
return cpu_unary_typecheck(math_kernel_location, spec, f, types, li,
|
|
1067
|
+
nin, nout, check_broadcast, ctx);
|
|
1068
|
+
}
|
|
1069
|
+
|
|
1070
|
+
int
|
|
1071
|
+
gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx)
|
|
1072
|
+
{
|
|
1073
|
+
const gm_kernel_init_t *k;
|
|
1074
|
+
|
|
1075
|
+
for (k = unary_copy; k->name != NULL; k++) {
|
|
1076
|
+
if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_copy_typecheck) < 0) {
|
|
1077
|
+
return -1;
|
|
1078
|
+
}
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
for (k = unary_invert; k->name != NULL; k++) {
|
|
1082
|
+
if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_invert_typecheck) < 0) {
|
|
1083
|
+
return -1;
|
|
1084
|
+
}
|
|
1085
|
+
}
|
|
1086
|
+
|
|
1087
|
+
for (k = unary_negative; k->name != NULL; k++) {
|
|
1088
|
+
if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_negative_typecheck) < 0) {
|
|
1089
|
+
return -1;
|
|
1090
|
+
}
|
|
1091
|
+
}
|
|
1092
|
+
|
|
1093
|
+
for (k = unary_float; k->name != NULL; k++) {
|
|
1094
|
+
if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_math_typecheck) < 0) {
|
|
1095
|
+
return -1;
|
|
1096
|
+
}
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
return 0;
|
|
1100
|
+
}
|