gumath 0.2.0dev5 → 0.2.0dev8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,1331 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #include <stdlib.h>
35
+ #include <stdint.h>
36
+ #include <string.h>
37
+ #include "ndtypes.h"
38
+ #include "xnd.h"
39
+ #include "gumath.h"
40
+ #include "common.h"
41
+ #include "cuda_device_unary.h"
42
+
43
+
44
+ /****************************************************************************/
45
+ /* Kernel locations for optimized lookup */
46
+ /****************************************************************************/
47
+
48
+ static int
49
+ copy_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
50
+ {
51
+ const ndt_t *t = ndt_dtype(in);
52
+ const ndt_t *u = ndt_dtype(out);
53
+
54
+ switch (t->tag) {
55
+ case Bool: {
56
+ switch (u->tag) {
57
+ case Bool: return 0;
58
+ case Uint8: return 2;
59
+ case Uint16: return 4;
60
+ case Uint32: return 6;
61
+ case Uint64: return 8;
62
+ case Int8: return 10;
63
+ case Int16: return 12;
64
+ case Int32: return 14;
65
+ case Int64: return 16;
66
+ case BFloat16: return 18;
67
+ case Float16: return 20;
68
+ case Float32: return 22;
69
+ case Float64: return 24;
70
+ case Complex32: return 26;
71
+ case Complex64: return 28;
72
+ case Complex128: return 30;
73
+ default: goto invalid_combination;
74
+ }
75
+ }
76
+
77
+ case Uint8: {
78
+ switch (u->tag) {
79
+ case Uint8: return 32;
80
+ case Uint16: return 34;
81
+ case Uint32: return 36;
82
+ case Uint64: return 38;
83
+ case Int16: return 40;
84
+ case Int32: return 42;
85
+ case Int64: return 44;
86
+ case BFloat16: return 46;
87
+ case Float16: return 48;
88
+ case Float32: return 50;
89
+ case Float64: return 52;
90
+ case Complex32: return 54;
91
+ case Complex64: return 56;
92
+ case Complex128: return 58;
93
+ default: goto invalid_combination;
94
+ }
95
+ }
96
+
97
+ case Uint16: {
98
+ switch (u->tag) {
99
+ case Uint16: return 60;
100
+ case Uint32: return 62;
101
+ case Uint64: return 64;
102
+ case Int32: return 66;
103
+ case Int64: return 68;
104
+ case Float32: return 70;
105
+ case Float64: return 72;
106
+ case Complex64: return 74;
107
+ case Complex128: return 76;
108
+ default: goto invalid_combination;
109
+ }
110
+ }
111
+
112
+ case Uint32: {
113
+ switch (u->tag) {
114
+ case Uint32: return 78;
115
+ case Uint64: return 80;
116
+ case Int64: return 82;
117
+ case Float64: return 84;
118
+ case Complex128: return 86;
119
+ default: goto invalid_combination;
120
+ }
121
+ }
122
+
123
+ case Uint64: {
124
+ switch (u->tag) {
125
+ case Uint64: return 88;
126
+ default: goto invalid_combination;
127
+ }
128
+ }
129
+
130
+ case Int8: {
131
+ switch (u->tag) {
132
+ case Int8: return 90;
133
+ case Int16: return 92;
134
+ case Int32: return 94;
135
+ case Int64: return 96;
136
+ case BFloat16: return 98;
137
+ case Float16: return 100;
138
+ case Float32: return 102;
139
+ case Float64: return 104;
140
+ case Complex32: return 106;
141
+ case Complex64: return 108;
142
+ case Complex128: return 110;
143
+ default: goto invalid_combination;
144
+ }
145
+ }
146
+
147
+ case Int16: {
148
+ switch (u->tag) {
149
+ case Int16: return 112;
150
+ case Int32: return 114;
151
+ case Int64: return 116;
152
+ case Float32: return 118;
153
+ case Float64: return 120;
154
+ case Complex64: return 122;
155
+ case Complex128: return 124;
156
+ default: goto invalid_combination;
157
+ }
158
+ }
159
+
160
+ case Int32: {
161
+ switch (u->tag) {
162
+ case Int32: return 126;
163
+ case Int64: return 128;
164
+ case Float64: return 130;
165
+ case Complex128: return 132;
166
+ default: goto invalid_combination;
167
+ }
168
+ }
169
+
170
+ case Int64: {
171
+ switch (u->tag) {
172
+ case Int64: return 134;
173
+ default: goto invalid_combination;
174
+ }
175
+ }
176
+
177
+ case BFloat16: {
178
+ switch (u->tag) {
179
+ case BFloat16: return 136;
180
+ case Float32: return 138;
181
+ case Float64: return 140;
182
+ case Complex64: return 142;
183
+ case Complex128: return 144;
184
+ default: goto invalid_combination;
185
+ }
186
+ }
187
+
188
+ case Float16: {
189
+ switch (u->tag) {
190
+ case Float16: return 146;
191
+ case Float32: return 148;
192
+ case Float64: return 150;
193
+ case Complex32: return 152;
194
+ case Complex64: return 154;
195
+ case Complex128: return 156;
196
+ default: goto invalid_combination;
197
+ }
198
+ }
199
+
200
+ case Float32: {
201
+ switch (u->tag) {
202
+ case Float32: return 158;
203
+ case Float64: return 160;
204
+ case Complex64: return 162;
205
+ case Complex128: return 164;
206
+ default: goto invalid_combination;
207
+ }
208
+ }
209
+
210
+ case Float64: {
211
+ switch (u->tag) {
212
+ case Float64: return 166;
213
+ case Complex128: return 168;
214
+ default: goto invalid_combination;
215
+ }
216
+ }
217
+
218
+ case Complex32: {
219
+ switch (u->tag) {
220
+ case Complex32: return 170;
221
+ case Complex64: return 172;
222
+ case Complex128: return 174;
223
+ default: goto invalid_combination;
224
+ }
225
+ }
226
+
227
+ case Complex64: {
228
+ switch (u->tag) {
229
+ case Complex64: return 176;
230
+ case Complex128: return 178;
231
+ default: goto invalid_combination;
232
+ }
233
+ }
234
+
235
+ case Complex128: {
236
+ switch (u->tag) {
237
+ case Complex128: return 180;
238
+ default: goto invalid_combination;
239
+ }
240
+ }
241
+
242
+ default: goto invalid_combination;
243
+ }
244
+
245
+ invalid_combination:
246
+ ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
247
+ return -1;
248
+ }
249
+
250
+ static int
251
+ invert_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
252
+ {
253
+ const ndt_t *t = ndt_dtype(in);
254
+ (void)out;
255
+
256
+ switch (t->tag) {
257
+ case Bool: return 0;
258
+
259
+ case Uint8: return 2;
260
+ case Uint16: return 4;
261
+ case Uint32: return 6;
262
+ case Uint64: return 8;
263
+
264
+ case Int8: return 10;
265
+ case Int16: return 12;
266
+ case Int32: return 14;
267
+ case Int64: return 16;
268
+
269
+ default:
270
+ ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
271
+ return -1;
272
+ }
273
+ }
274
+
275
+ static int
276
+ negative_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
277
+ {
278
+ const ndt_t *t = ndt_dtype(in);
279
+ (void)out;
280
+
281
+ switch (t->tag) {
282
+ case Uint8: return 0;
283
+ case Uint16: return 2;
284
+ case Uint32: return 4;
285
+
286
+ case Int8: return 6;
287
+ case Int16: return 8;
288
+ case Int32: return 10;
289
+ case Int64: return 12;
290
+
291
+ case BFloat16: return 14;
292
+ case Float16: return 16;
293
+ case Float32: return 18;
294
+ case Float64: return 20;
295
+
296
+ case Complex32: return 22;
297
+ case Complex64: return 24;
298
+ case Complex128: return 26;
299
+
300
+ default:
301
+ ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
302
+ return -1;
303
+ }
304
+ }
305
+
306
+ static int
307
+ math_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
308
+ {
309
+ const ndt_t *t = ndt_dtype(in);
310
+ (void)out;
311
+
312
+ switch (t->tag) {
313
+ case Uint8: return 0;
314
+ case Int8: return 2;
315
+ case Float16: return 4;
316
+
317
+ case BFloat16: return 6;
318
+
319
+ case Uint16: return 8;
320
+ case Int16: return 10;
321
+ case Float32: return 12;
322
+
323
+ case Uint32: return 14;
324
+ case Int32: return 16;
325
+ case Float64: return 18;
326
+
327
+ case Complex32: return 20;
328
+ case Complex64: return 22;
329
+ case Complex128: return 24;
330
+
331
+ default:
332
+ ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
333
+ return -1;
334
+ }
335
+ }
336
+
337
+
338
+ /*****************************************************************************/
339
+ /* CUDA-specific unary macros */
340
+ /*****************************************************************************/
341
+
342
+ #define CUDA_HOST_UNARY(name, t0, t1) \
343
+ static int \
344
+ gm_cuda_host_fixed_1D_C_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
345
+ { \
346
+ const char *a0 = apply_index(&stack[0]); \
347
+ char *a1 = apply_index(&stack[1]); \
348
+ const int64_t N = xnd_fixed_shape(&stack[0]); \
349
+ (void)ctx; \
350
+ \
351
+ gm_cuda_device_fixed_1D_C_##name##_##t0##_##t1(a0, a1, N); \
352
+ \
353
+ if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
354
+ unary_update_bitmap_1D_S(stack); \
355
+ } \
356
+ \
357
+ return 0; \
358
+ } \
359
+ \
360
+ static int \
361
+ gm_cuda_host_fixed_1D_S_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
362
+ { \
363
+ const char *a0 = apply_index(&stack[0]); \
364
+ char *a1 = apply_index(&stack[1]); \
365
+ const int64_t N = xnd_fixed_shape(&stack[0]); \
366
+ const int64_t s0 = xnd_fixed_step(&stack[0]); \
367
+ const int64_t s1 = xnd_fixed_step(&stack[1]); \
368
+ (void)ctx; \
369
+ \
370
+ gm_cuda_device_fixed_1D_S_##name##_##t0##_##t1(a0, a1, s0, s1, N); \
371
+ \
372
+ if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
373
+ unary_update_bitmap_1D_S(stack); \
374
+ } \
375
+ \
376
+ return 0; \
377
+ } \
378
+ \
379
+ static int \
380
+ gm_cuda_host_0D_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
381
+ { \
382
+ const char *a0 = stack[0].ptr; \
383
+ char *a1 = stack[1].ptr; \
384
+ (void)ctx; \
385
+ \
386
+ gm_cuda_device_0D_##name##_##t0##_##t1(a0, a1); \
387
+ \
388
+ if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
389
+ unary_update_bitmap_0D(stack); \
390
+ } \
391
+ \
392
+ return 0; \
393
+ }
394
+
395
+
396
+ #define CUDA_HOST_NOIMPL(name, t0, t1) \
397
+ static int \
398
+ gm_cuda_host_fixed_1D_C_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
399
+ { \
400
+ (void)stack; \
401
+ \
402
+ ndt_err_format(ctx, NDT_NotImplementedError, \
403
+ "implementation for " STRINGIZE(name) " : " \
404
+ STRINGIZE(t0) " -> " STRINGIZE(t1) \
405
+ " currently requires double rounding"); \
406
+ \
407
+ return -1; \
408
+ } \
409
+ \
410
+ static int \
411
+ gm_cuda_host_fixed_1D_S_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
412
+ { \
413
+ (void)stack; \
414
+ \
415
+ ndt_err_format(ctx, NDT_NotImplementedError, \
416
+ "implementation for " STRINGIZE(name) " : " \
417
+ STRINGIZE(t0) " -> " STRINGIZE(t1) \
418
+ " currently requires double rounding"); \
419
+ \
420
+ return -1; \
421
+ } \
422
+ \
423
+ static int \
424
+ gm_cuda_host_0D_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
425
+ { \
426
+ (void)stack; \
427
+ \
428
+ ndt_err_format(ctx, NDT_NotImplementedError, \
429
+ "implementation for " STRINGIZE(name) " : " \
430
+ STRINGIZE(t0) " -> " STRINGIZE(t1) \
431
+ " currently requires double rounding"); \
432
+ \
433
+ return -1; \
434
+ }
435
+
436
+
437
+ #define CUDA_HOST_UNARY_REDUCE(name, t0, t1) \
438
+ static int \
439
+ gm_cuda_host_1D_C_reduce_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
440
+ { \
441
+ const char *a0 = apply_index(&stack[0]); \
442
+ char *a1 = stack[1].ptr; \
443
+ const int64_t N = xnd_fixed_shape(&stack[0]); \
444
+ (void)ctx; \
445
+ \
446
+ gm_cuda_device_1D_C_reduce_##name##_##t0##_##t1(a0, a1, N); \
447
+ \
448
+ if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
449
+ unary_reduce_bitmap_1D_S(stack); \
450
+ } \
451
+ \
452
+ return 0; \
453
+ }
454
+
455
+ #define CUDA_HOST_REDUCE_NOIMPL(name, t0, t1) \
456
+ static int \
457
+ gm_cuda_host_1D_C_reduce_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
458
+ { \
459
+ (void)stack; \
460
+ \
461
+ ndt_err_format(ctx, NDT_NotImplementedError, \
462
+ "No cuda thrust implementation for: " STRINGIZE(name) " : " \
463
+ STRINGIZE(t0) " -> " STRINGIZE(t1)); \
464
+ \
465
+ return -1; \
466
+ }
467
+
468
+
469
+ #define CUDA_HOST_UNARY_INIT(funcname, func, t0, t1) \
470
+ { .name = STRINGIZE(funcname), \
471
+ .sig = "... * " STRINGIZE(t0) " -> ... * " STRINGIZE(t1), \
472
+ .OptC = gm_cuda_host_fixed_1D_C_##func##_##t0##_##t1, \
473
+ .OptS = gm_cuda_host_fixed_1D_S_##func##_##t0##_##t1, \
474
+ .C = gm_cuda_host_0D_##func##_##t0##_##t1 }, \
475
+ \
476
+ { .name = STRINGIZE(funcname), \
477
+ .sig = "... * ?" STRINGIZE(t0) " -> ... * ?" STRINGIZE(t1), \
478
+ .OptC = gm_cuda_host_fixed_1D_C_##func##_##t0##_##t1, \
479
+ .OptS = gm_cuda_host_fixed_1D_S_##func##_##t0##_##t1, \
480
+ .C = gm_cuda_host_0D_##func##_##t0##_##t1 } \
481
+
482
+
483
+ #define CUDA_HOST_UNARY_REDUCE_INIT(funcname, func, t0, t1) \
484
+ { .name = "reduce_" STRINGIZE(funcname), \
485
+ .sig = "N * " STRINGIZE(t0) " -> " STRINGIZE(t1), \
486
+ .C = gm_cuda_host_1D_C_reduce_##func##_##t0##_##t1 }, \
487
+ \
488
+ { .name = "reduce_" STRINGIZE(funcname), \
489
+ .sig = "N * ?" STRINGIZE(t0) " -> ?" STRINGIZE(t1), \
490
+ .C = gm_cuda_host_1D_C_reduce_##func##_##t0##_##t1 }
491
+
492
+
493
+ #undef bool
494
+ #define bool_t _Bool
495
+
496
+
497
+ /*****************************************************************************/
498
+ /* Copy */
499
+ /*****************************************************************************/
500
+
501
+ #define CUDA_HOST_ALL_UNARY(name) \
502
+ CUDA_HOST_UNARY(name, bool, bool) \
503
+ CUDA_HOST_UNARY(name, bool, uint8) \
504
+ CUDA_HOST_UNARY(name, bool, uint16) \
505
+ CUDA_HOST_UNARY(name, bool, uint32) \
506
+ CUDA_HOST_UNARY(name, bool, uint64) \
507
+ CUDA_HOST_UNARY(name, bool, int8) \
508
+ CUDA_HOST_UNARY(name, bool, int16) \
509
+ CUDA_HOST_UNARY(name, bool, int32) \
510
+ CUDA_HOST_UNARY(name, bool, int64) \
511
+ CUDA_HOST_UNARY(name, bool, bfloat16) \
512
+ CUDA_HOST_UNARY(name, bool, float16) \
513
+ CUDA_HOST_UNARY(name, bool, float32) \
514
+ CUDA_HOST_UNARY(name, bool, float64) \
515
+ CUDA_HOST_NOIMPL(name,bool, complex32) \
516
+ CUDA_HOST_UNARY(name, bool, complex64) \
517
+ CUDA_HOST_UNARY(name, bool, complex128) \
518
+ \
519
+ CUDA_HOST_UNARY(name, uint8, uint8) \
520
+ CUDA_HOST_UNARY(name, uint8, uint16) \
521
+ CUDA_HOST_UNARY(name, uint8, uint32) \
522
+ CUDA_HOST_UNARY(name, uint8, uint64) \
523
+ CUDA_HOST_UNARY(name, uint8, int16) \
524
+ CUDA_HOST_UNARY(name, uint8, int32) \
525
+ CUDA_HOST_UNARY(name, uint8, int64) \
526
+ CUDA_HOST_UNARY(name, uint8, bfloat16) \
527
+ CUDA_HOST_UNARY(name, uint8, float16) \
528
+ CUDA_HOST_UNARY(name, uint8, float32) \
529
+ CUDA_HOST_UNARY(name, uint8, float64) \
530
+ CUDA_HOST_NOIMPL(name, uint8, complex32) \
531
+ CUDA_HOST_UNARY(name, uint8, complex64) \
532
+ CUDA_HOST_UNARY(name, uint8, complex128) \
533
+ \
534
+ CUDA_HOST_UNARY(name, uint16, uint16) \
535
+ CUDA_HOST_UNARY(name, uint16, uint32) \
536
+ CUDA_HOST_UNARY(name, uint16, uint64) \
537
+ CUDA_HOST_UNARY(name, uint16, int32) \
538
+ CUDA_HOST_UNARY(name, uint16, int64) \
539
+ CUDA_HOST_UNARY(name, uint16, float32) \
540
+ CUDA_HOST_UNARY(name, uint16, float64) \
541
+ CUDA_HOST_UNARY(name, uint16, complex64) \
542
+ CUDA_HOST_UNARY(name, uint16, complex128) \
543
+ \
544
+ CUDA_HOST_UNARY(name, uint32, uint32) \
545
+ CUDA_HOST_UNARY(name, uint32, uint64) \
546
+ CUDA_HOST_UNARY(name, uint32, int64) \
547
+ CUDA_HOST_UNARY(name, uint32, float64) \
548
+ CUDA_HOST_UNARY(name, uint32, complex128) \
549
+ \
550
+ CUDA_HOST_UNARY(name, uint64, uint64) \
551
+ \
552
+ CUDA_HOST_UNARY(name, int8, int8) \
553
+ CUDA_HOST_UNARY(name, int8, int16) \
554
+ CUDA_HOST_UNARY(name, int8, int32) \
555
+ CUDA_HOST_UNARY(name, int8, int64) \
556
+ CUDA_HOST_UNARY(name, int8, bfloat16) \
557
+ CUDA_HOST_UNARY(name, int8, float16) \
558
+ CUDA_HOST_UNARY(name, int8, float32) \
559
+ CUDA_HOST_UNARY(name, int8, float64) \
560
+ CUDA_HOST_NOIMPL(name, int8, complex32) \
561
+ CUDA_HOST_UNARY(name, int8, complex64) \
562
+ CUDA_HOST_UNARY(name, int8, complex128) \
563
+ \
564
+ CUDA_HOST_UNARY(name, int16, int16) \
565
+ CUDA_HOST_UNARY(name, int16, int32) \
566
+ CUDA_HOST_UNARY(name, int16, int64) \
567
+ CUDA_HOST_UNARY(name, int16, float32) \
568
+ CUDA_HOST_UNARY(name, int16, float64) \
569
+ CUDA_HOST_UNARY(name, int16, complex64) \
570
+ CUDA_HOST_UNARY(name, int16, complex128) \
571
+ \
572
+ CUDA_HOST_UNARY(name, int32, int32) \
573
+ CUDA_HOST_UNARY(name, int32, int64) \
574
+ CUDA_HOST_UNARY(name, int32, float64) \
575
+ CUDA_HOST_UNARY(name, int32, complex128) \
576
+ \
577
+ CUDA_HOST_UNARY(name, int64, int64) \
578
+ \
579
+ CUDA_HOST_UNARY(name, bfloat16, bfloat16) \
580
+ CUDA_HOST_UNARY(name, bfloat16, float32) \
581
+ CUDA_HOST_UNARY(name, bfloat16, float64) \
582
+ CUDA_HOST_UNARY(name, bfloat16, complex64) \
583
+ CUDA_HOST_UNARY(name, bfloat16, complex128) \
584
+ \
585
+ CUDA_HOST_UNARY(name, float16, float16) \
586
+ CUDA_HOST_UNARY(name, float16, float32) \
587
+ CUDA_HOST_UNARY(name, float16, float64) \
588
+ CUDA_HOST_NOIMPL(name, float16, complex32) \
589
+ CUDA_HOST_UNARY(name, float16, complex64) \
590
+ CUDA_HOST_UNARY(name, float16, complex128) \
591
+ \
592
+ CUDA_HOST_UNARY(name, float32, float32) \
593
+ CUDA_HOST_UNARY(name, float32, float64) \
594
+ CUDA_HOST_UNARY(name, float32, complex64) \
595
+ CUDA_HOST_UNARY(name, float32, complex128) \
596
+ \
597
+ CUDA_HOST_UNARY(name, float64, float64) \
598
+ CUDA_HOST_UNARY(name, float64, complex128) \
599
+ \
600
+ CUDA_HOST_NOIMPL(name, complex32, complex32) \
601
+ CUDA_HOST_NOIMPL(name, complex32, complex64) \
602
+ CUDA_HOST_NOIMPL(name, complex32, complex128) \
603
+ \
604
+ CUDA_HOST_UNARY(name, complex64, complex64) \
605
+ CUDA_HOST_UNARY(name, complex64, complex128) \
606
+ \
607
+ CUDA_HOST_UNARY(name, complex128, complex128)
608
+
609
+ #define CUDA_HOST_ALL_UNARY_INIT(name, func, hfunc) \
610
+ CUDA_HOST_UNARY_INIT(name, func, bool, bool), \
611
+ CUDA_HOST_UNARY_INIT(name, func, bool, uint8), \
612
+ CUDA_HOST_UNARY_INIT(name, func, bool, uint16), \
613
+ CUDA_HOST_UNARY_INIT(name, func, bool, uint32), \
614
+ CUDA_HOST_UNARY_INIT(name, func, bool, uint64), \
615
+ CUDA_HOST_UNARY_INIT(name, func, bool, int8), \
616
+ CUDA_HOST_UNARY_INIT(name, func, bool, int16), \
617
+ CUDA_HOST_UNARY_INIT(name, func, bool, int32), \
618
+ CUDA_HOST_UNARY_INIT(name, func, bool, int64), \
619
+ CUDA_HOST_UNARY_INIT(name, func, bool, bfloat16), \
620
+ CUDA_HOST_UNARY_INIT(name, hfunc, bool, float16), \
621
+ CUDA_HOST_UNARY_INIT(name, func, bool, float32), \
622
+ CUDA_HOST_UNARY_INIT(name, func, bool, float64), \
623
+ CUDA_HOST_UNARY_INIT(name, func, bool, complex32), \
624
+ CUDA_HOST_UNARY_INIT(name, func, bool, complex64), \
625
+ CUDA_HOST_UNARY_INIT(name, func, bool, complex128), \
626
+ \
627
+ CUDA_HOST_UNARY_INIT(name, func, uint8, uint8), \
628
+ CUDA_HOST_UNARY_INIT(name, func, uint8, uint16), \
629
+ CUDA_HOST_UNARY_INIT(name, func, uint8, uint32), \
630
+ CUDA_HOST_UNARY_INIT(name, func, uint8, uint64), \
631
+ CUDA_HOST_UNARY_INIT(name, func, uint8, int16), \
632
+ CUDA_HOST_UNARY_INIT(name, func, uint8, int32), \
633
+ CUDA_HOST_UNARY_INIT(name, func, uint8, int64), \
634
+ CUDA_HOST_UNARY_INIT(name, func, uint8, bfloat16), \
635
+ CUDA_HOST_UNARY_INIT(name, hfunc, uint8, float16), \
636
+ CUDA_HOST_UNARY_INIT(name, func, uint8, float32), \
637
+ CUDA_HOST_UNARY_INIT(name, func, uint8, float64), \
638
+ CUDA_HOST_UNARY_INIT(name, func, uint8, complex32), \
639
+ CUDA_HOST_UNARY_INIT(name, func, uint8, complex64), \
640
+ CUDA_HOST_UNARY_INIT(name, func, uint8, complex128), \
641
+ \
642
+ CUDA_HOST_UNARY_INIT(name, func, uint16, uint16), \
643
+ CUDA_HOST_UNARY_INIT(name, func, uint16, uint32), \
644
+ CUDA_HOST_UNARY_INIT(name, func, uint16, uint64), \
645
+ CUDA_HOST_UNARY_INIT(name, func, uint16, int32), \
646
+ CUDA_HOST_UNARY_INIT(name, func, uint16, int64), \
647
+ CUDA_HOST_UNARY_INIT(name, func, uint16, float32), \
648
+ CUDA_HOST_UNARY_INIT(name, func, uint16, float64), \
649
+ CUDA_HOST_UNARY_INIT(name, func, uint16, complex64), \
650
+ CUDA_HOST_UNARY_INIT(name, func, uint16, complex128), \
651
+ \
652
+ CUDA_HOST_UNARY_INIT(name, func, uint32, uint32), \
653
+ CUDA_HOST_UNARY_INIT(name, func, uint32, uint64), \
654
+ CUDA_HOST_UNARY_INIT(name, func, uint32, int64), \
655
+ CUDA_HOST_UNARY_INIT(name, func, uint32, float64), \
656
+ CUDA_HOST_UNARY_INIT(name, func, uint32, complex128), \
657
+ \
658
+ CUDA_HOST_UNARY_INIT(name, func, uint64, uint64), \
659
+ \
660
+ CUDA_HOST_UNARY_INIT(name, func, int8, int8), \
661
+ CUDA_HOST_UNARY_INIT(name, func, int8, int16), \
662
+ CUDA_HOST_UNARY_INIT(name, func, int8, int32), \
663
+ CUDA_HOST_UNARY_INIT(name, func, int8, int64), \
664
+ CUDA_HOST_UNARY_INIT(name, func, int8, bfloat16), \
665
+ CUDA_HOST_UNARY_INIT(name, hfunc, int8, float16), \
666
+ CUDA_HOST_UNARY_INIT(name, func, int8, float32), \
667
+ CUDA_HOST_UNARY_INIT(name, func, int8, float64), \
668
+ CUDA_HOST_UNARY_INIT(name, func, int8, complex32), \
669
+ CUDA_HOST_UNARY_INIT(name, func, int8, complex64), \
670
+ CUDA_HOST_UNARY_INIT(name, func, int8, complex128), \
671
+ \
672
+ CUDA_HOST_UNARY_INIT(name, func, int16, int16), \
673
+ CUDA_HOST_UNARY_INIT(name, func, int16, int32), \
674
+ CUDA_HOST_UNARY_INIT(name, func, int16, int64), \
675
+ CUDA_HOST_UNARY_INIT(name, func, int16, float32), \
676
+ CUDA_HOST_UNARY_INIT(name, func, int16, float64), \
677
+ CUDA_HOST_UNARY_INIT(name, func, int16, complex64), \
678
+ CUDA_HOST_UNARY_INIT(name, func, int16, complex128), \
679
+ \
680
+ CUDA_HOST_UNARY_INIT(name, func, int32, int32), \
681
+ CUDA_HOST_UNARY_INIT(name, func, int32, int64), \
682
+ CUDA_HOST_UNARY_INIT(name, func, int32, float64), \
683
+ CUDA_HOST_UNARY_INIT(name, func, int32, complex128), \
684
+ \
685
+ CUDA_HOST_UNARY_INIT(name, func, int64, int64), \
686
+ \
687
+ CUDA_HOST_UNARY_INIT(name, func, bfloat16, bfloat16), \
688
+ CUDA_HOST_UNARY_INIT(name, func, bfloat16, float32), \
689
+ CUDA_HOST_UNARY_INIT(name, func, bfloat16, float64), \
690
+ CUDA_HOST_UNARY_INIT(name, func, bfloat16, complex64), \
691
+ CUDA_HOST_UNARY_INIT(name, func, bfloat16, complex128), \
692
+ \
693
+ CUDA_HOST_UNARY_INIT(name, hfunc, float16, float16), \
694
+ CUDA_HOST_UNARY_INIT(name, func, float16, float32), \
695
+ CUDA_HOST_UNARY_INIT(name, func, float16, float64), \
696
+ CUDA_HOST_UNARY_INIT(name, func, float16, complex32), \
697
+ CUDA_HOST_UNARY_INIT(name, func, float16, complex64), \
698
+ CUDA_HOST_UNARY_INIT(name, func, float16, complex128), \
699
+ \
700
+ CUDA_HOST_UNARY_INIT(name, func, float32, float32), \
701
+ CUDA_HOST_UNARY_INIT(name, func, float32, float64), \
702
+ CUDA_HOST_UNARY_INIT(name, func, float32, complex64), \
703
+ CUDA_HOST_UNARY_INIT(name, func, float32, complex128), \
704
+ \
705
+ CUDA_HOST_UNARY_INIT(name, func, float64, float64), \
706
+ CUDA_HOST_UNARY_INIT(name, func, float64, complex128), \
707
+ \
708
+ CUDA_HOST_UNARY_INIT(name, func, complex32, complex32), \
709
+ CUDA_HOST_UNARY_INIT(name, func, complex32, complex64), \
710
+ CUDA_HOST_UNARY_INIT(name, func, complex32, complex128), \
711
+ \
712
+ CUDA_HOST_UNARY_INIT(name, func, complex64, complex64), \
713
+ CUDA_HOST_UNARY_INIT(name, func, complex64, complex128), \
714
+ \
715
+ CUDA_HOST_UNARY_INIT(name, func, complex128, complex128)
716
+
717
+
718
+ CUDA_HOST_ALL_UNARY(copy)
719
+ CUDA_HOST_ALL_UNARY(abs)
720
+
721
+
722
+ static const gm_kernel_init_t unary_copy[] = {
723
+ /* COPY */
724
+ CUDA_HOST_ALL_UNARY_INIT(copy, copy, copy),
725
+ CUDA_HOST_ALL_UNARY_INIT(abs, abs, abs),
726
+
727
+ { .name = NULL, .sig = NULL }
728
+ };
729
+
730
+ /*****************************************************************************/
731
+ /* Reduce */
732
+ /*****************************************************************************/
733
+
734
+ #define CUDA_HOST_ALL_UNARY_REDUCE(name) \
735
+ CUDA_HOST_UNARY_REDUCE(name, bool, bool) \
736
+ CUDA_HOST_UNARY_REDUCE(name, bool, uint8) \
737
+ CUDA_HOST_UNARY_REDUCE(name, bool, uint16) \
738
+ CUDA_HOST_UNARY_REDUCE(name, bool, uint32) \
739
+ CUDA_HOST_UNARY_REDUCE(name, bool, uint64) \
740
+ CUDA_HOST_UNARY_REDUCE(name, bool, int8) \
741
+ CUDA_HOST_UNARY_REDUCE(name, bool, int16) \
742
+ CUDA_HOST_UNARY_REDUCE(name, bool, int32) \
743
+ CUDA_HOST_UNARY_REDUCE(name, bool, int64) \
744
+ CUDA_HOST_REDUCE_NOIMPL(name, bool, bfloat16) \
745
+ CUDA_HOST_UNARY_REDUCE(name, bool, float16) \
746
+ CUDA_HOST_UNARY_REDUCE(name, bool, float32) \
747
+ CUDA_HOST_UNARY_REDUCE(name, bool, float64) \
748
+ CUDA_HOST_REDUCE_NOIMPL(name,bool, complex32) \
749
+ CUDA_HOST_REDUCE_NOIMPL(name, bool, complex64) \
750
+ CUDA_HOST_REDUCE_NOIMPL(name, bool, complex128) \
751
+ \
752
+ CUDA_HOST_UNARY_REDUCE(name, uint8, uint8) \
753
+ CUDA_HOST_UNARY_REDUCE(name, uint8, uint16) \
754
+ CUDA_HOST_UNARY_REDUCE(name, uint8, uint32) \
755
+ CUDA_HOST_UNARY_REDUCE(name, uint8, uint64) \
756
+ CUDA_HOST_UNARY_REDUCE(name, uint8, int16) \
757
+ CUDA_HOST_UNARY_REDUCE(name, uint8, int32) \
758
+ CUDA_HOST_UNARY_REDUCE(name, uint8, int64) \
759
+ CUDA_HOST_REDUCE_NOIMPL(name, uint8, bfloat16) \
760
+ CUDA_HOST_UNARY_REDUCE(name, uint8, float16) \
761
+ CUDA_HOST_UNARY_REDUCE(name, uint8, float32) \
762
+ CUDA_HOST_UNARY_REDUCE(name, uint8, float64) \
763
+ CUDA_HOST_REDUCE_NOIMPL(name, uint8, complex32) \
764
+ CUDA_HOST_REDUCE_NOIMPL(name, uint8, complex64) \
765
+ CUDA_HOST_REDUCE_NOIMPL(name, uint8, complex128) \
766
+ \
767
+ CUDA_HOST_UNARY_REDUCE(name, uint16, uint16) \
768
+ CUDA_HOST_UNARY_REDUCE(name, uint16, uint32) \
769
+ CUDA_HOST_UNARY_REDUCE(name, uint16, uint64) \
770
+ CUDA_HOST_UNARY_REDUCE(name, uint16, int32) \
771
+ CUDA_HOST_UNARY_REDUCE(name, uint16, int64) \
772
+ CUDA_HOST_UNARY_REDUCE(name, uint16, float32) \
773
+ CUDA_HOST_UNARY_REDUCE(name, uint16, float64) \
774
+ CUDA_HOST_REDUCE_NOIMPL(name, uint16, complex64) \
775
+ CUDA_HOST_REDUCE_NOIMPL(name, uint16, complex128) \
776
+ \
777
+ CUDA_HOST_UNARY_REDUCE(name, uint32, uint32) \
778
+ CUDA_HOST_UNARY_REDUCE(name, uint32, uint64) \
779
+ CUDA_HOST_UNARY_REDUCE(name, uint32, int64) \
780
+ CUDA_HOST_UNARY_REDUCE(name, uint32, float64) \
781
+ CUDA_HOST_REDUCE_NOIMPL(name, uint32, complex128) \
782
+ \
783
+ CUDA_HOST_UNARY_REDUCE(name, uint64, uint64) \
784
+ \
785
+ CUDA_HOST_UNARY_REDUCE(name, int8, int8) \
786
+ CUDA_HOST_UNARY_REDUCE(name, int8, int16) \
787
+ CUDA_HOST_UNARY_REDUCE(name, int8, int32) \
788
+ CUDA_HOST_UNARY_REDUCE(name, int8, int64) \
789
+ CUDA_HOST_REDUCE_NOIMPL(name, int8, bfloat16) \
790
+ CUDA_HOST_UNARY_REDUCE(name, int8, float16) \
791
+ CUDA_HOST_UNARY_REDUCE(name, int8, float32) \
792
+ CUDA_HOST_UNARY_REDUCE(name, int8, float64) \
793
+ CUDA_HOST_REDUCE_NOIMPL(name, int8, complex32) \
794
+ CUDA_HOST_REDUCE_NOIMPL(name, int8, complex64) \
795
+ CUDA_HOST_REDUCE_NOIMPL(name, int8, complex128) \
796
+ \
797
+ CUDA_HOST_UNARY_REDUCE(name, int16, int16) \
798
+ CUDA_HOST_UNARY_REDUCE(name, int16, int32) \
799
+ CUDA_HOST_UNARY_REDUCE(name, int16, int64) \
800
+ CUDA_HOST_UNARY_REDUCE(name, int16, float32) \
801
+ CUDA_HOST_UNARY_REDUCE(name, int16, float64) \
802
+ CUDA_HOST_REDUCE_NOIMPL(name, int16, complex64) \
803
+ CUDA_HOST_REDUCE_NOIMPL(name, int16, complex128) \
804
+ \
805
+ CUDA_HOST_UNARY_REDUCE(name, int32, int32) \
806
+ CUDA_HOST_UNARY_REDUCE(name, int32, int64) \
807
+ CUDA_HOST_UNARY_REDUCE(name, int32, float64) \
808
+ CUDA_HOST_REDUCE_NOIMPL(name, int32, complex128) \
809
+ \
810
+ CUDA_HOST_UNARY_REDUCE(name, int64, int64) \
811
+ \
812
+ CUDA_HOST_REDUCE_NOIMPL(name, bfloat16, bfloat16) \
813
+ CUDA_HOST_REDUCE_NOIMPL(name, bfloat16, float32) \
814
+ CUDA_HOST_REDUCE_NOIMPL(name, bfloat16, float64) \
815
+ CUDA_HOST_REDUCE_NOIMPL(name, bfloat16, complex64) \
816
+ CUDA_HOST_REDUCE_NOIMPL(name, bfloat16, complex128) \
817
+ \
818
+ CUDA_HOST_UNARY_REDUCE(name, float16, float16) \
819
+ CUDA_HOST_REDUCE_NOIMPL(name, float16, float32) \
820
+ CUDA_HOST_REDUCE_NOIMPL(name, float16, float64) \
821
+ CUDA_HOST_REDUCE_NOIMPL(name, float16, complex32) \
822
+ CUDA_HOST_REDUCE_NOIMPL(name, float16, complex64) \
823
+ CUDA_HOST_REDUCE_NOIMPL(name, float16, complex128) \
824
+ \
825
+ CUDA_HOST_UNARY_REDUCE(name, float32, float32) \
826
+ CUDA_HOST_UNARY_REDUCE(name, float32, float64) \
827
+ CUDA_HOST_REDUCE_NOIMPL(name, float32, complex64) \
828
+ CUDA_HOST_REDUCE_NOIMPL(name, float32, complex128) \
829
+ \
830
+ CUDA_HOST_UNARY_REDUCE(name, float64, float64) \
831
+ CUDA_HOST_REDUCE_NOIMPL(name, float64, complex128) \
832
+ \
833
+ CUDA_HOST_REDUCE_NOIMPL(name, complex32, complex32) \
834
+ CUDA_HOST_REDUCE_NOIMPL(name, complex32, complex64) \
835
+ CUDA_HOST_REDUCE_NOIMPL(name, complex32, complex128) \
836
+ \
837
+ CUDA_HOST_REDUCE_NOIMPL(name, complex64, complex64) \
838
+ CUDA_HOST_REDUCE_NOIMPL(name, complex64, complex128) \
839
+ \
840
+ CUDA_HOST_REDUCE_NOIMPL(name, complex128, complex128)
841
+
842
+ #define CUDA_HOST_ALL_UNARY_REDUCE_INIT(name, func) \
843
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, bool), \
844
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, uint8), \
845
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, uint16), \
846
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, uint32), \
847
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, uint64), \
848
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, int8), \
849
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, int16), \
850
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, int32), \
851
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, int64), \
852
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, bfloat16), \
853
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, float16), \
854
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, float32), \
855
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, float64), \
856
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, complex32), \
857
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, complex64), \
858
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, complex128), \
859
+ \
860
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, uint8), \
861
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, uint16), \
862
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, uint32), \
863
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, uint64), \
864
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, int16), \
865
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, int32), \
866
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, int64), \
867
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, bfloat16), \
868
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, float16), \
869
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, float32), \
870
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, float64), \
871
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, complex32), \
872
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, complex64), \
873
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, complex128), \
874
+ \
875
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, uint16), \
876
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, uint32), \
877
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, uint64), \
878
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, int32), \
879
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, int64), \
880
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, float32), \
881
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, float64), \
882
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, complex64), \
883
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, complex128), \
884
+ \
885
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint32, uint32), \
886
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint32, uint64), \
887
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint32, int64), \
888
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint32, float64), \
889
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint32, complex128), \
890
+ \
891
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint64, uint64), \
892
+ \
893
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, int8), \
894
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, int16), \
895
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, int32), \
896
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, int64), \
897
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, bfloat16), \
898
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, float16), \
899
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, float32), \
900
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, float64), \
901
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, complex32), \
902
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, complex64), \
903
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, complex128), \
904
+ \
905
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, int16), \
906
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, int32), \
907
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, int64), \
908
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, float32), \
909
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, float64), \
910
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, complex64), \
911
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, complex128), \
912
+ \
913
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int32, int32), \
914
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int32, int64), \
915
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int32, float64), \
916
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int32, complex128), \
917
+ \
918
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int64, int64), \
919
+ \
920
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bfloat16, bfloat16), \
921
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bfloat16, float32), \
922
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bfloat16, float64), \
923
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bfloat16, complex64), \
924
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bfloat16, complex128), \
925
+ \
926
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, float16), \
927
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, float32), \
928
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, float64), \
929
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, complex32), \
930
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, complex64), \
931
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, complex128), \
932
+ \
933
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float32, float32), \
934
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float32, float64), \
935
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float32, complex64), \
936
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float32, complex128), \
937
+ \
938
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float64, float64), \
939
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float64, complex128), \
940
+ \
941
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex32, complex32), \
942
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex32, complex64), \
943
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex32, complex128), \
944
+ \
945
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex64, complex64), \
946
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex64, complex128), \
947
+ \
948
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex128, complex128)
949
+
950
+
951
+ CUDA_HOST_ALL_UNARY_REDUCE(add)
952
+ CUDA_HOST_ALL_UNARY_REDUCE(multiply)
953
+
954
+
955
+ static const gm_kernel_init_t unary_reduce[] = {
956
+ /* REDUCE */
957
+ CUDA_HOST_ALL_UNARY_REDUCE_INIT(add, add),
958
+ CUDA_HOST_ALL_UNARY_REDUCE_INIT(multiply, multiply),
959
+
960
+ { .name = NULL, .sig = NULL }
961
+ };
962
+
963
+
964
+ /*****************************************************************************/
965
+ /* Bitwise NOT */
966
+ /*****************************************************************************/
967
+
968
+ CUDA_HOST_UNARY(invert, bool, bool)
969
+
970
+ CUDA_HOST_UNARY(invert, uint8, uint8)
971
+ CUDA_HOST_UNARY(invert, uint16, uint16)
972
+ CUDA_HOST_UNARY(invert, uint32, uint32)
973
+ CUDA_HOST_UNARY(invert, uint64, uint64)
974
+
975
+ CUDA_HOST_UNARY(invert, int8, int8)
976
+ CUDA_HOST_UNARY(invert, int16, int16)
977
+ CUDA_HOST_UNARY(invert, int32, int32)
978
+ CUDA_HOST_UNARY(invert, int64, int64)
979
+
980
+
981
+ static const gm_kernel_init_t unary_invert[] = {
982
+ /* INVERT */
983
+ CUDA_HOST_UNARY_INIT(invert, invert, bool, bool),
984
+
985
+ CUDA_HOST_UNARY_INIT(invert, invert, uint8, uint8),
986
+ CUDA_HOST_UNARY_INIT(invert, invert, uint16, uint16),
987
+ CUDA_HOST_UNARY_INIT(invert, invert, uint32, uint32),
988
+ CUDA_HOST_UNARY_INIT(invert, invert, uint64, uint64),
989
+
990
+ CUDA_HOST_UNARY_INIT(invert, invert, int8, int8),
991
+ CUDA_HOST_UNARY_INIT(invert, invert, int16, int16),
992
+ CUDA_HOST_UNARY_INIT(invert, invert, int32, int32),
993
+ CUDA_HOST_UNARY_INIT(invert, invert, int64, int64),
994
+
995
+ { .name = NULL, .sig = NULL }
996
+ };
997
+
998
+
999
+ /*****************************************************************************/
1000
+ /* Negative */
1001
+ /*****************************************************************************/
1002
+
1003
+ CUDA_HOST_UNARY(negative, uint8, int16)
1004
+ CUDA_HOST_UNARY(negative, uint16, int32)
1005
+ CUDA_HOST_UNARY(negative, uint32, int64)
1006
+
1007
+ CUDA_HOST_UNARY(negative, int8, int8)
1008
+ CUDA_HOST_UNARY(negative, int16, int16)
1009
+ CUDA_HOST_UNARY(negative, int32, int32)
1010
+ CUDA_HOST_UNARY(negative, int64, int64)
1011
+
1012
+ CUDA_HOST_UNARY(negative, bfloat16, bfloat16)
1013
+ CUDA_HOST_UNARY(negative, float16, float16)
1014
+ CUDA_HOST_UNARY(negative, float32, float32)
1015
+ CUDA_HOST_UNARY(negative, float64, float64)
1016
+
1017
+ CUDA_HOST_NOIMPL(negative, complex32, complex32)
1018
+ CUDA_HOST_UNARY(negative, complex64, complex64)
1019
+ CUDA_HOST_UNARY(negative, complex128, complex128)
1020
+
1021
+
1022
+ static const gm_kernel_init_t unary_negative[] = {
1023
+ /* NEGATIVE */
1024
+ CUDA_HOST_UNARY_INIT(negative, negative, uint8, int16),
1025
+ CUDA_HOST_UNARY_INIT(negative, negative, uint16, int32),
1026
+ CUDA_HOST_UNARY_INIT(negative, negative, uint32, int64),
1027
+
1028
+ CUDA_HOST_UNARY_INIT(negative, negative, int8, int8),
1029
+ CUDA_HOST_UNARY_INIT(negative, negative, int16, int16),
1030
+ CUDA_HOST_UNARY_INIT(negative, negative, int32, int32),
1031
+ CUDA_HOST_UNARY_INIT(negative, negative, int64, int64),
1032
+
1033
+ CUDA_HOST_UNARY_INIT(negative, negative, bfloat16, bfloat16),
1034
+ CUDA_HOST_UNARY_INIT(negative, negative, float16, float16),
1035
+ CUDA_HOST_UNARY_INIT(negative, negative, float32, float32),
1036
+ CUDA_HOST_UNARY_INIT(negative, negative, float64, float64),
1037
+
1038
+ CUDA_HOST_UNARY_INIT(negative, negative, complex32, complex32),
1039
+ CUDA_HOST_UNARY_INIT(negative, negative, complex64, complex64),
1040
+ CUDA_HOST_UNARY_INIT(negative, negative, complex128, complex128),
1041
+
1042
+ { .name = NULL, .sig = NULL }
1043
+ };
1044
+
1045
+
1046
+ /*****************************************************************************/
1047
+ /* Math */
1048
+ /*****************************************************************************/
1049
+
1050
+ #define _CUDA_ALL_HALF_MATH(name) \
1051
+ CUDA_HOST_UNARY(name##f16, uint8, float16) \
1052
+ CUDA_HOST_UNARY(name##f16, int8, float16) \
1053
+ CUDA_HOST_UNARY(name##f16, float16, float16)
1054
+
1055
+ #define _CUDA_ALL_HALF_MATH_NOIMPL(name) \
1056
+ CUDA_HOST_NOIMPL(name##f16, uint8, float16) \
1057
+ CUDA_HOST_NOIMPL(name##f16, int8, float16) \
1058
+ CUDA_HOST_NOIMPL(name##f16, float16, float16)
1059
+
1060
+ #define _CUDA_ALL_COMPLEX_MATH(name) \
1061
+ CUDA_HOST_NOIMPL(name, complex32, complex32) \
1062
+ CUDA_HOST_UNARY(name, complex64, complex64) \
1063
+ CUDA_HOST_UNARY(name, complex128, complex128)
1064
+
1065
+ #define _CUDA_ALL_COMPLEX_MATH_NOIMPL(name) \
1066
+ CUDA_HOST_NOIMPL(name, complex32, complex32) \
1067
+ CUDA_HOST_NOIMPL(name, complex64, complex64) \
1068
+ CUDA_HOST_NOIMPL(name, complex128, complex128)
1069
+
1070
+ #define _CUDA_ALL_REAL_MATH(name) \
1071
+ CUDA_HOST_UNARY(name##b16, bfloat16, bfloat16) \
1072
+ CUDA_HOST_UNARY(name##f, uint16, float32) \
1073
+ CUDA_HOST_UNARY(name##f, int16, float32) \
1074
+ CUDA_HOST_UNARY(name##f, float32, float32) \
1075
+ CUDA_HOST_UNARY(name, uint32, float64) \
1076
+ CUDA_HOST_UNARY(name, int32, float64) \
1077
+ CUDA_HOST_UNARY(name, float64, float64) \
1078
+
1079
+ #define CUDA_ALL_REAL_MATH(name) \
1080
+ _CUDA_ALL_HALF_MATH_NOIMPL(name) \
1081
+ _CUDA_ALL_REAL_MATH(name) \
1082
+ _CUDA_ALL_COMPLEX_MATH_NOIMPL(name)
1083
+
1084
+ #define CUDA_ALL_REAL_MATH_WITH_HALF(name) \
1085
+ _CUDA_ALL_HALF_MATH(name) \
1086
+ _CUDA_ALL_REAL_MATH(name) \
1087
+ _CUDA_ALL_COMPLEX_MATH_NOIMPL(name)
1088
+
1089
+ #define CUDA_ALL_COMPLEX_MATH(name) \
1090
+ _CUDA_ALL_HALF_MATH_NOIMPL(name) \
1091
+ _CUDA_ALL_REAL_MATH(name) \
1092
+ _CUDA_ALL_COMPLEX_MATH(name)
1093
+
1094
+ #define CUDA_ALL_COMPLEX_MATH_WITH_HALF(name) \
1095
+ _CUDA_ALL_HALF_MATH(name) \
1096
+ _CUDA_ALL_REAL_MATH(name) \
1097
+ _CUDA_ALL_COMPLEX_MATH(name) \
1098
+
1099
+
1100
+ #define CUDA_ALL_UNARY_MATH_INIT(name) \
1101
+ CUDA_HOST_UNARY_INIT(name, name##f16, uint8, float16), \
1102
+ CUDA_HOST_UNARY_INIT(name, name##f16, int8, float16), \
1103
+ CUDA_HOST_UNARY_INIT(name, name##f16, float16, float16), \
1104
+ \
1105
+ CUDA_HOST_UNARY_INIT(name, name##b16, bfloat16, bfloat16), \
1106
+ \
1107
+ CUDA_HOST_UNARY_INIT(name, name##f, uint16, float32), \
1108
+ CUDA_HOST_UNARY_INIT(name, name##f, int16, float32), \
1109
+ CUDA_HOST_UNARY_INIT(name, name##f, float32, float32), \
1110
+ \
1111
+ CUDA_HOST_UNARY_INIT(name, name, uint32, float64), \
1112
+ CUDA_HOST_UNARY_INIT(name, name, int32, float64), \
1113
+ CUDA_HOST_UNARY_INIT(name, name, float64, float64), \
1114
+ \
1115
+ CUDA_HOST_UNARY_INIT(name, name, complex32, complex32), \
1116
+ CUDA_HOST_UNARY_INIT(name, name, complex64, complex64), \
1117
+ CUDA_HOST_UNARY_INIT(name, name, complex128, complex128)
1118
+
1119
+
1120
+ /*****************************************************************************/
1121
+ /* Abs functions */
1122
+ /*****************************************************************************/
1123
+
1124
+ CUDA_ALL_REAL_MATH_WITH_HALF(fabs)
1125
+
1126
+
1127
+ /*****************************************************************************/
1128
+ /* Exponential functions */
1129
+ /*****************************************************************************/
1130
+
1131
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(exp)
1132
+ CUDA_ALL_REAL_MATH_WITH_HALF(exp2)
1133
+ CUDA_ALL_REAL_MATH(expm1)
1134
+
1135
+
1136
+ /*****************************************************************************/
1137
+ /* Logarithm functions */
1138
+ /*****************************************************************************/
1139
+
1140
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(log)
1141
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(log10)
1142
+ CUDA_ALL_REAL_MATH_WITH_HALF(log2)
1143
+ CUDA_ALL_REAL_MATH(log1p)
1144
+ CUDA_ALL_REAL_MATH(logb)
1145
+
1146
+
1147
+ /*****************************************************************************/
1148
+ /* Power functions */
1149
+ /*****************************************************************************/
1150
+
1151
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(sqrt)
1152
+ CUDA_ALL_REAL_MATH(cbrt)
1153
+
1154
+
1155
+ /*****************************************************************************/
1156
+ /* Trigonometric functions */
1157
+ /*****************************************************************************/
1158
+
1159
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(sin)
1160
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(cos)
1161
+ CUDA_ALL_COMPLEX_MATH(tan)
1162
+ CUDA_ALL_COMPLEX_MATH(asin)
1163
+ CUDA_ALL_COMPLEX_MATH(acos)
1164
+ CUDA_ALL_COMPLEX_MATH(atan)
1165
+
1166
+
1167
+ /*****************************************************************************/
1168
+ /* Hyperbolic functions */
1169
+ /*****************************************************************************/
1170
+
1171
+ CUDA_ALL_COMPLEX_MATH(sinh)
1172
+ CUDA_ALL_COMPLEX_MATH(cosh)
1173
+ CUDA_ALL_COMPLEX_MATH(tanh)
1174
+ CUDA_ALL_COMPLEX_MATH(asinh)
1175
+ CUDA_ALL_COMPLEX_MATH(acosh)
1176
+ CUDA_ALL_COMPLEX_MATH(atanh)
1177
+
1178
+
1179
+ /*****************************************************************************/
1180
+ /* Error and gamma functions */
1181
+ /*****************************************************************************/
1182
+
1183
+ CUDA_ALL_REAL_MATH(erf)
1184
+ CUDA_ALL_REAL_MATH(erfc)
1185
+ CUDA_ALL_REAL_MATH(lgamma)
1186
+ CUDA_ALL_REAL_MATH(tgamma)
1187
+
1188
+
1189
+ /*****************************************************************************/
1190
+ /* Ceiling, floor, trunc */
1191
+ /*****************************************************************************/
1192
+
1193
+ CUDA_ALL_REAL_MATH(ceil)
1194
+ CUDA_ALL_REAL_MATH(floor)
1195
+ CUDA_ALL_REAL_MATH(trunc)
1196
+ CUDA_ALL_REAL_MATH(round)
1197
+ CUDA_ALL_REAL_MATH(nearbyint)
1198
+
1199
+
1200
+ static const gm_kernel_init_t unary_float[] = {
1201
+ /* ABS */
1202
+ CUDA_ALL_UNARY_MATH_INIT(fabs),
1203
+
1204
+ /* EXPONENTIAL */
1205
+ CUDA_ALL_UNARY_MATH_INIT(exp),
1206
+ CUDA_ALL_UNARY_MATH_INIT(exp2),
1207
+ CUDA_ALL_UNARY_MATH_INIT(expm1),
1208
+
1209
+ /* LOGARITHM */
1210
+ CUDA_ALL_UNARY_MATH_INIT(log),
1211
+ CUDA_ALL_UNARY_MATH_INIT(log2),
1212
+ CUDA_ALL_UNARY_MATH_INIT(log10),
1213
+ CUDA_ALL_UNARY_MATH_INIT(log1p),
1214
+ CUDA_ALL_UNARY_MATH_INIT(logb),
1215
+
1216
+ /* POWER */
1217
+ CUDA_ALL_UNARY_MATH_INIT(sqrt),
1218
+ CUDA_ALL_UNARY_MATH_INIT(cbrt),
1219
+
1220
+ /* TRIGONOMETRIC */
1221
+ CUDA_ALL_UNARY_MATH_INIT(sin),
1222
+ CUDA_ALL_UNARY_MATH_INIT(cos),
1223
+ CUDA_ALL_UNARY_MATH_INIT(tan),
1224
+ CUDA_ALL_UNARY_MATH_INIT(asin),
1225
+ CUDA_ALL_UNARY_MATH_INIT(acos),
1226
+ CUDA_ALL_UNARY_MATH_INIT(atan),
1227
+
1228
+ /* HYPERBOLIC */
1229
+ CUDA_ALL_UNARY_MATH_INIT(sinh),
1230
+ CUDA_ALL_UNARY_MATH_INIT(cosh),
1231
+ CUDA_ALL_UNARY_MATH_INIT(tanh),
1232
+ CUDA_ALL_UNARY_MATH_INIT(asinh),
1233
+ CUDA_ALL_UNARY_MATH_INIT(acosh),
1234
+ CUDA_ALL_UNARY_MATH_INIT(atanh),
1235
+
1236
+ /* ERROR AND GAMMA */
1237
+ CUDA_ALL_UNARY_MATH_INIT(erf),
1238
+ CUDA_ALL_UNARY_MATH_INIT(erfc),
1239
+ CUDA_ALL_UNARY_MATH_INIT(lgamma),
1240
+ CUDA_ALL_UNARY_MATH_INIT(tgamma),
1241
+
1242
+ /* CEILING, FLOOR, TRUNC */
1243
+ CUDA_ALL_UNARY_MATH_INIT(ceil),
1244
+ CUDA_ALL_UNARY_MATH_INIT(floor),
1245
+ CUDA_ALL_UNARY_MATH_INIT(trunc),
1246
+ CUDA_ALL_UNARY_MATH_INIT(round),
1247
+ CUDA_ALL_UNARY_MATH_INIT(nearbyint),
1248
+
1249
+ { .name = NULL, .sig = NULL }
1250
+ };
1251
+
1252
+
1253
+ /****************************************************************************/
1254
+ /* Initialize kernel table */
1255
+ /****************************************************************************/
1256
+
1257
+ typedef _Bool bool;
1258
+
1259
+ static const gm_kernel_set_t *
1260
+ unary_copy_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
1261
+ const int64_t li[], int nin, int nout, bool check_broadcast,
1262
+ ndt_context_t *ctx)
1263
+ {
1264
+ return cuda_unary_typecheck(copy_kernel_location, spec, f, types, li,
1265
+ nin, nout, check_broadcast, ctx);
1266
+ }
1267
+
1268
+ static const gm_kernel_set_t *
1269
+ unary_invert_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
1270
+ const int64_t li[], int nin, int nout, bool check_broadcast,
1271
+ ndt_context_t *ctx)
1272
+ {
1273
+ return cuda_unary_typecheck(invert_kernel_location, spec, f, types, li,
1274
+ nin, nout, check_broadcast, ctx);
1275
+ }
1276
+
1277
+ static const gm_kernel_set_t *
1278
+ unary_negative_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
1279
+ const int64_t li[], int nin, int nout, bool check_broadcast,
1280
+ ndt_context_t *ctx)
1281
+ {
1282
+ return cuda_unary_typecheck(negative_kernel_location, spec, f, types, li,
1283
+ nin, nout, check_broadcast, ctx);
1284
+ }
1285
+
1286
+ static const gm_kernel_set_t *
1287
+ unary_math_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
1288
+ const int64_t li[], int nin, int nout, bool check_broadcast,
1289
+ ndt_context_t *ctx)
1290
+ {
1291
+ return cuda_unary_typecheck(math_kernel_location, spec, f, types, li,
1292
+ nin, nout, check_broadcast, ctx);
1293
+ }
1294
+
1295
+ int
1296
+ gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx)
1297
+ {
1298
+ const gm_kernel_init_t *k;
1299
+
1300
+ for (k = unary_copy; k->name != NULL; k++) {
1301
+ if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_copy_typecheck) < 0) {
1302
+ return -1;
1303
+ }
1304
+ }
1305
+
1306
+ for (k = unary_reduce; k->name != NULL; k++) {
1307
+ if (gm_add_kernel(tbl, k, ctx) < 0) {
1308
+ return -1;
1309
+ }
1310
+ }
1311
+
1312
+ for (k = unary_invert; k->name != NULL; k++) {
1313
+ if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_invert_typecheck) < 0) {
1314
+ return -1;
1315
+ }
1316
+ }
1317
+
1318
+ for (k = unary_negative; k->name != NULL; k++) {
1319
+ if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_negative_typecheck) < 0) {
1320
+ return -1;
1321
+ }
1322
+ }
1323
+
1324
+ for (k = unary_float; k->name != NULL; k++) {
1325
+ if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_math_typecheck) < 0) {
1326
+ return -1;
1327
+ }
1328
+ }
1329
+
1330
+ return 0;
1331
+ }