gumath 0.2.0dev5 → 0.2.0dev8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,1331 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #include <stdlib.h>
35
+ #include <stdint.h>
36
+ #include <string.h>
37
+ #include "ndtypes.h"
38
+ #include "xnd.h"
39
+ #include "gumath.h"
40
+ #include "common.h"
41
+ #include "cuda_device_unary.h"
42
+
43
+
44
+ /****************************************************************************/
45
+ /* Kernel locations for optimized lookup */
46
+ /****************************************************************************/
47
+
48
+ static int
49
+ copy_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
50
+ {
51
+ const ndt_t *t = ndt_dtype(in);
52
+ const ndt_t *u = ndt_dtype(out);
53
+
54
+ switch (t->tag) {
55
+ case Bool: {
56
+ switch (u->tag) {
57
+ case Bool: return 0;
58
+ case Uint8: return 2;
59
+ case Uint16: return 4;
60
+ case Uint32: return 6;
61
+ case Uint64: return 8;
62
+ case Int8: return 10;
63
+ case Int16: return 12;
64
+ case Int32: return 14;
65
+ case Int64: return 16;
66
+ case BFloat16: return 18;
67
+ case Float16: return 20;
68
+ case Float32: return 22;
69
+ case Float64: return 24;
70
+ case Complex32: return 26;
71
+ case Complex64: return 28;
72
+ case Complex128: return 30;
73
+ default: goto invalid_combination;
74
+ }
75
+ }
76
+
77
+ case Uint8: {
78
+ switch (u->tag) {
79
+ case Uint8: return 32;
80
+ case Uint16: return 34;
81
+ case Uint32: return 36;
82
+ case Uint64: return 38;
83
+ case Int16: return 40;
84
+ case Int32: return 42;
85
+ case Int64: return 44;
86
+ case BFloat16: return 46;
87
+ case Float16: return 48;
88
+ case Float32: return 50;
89
+ case Float64: return 52;
90
+ case Complex32: return 54;
91
+ case Complex64: return 56;
92
+ case Complex128: return 58;
93
+ default: goto invalid_combination;
94
+ }
95
+ }
96
+
97
+ case Uint16: {
98
+ switch (u->tag) {
99
+ case Uint16: return 60;
100
+ case Uint32: return 62;
101
+ case Uint64: return 64;
102
+ case Int32: return 66;
103
+ case Int64: return 68;
104
+ case Float32: return 70;
105
+ case Float64: return 72;
106
+ case Complex64: return 74;
107
+ case Complex128: return 76;
108
+ default: goto invalid_combination;
109
+ }
110
+ }
111
+
112
+ case Uint32: {
113
+ switch (u->tag) {
114
+ case Uint32: return 78;
115
+ case Uint64: return 80;
116
+ case Int64: return 82;
117
+ case Float64: return 84;
118
+ case Complex128: return 86;
119
+ default: goto invalid_combination;
120
+ }
121
+ }
122
+
123
+ case Uint64: {
124
+ switch (u->tag) {
125
+ case Uint64: return 88;
126
+ default: goto invalid_combination;
127
+ }
128
+ }
129
+
130
+ case Int8: {
131
+ switch (u->tag) {
132
+ case Int8: return 90;
133
+ case Int16: return 92;
134
+ case Int32: return 94;
135
+ case Int64: return 96;
136
+ case BFloat16: return 98;
137
+ case Float16: return 100;
138
+ case Float32: return 102;
139
+ case Float64: return 104;
140
+ case Complex32: return 106;
141
+ case Complex64: return 108;
142
+ case Complex128: return 110;
143
+ default: goto invalid_combination;
144
+ }
145
+ }
146
+
147
+ case Int16: {
148
+ switch (u->tag) {
149
+ case Int16: return 112;
150
+ case Int32: return 114;
151
+ case Int64: return 116;
152
+ case Float32: return 118;
153
+ case Float64: return 120;
154
+ case Complex64: return 122;
155
+ case Complex128: return 124;
156
+ default: goto invalid_combination;
157
+ }
158
+ }
159
+
160
+ case Int32: {
161
+ switch (u->tag) {
162
+ case Int32: return 126;
163
+ case Int64: return 128;
164
+ case Float64: return 130;
165
+ case Complex128: return 132;
166
+ default: goto invalid_combination;
167
+ }
168
+ }
169
+
170
+ case Int64: {
171
+ switch (u->tag) {
172
+ case Int64: return 134;
173
+ default: goto invalid_combination;
174
+ }
175
+ }
176
+
177
+ case BFloat16: {
178
+ switch (u->tag) {
179
+ case BFloat16: return 136;
180
+ case Float32: return 138;
181
+ case Float64: return 140;
182
+ case Complex64: return 142;
183
+ case Complex128: return 144;
184
+ default: goto invalid_combination;
185
+ }
186
+ }
187
+
188
+ case Float16: {
189
+ switch (u->tag) {
190
+ case Float16: return 146;
191
+ case Float32: return 148;
192
+ case Float64: return 150;
193
+ case Complex32: return 152;
194
+ case Complex64: return 154;
195
+ case Complex128: return 156;
196
+ default: goto invalid_combination;
197
+ }
198
+ }
199
+
200
+ case Float32: {
201
+ switch (u->tag) {
202
+ case Float32: return 158;
203
+ case Float64: return 160;
204
+ case Complex64: return 162;
205
+ case Complex128: return 164;
206
+ default: goto invalid_combination;
207
+ }
208
+ }
209
+
210
+ case Float64: {
211
+ switch (u->tag) {
212
+ case Float64: return 166;
213
+ case Complex128: return 168;
214
+ default: goto invalid_combination;
215
+ }
216
+ }
217
+
218
+ case Complex32: {
219
+ switch (u->tag) {
220
+ case Complex32: return 170;
221
+ case Complex64: return 172;
222
+ case Complex128: return 174;
223
+ default: goto invalid_combination;
224
+ }
225
+ }
226
+
227
+ case Complex64: {
228
+ switch (u->tag) {
229
+ case Complex64: return 176;
230
+ case Complex128: return 178;
231
+ default: goto invalid_combination;
232
+ }
233
+ }
234
+
235
+ case Complex128: {
236
+ switch (u->tag) {
237
+ case Complex128: return 180;
238
+ default: goto invalid_combination;
239
+ }
240
+ }
241
+
242
+ default: goto invalid_combination;
243
+ }
244
+
245
+ invalid_combination:
246
+ ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
247
+ return -1;
248
+ }
249
+
250
+ static int
251
+ invert_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
252
+ {
253
+ const ndt_t *t = ndt_dtype(in);
254
+ (void)out;
255
+
256
+ switch (t->tag) {
257
+ case Bool: return 0;
258
+
259
+ case Uint8: return 2;
260
+ case Uint16: return 4;
261
+ case Uint32: return 6;
262
+ case Uint64: return 8;
263
+
264
+ case Int8: return 10;
265
+ case Int16: return 12;
266
+ case Int32: return 14;
267
+ case Int64: return 16;
268
+
269
+ default:
270
+ ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
271
+ return -1;
272
+ }
273
+ }
274
+
275
+ static int
276
+ negative_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
277
+ {
278
+ const ndt_t *t = ndt_dtype(in);
279
+ (void)out;
280
+
281
+ switch (t->tag) {
282
+ case Uint8: return 0;
283
+ case Uint16: return 2;
284
+ case Uint32: return 4;
285
+
286
+ case Int8: return 6;
287
+ case Int16: return 8;
288
+ case Int32: return 10;
289
+ case Int64: return 12;
290
+
291
+ case BFloat16: return 14;
292
+ case Float16: return 16;
293
+ case Float32: return 18;
294
+ case Float64: return 20;
295
+
296
+ case Complex32: return 22;
297
+ case Complex64: return 24;
298
+ case Complex128: return 26;
299
+
300
+ default:
301
+ ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
302
+ return -1;
303
+ }
304
+ }
305
+
306
+ static int
307
+ math_kernel_location(const ndt_t *in, const ndt_t *out, ndt_context_t *ctx)
308
+ {
309
+ const ndt_t *t = ndt_dtype(in);
310
+ (void)out;
311
+
312
+ switch (t->tag) {
313
+ case Uint8: return 0;
314
+ case Int8: return 2;
315
+ case Float16: return 4;
316
+
317
+ case BFloat16: return 6;
318
+
319
+ case Uint16: return 8;
320
+ case Int16: return 10;
321
+ case Float32: return 12;
322
+
323
+ case Uint32: return 14;
324
+ case Int32: return 16;
325
+ case Float64: return 18;
326
+
327
+ case Complex32: return 20;
328
+ case Complex64: return 22;
329
+ case Complex128: return 24;
330
+
331
+ default:
332
+ ndt_err_format(ctx, NDT_ValueError, "invalid dtype");
333
+ return -1;
334
+ }
335
+ }
336
+
337
+
338
+ /*****************************************************************************/
339
+ /* CUDA-specific unary macros */
340
+ /*****************************************************************************/
341
+
342
+ #define CUDA_HOST_UNARY(name, t0, t1) \
343
+ static int \
344
+ gm_cuda_host_fixed_1D_C_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
345
+ { \
346
+ const char *a0 = apply_index(&stack[0]); \
347
+ char *a1 = apply_index(&stack[1]); \
348
+ const int64_t N = xnd_fixed_shape(&stack[0]); \
349
+ (void)ctx; \
350
+ \
351
+ gm_cuda_device_fixed_1D_C_##name##_##t0##_##t1(a0, a1, N); \
352
+ \
353
+ if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
354
+ unary_update_bitmap_1D_S(stack); \
355
+ } \
356
+ \
357
+ return 0; \
358
+ } \
359
+ \
360
+ static int \
361
+ gm_cuda_host_fixed_1D_S_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
362
+ { \
363
+ const char *a0 = apply_index(&stack[0]); \
364
+ char *a1 = apply_index(&stack[1]); \
365
+ const int64_t N = xnd_fixed_shape(&stack[0]); \
366
+ const int64_t s0 = xnd_fixed_step(&stack[0]); \
367
+ const int64_t s1 = xnd_fixed_step(&stack[1]); \
368
+ (void)ctx; \
369
+ \
370
+ gm_cuda_device_fixed_1D_S_##name##_##t0##_##t1(a0, a1, s0, s1, N); \
371
+ \
372
+ if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
373
+ unary_update_bitmap_1D_S(stack); \
374
+ } \
375
+ \
376
+ return 0; \
377
+ } \
378
+ \
379
+ static int \
380
+ gm_cuda_host_0D_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
381
+ { \
382
+ const char *a0 = stack[0].ptr; \
383
+ char *a1 = stack[1].ptr; \
384
+ (void)ctx; \
385
+ \
386
+ gm_cuda_device_0D_##name##_##t0##_##t1(a0, a1); \
387
+ \
388
+ if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
389
+ unary_update_bitmap_0D(stack); \
390
+ } \
391
+ \
392
+ return 0; \
393
+ }
394
+
395
+
396
+ #define CUDA_HOST_NOIMPL(name, t0, t1) \
397
+ static int \
398
+ gm_cuda_host_fixed_1D_C_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
399
+ { \
400
+ (void)stack; \
401
+ \
402
+ ndt_err_format(ctx, NDT_NotImplementedError, \
403
+ "implementation for " STRINGIZE(name) " : " \
404
+ STRINGIZE(t0) " -> " STRINGIZE(t1) \
405
+ " currently requires double rounding"); \
406
+ \
407
+ return -1; \
408
+ } \
409
+ \
410
+ static int \
411
+ gm_cuda_host_fixed_1D_S_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
412
+ { \
413
+ (void)stack; \
414
+ \
415
+ ndt_err_format(ctx, NDT_NotImplementedError, \
416
+ "implementation for " STRINGIZE(name) " : " \
417
+ STRINGIZE(t0) " -> " STRINGIZE(t1) \
418
+ " currently requires double rounding"); \
419
+ \
420
+ return -1; \
421
+ } \
422
+ \
423
+ static int \
424
+ gm_cuda_host_0D_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
425
+ { \
426
+ (void)stack; \
427
+ \
428
+ ndt_err_format(ctx, NDT_NotImplementedError, \
429
+ "implementation for " STRINGIZE(name) " : " \
430
+ STRINGIZE(t0) " -> " STRINGIZE(t1) \
431
+ " currently requires double rounding"); \
432
+ \
433
+ return -1; \
434
+ }
435
+
436
+
437
+ #define CUDA_HOST_UNARY_REDUCE(name, t0, t1) \
438
+ static int \
439
+ gm_cuda_host_1D_C_reduce_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
440
+ { \
441
+ const char *a0 = apply_index(&stack[0]); \
442
+ char *a1 = stack[1].ptr; \
443
+ const int64_t N = xnd_fixed_shape(&stack[0]); \
444
+ (void)ctx; \
445
+ \
446
+ gm_cuda_device_1D_C_reduce_##name##_##t0##_##t1(a0, a1, N); \
447
+ \
448
+ if (ndt_is_optional(ndt_dtype(stack[1].type))) { \
449
+ unary_reduce_bitmap_1D_S(stack); \
450
+ } \
451
+ \
452
+ return 0; \
453
+ }
454
+
455
+ #define CUDA_HOST_REDUCE_NOIMPL(name, t0, t1) \
456
+ static int \
457
+ gm_cuda_host_1D_C_reduce_##name##_##t0##_##t1(xnd_t stack[], ndt_context_t *ctx) \
458
+ { \
459
+ (void)stack; \
460
+ \
461
+ ndt_err_format(ctx, NDT_NotImplementedError, \
462
+ "No cuda thrust implementation for: " STRINGIZE(name) " : " \
463
+ STRINGIZE(t0) " -> " STRINGIZE(t1)); \
464
+ \
465
+ return -1; \
466
+ }
467
+
468
+
469
+ #define CUDA_HOST_UNARY_INIT(funcname, func, t0, t1) \
470
+ { .name = STRINGIZE(funcname), \
471
+ .sig = "... * " STRINGIZE(t0) " -> ... * " STRINGIZE(t1), \
472
+ .OptC = gm_cuda_host_fixed_1D_C_##func##_##t0##_##t1, \
473
+ .OptS = gm_cuda_host_fixed_1D_S_##func##_##t0##_##t1, \
474
+ .C = gm_cuda_host_0D_##func##_##t0##_##t1 }, \
475
+ \
476
+ { .name = STRINGIZE(funcname), \
477
+ .sig = "... * ?" STRINGIZE(t0) " -> ... * ?" STRINGIZE(t1), \
478
+ .OptC = gm_cuda_host_fixed_1D_C_##func##_##t0##_##t1, \
479
+ .OptS = gm_cuda_host_fixed_1D_S_##func##_##t0##_##t1, \
480
+ .C = gm_cuda_host_0D_##func##_##t0##_##t1 } \
481
+
482
+
483
+ #define CUDA_HOST_UNARY_REDUCE_INIT(funcname, func, t0, t1) \
484
+ { .name = "reduce_" STRINGIZE(funcname), \
485
+ .sig = "N * " STRINGIZE(t0) " -> " STRINGIZE(t1), \
486
+ .C = gm_cuda_host_1D_C_reduce_##func##_##t0##_##t1 }, \
487
+ \
488
+ { .name = "reduce_" STRINGIZE(funcname), \
489
+ .sig = "N * ?" STRINGIZE(t0) " -> ?" STRINGIZE(t1), \
490
+ .C = gm_cuda_host_1D_C_reduce_##func##_##t0##_##t1 }
491
+
492
+
493
+ #undef bool
494
+ #define bool_t _Bool
495
+
496
+
497
+ /*****************************************************************************/
498
+ /* Copy */
499
+ /*****************************************************************************/
500
+
501
+ #define CUDA_HOST_ALL_UNARY(name) \
502
+ CUDA_HOST_UNARY(name, bool, bool) \
503
+ CUDA_HOST_UNARY(name, bool, uint8) \
504
+ CUDA_HOST_UNARY(name, bool, uint16) \
505
+ CUDA_HOST_UNARY(name, bool, uint32) \
506
+ CUDA_HOST_UNARY(name, bool, uint64) \
507
+ CUDA_HOST_UNARY(name, bool, int8) \
508
+ CUDA_HOST_UNARY(name, bool, int16) \
509
+ CUDA_HOST_UNARY(name, bool, int32) \
510
+ CUDA_HOST_UNARY(name, bool, int64) \
511
+ CUDA_HOST_UNARY(name, bool, bfloat16) \
512
+ CUDA_HOST_UNARY(name, bool, float16) \
513
+ CUDA_HOST_UNARY(name, bool, float32) \
514
+ CUDA_HOST_UNARY(name, bool, float64) \
515
+ CUDA_HOST_NOIMPL(name,bool, complex32) \
516
+ CUDA_HOST_UNARY(name, bool, complex64) \
517
+ CUDA_HOST_UNARY(name, bool, complex128) \
518
+ \
519
+ CUDA_HOST_UNARY(name, uint8, uint8) \
520
+ CUDA_HOST_UNARY(name, uint8, uint16) \
521
+ CUDA_HOST_UNARY(name, uint8, uint32) \
522
+ CUDA_HOST_UNARY(name, uint8, uint64) \
523
+ CUDA_HOST_UNARY(name, uint8, int16) \
524
+ CUDA_HOST_UNARY(name, uint8, int32) \
525
+ CUDA_HOST_UNARY(name, uint8, int64) \
526
+ CUDA_HOST_UNARY(name, uint8, bfloat16) \
527
+ CUDA_HOST_UNARY(name, uint8, float16) \
528
+ CUDA_HOST_UNARY(name, uint8, float32) \
529
+ CUDA_HOST_UNARY(name, uint8, float64) \
530
+ CUDA_HOST_NOIMPL(name, uint8, complex32) \
531
+ CUDA_HOST_UNARY(name, uint8, complex64) \
532
+ CUDA_HOST_UNARY(name, uint8, complex128) \
533
+ \
534
+ CUDA_HOST_UNARY(name, uint16, uint16) \
535
+ CUDA_HOST_UNARY(name, uint16, uint32) \
536
+ CUDA_HOST_UNARY(name, uint16, uint64) \
537
+ CUDA_HOST_UNARY(name, uint16, int32) \
538
+ CUDA_HOST_UNARY(name, uint16, int64) \
539
+ CUDA_HOST_UNARY(name, uint16, float32) \
540
+ CUDA_HOST_UNARY(name, uint16, float64) \
541
+ CUDA_HOST_UNARY(name, uint16, complex64) \
542
+ CUDA_HOST_UNARY(name, uint16, complex128) \
543
+ \
544
+ CUDA_HOST_UNARY(name, uint32, uint32) \
545
+ CUDA_HOST_UNARY(name, uint32, uint64) \
546
+ CUDA_HOST_UNARY(name, uint32, int64) \
547
+ CUDA_HOST_UNARY(name, uint32, float64) \
548
+ CUDA_HOST_UNARY(name, uint32, complex128) \
549
+ \
550
+ CUDA_HOST_UNARY(name, uint64, uint64) \
551
+ \
552
+ CUDA_HOST_UNARY(name, int8, int8) \
553
+ CUDA_HOST_UNARY(name, int8, int16) \
554
+ CUDA_HOST_UNARY(name, int8, int32) \
555
+ CUDA_HOST_UNARY(name, int8, int64) \
556
+ CUDA_HOST_UNARY(name, int8, bfloat16) \
557
+ CUDA_HOST_UNARY(name, int8, float16) \
558
+ CUDA_HOST_UNARY(name, int8, float32) \
559
+ CUDA_HOST_UNARY(name, int8, float64) \
560
+ CUDA_HOST_NOIMPL(name, int8, complex32) \
561
+ CUDA_HOST_UNARY(name, int8, complex64) \
562
+ CUDA_HOST_UNARY(name, int8, complex128) \
563
+ \
564
+ CUDA_HOST_UNARY(name, int16, int16) \
565
+ CUDA_HOST_UNARY(name, int16, int32) \
566
+ CUDA_HOST_UNARY(name, int16, int64) \
567
+ CUDA_HOST_UNARY(name, int16, float32) \
568
+ CUDA_HOST_UNARY(name, int16, float64) \
569
+ CUDA_HOST_UNARY(name, int16, complex64) \
570
+ CUDA_HOST_UNARY(name, int16, complex128) \
571
+ \
572
+ CUDA_HOST_UNARY(name, int32, int32) \
573
+ CUDA_HOST_UNARY(name, int32, int64) \
574
+ CUDA_HOST_UNARY(name, int32, float64) \
575
+ CUDA_HOST_UNARY(name, int32, complex128) \
576
+ \
577
+ CUDA_HOST_UNARY(name, int64, int64) \
578
+ \
579
+ CUDA_HOST_UNARY(name, bfloat16, bfloat16) \
580
+ CUDA_HOST_UNARY(name, bfloat16, float32) \
581
+ CUDA_HOST_UNARY(name, bfloat16, float64) \
582
+ CUDA_HOST_UNARY(name, bfloat16, complex64) \
583
+ CUDA_HOST_UNARY(name, bfloat16, complex128) \
584
+ \
585
+ CUDA_HOST_UNARY(name, float16, float16) \
586
+ CUDA_HOST_UNARY(name, float16, float32) \
587
+ CUDA_HOST_UNARY(name, float16, float64) \
588
+ CUDA_HOST_NOIMPL(name, float16, complex32) \
589
+ CUDA_HOST_UNARY(name, float16, complex64) \
590
+ CUDA_HOST_UNARY(name, float16, complex128) \
591
+ \
592
+ CUDA_HOST_UNARY(name, float32, float32) \
593
+ CUDA_HOST_UNARY(name, float32, float64) \
594
+ CUDA_HOST_UNARY(name, float32, complex64) \
595
+ CUDA_HOST_UNARY(name, float32, complex128) \
596
+ \
597
+ CUDA_HOST_UNARY(name, float64, float64) \
598
+ CUDA_HOST_UNARY(name, float64, complex128) \
599
+ \
600
+ CUDA_HOST_NOIMPL(name, complex32, complex32) \
601
+ CUDA_HOST_NOIMPL(name, complex32, complex64) \
602
+ CUDA_HOST_NOIMPL(name, complex32, complex128) \
603
+ \
604
+ CUDA_HOST_UNARY(name, complex64, complex64) \
605
+ CUDA_HOST_UNARY(name, complex64, complex128) \
606
+ \
607
+ CUDA_HOST_UNARY(name, complex128, complex128)
608
+
609
+ #define CUDA_HOST_ALL_UNARY_INIT(name, func, hfunc) \
610
+ CUDA_HOST_UNARY_INIT(name, func, bool, bool), \
611
+ CUDA_HOST_UNARY_INIT(name, func, bool, uint8), \
612
+ CUDA_HOST_UNARY_INIT(name, func, bool, uint16), \
613
+ CUDA_HOST_UNARY_INIT(name, func, bool, uint32), \
614
+ CUDA_HOST_UNARY_INIT(name, func, bool, uint64), \
615
+ CUDA_HOST_UNARY_INIT(name, func, bool, int8), \
616
+ CUDA_HOST_UNARY_INIT(name, func, bool, int16), \
617
+ CUDA_HOST_UNARY_INIT(name, func, bool, int32), \
618
+ CUDA_HOST_UNARY_INIT(name, func, bool, int64), \
619
+ CUDA_HOST_UNARY_INIT(name, func, bool, bfloat16), \
620
+ CUDA_HOST_UNARY_INIT(name, hfunc, bool, float16), \
621
+ CUDA_HOST_UNARY_INIT(name, func, bool, float32), \
622
+ CUDA_HOST_UNARY_INIT(name, func, bool, float64), \
623
+ CUDA_HOST_UNARY_INIT(name, func, bool, complex32), \
624
+ CUDA_HOST_UNARY_INIT(name, func, bool, complex64), \
625
+ CUDA_HOST_UNARY_INIT(name, func, bool, complex128), \
626
+ \
627
+ CUDA_HOST_UNARY_INIT(name, func, uint8, uint8), \
628
+ CUDA_HOST_UNARY_INIT(name, func, uint8, uint16), \
629
+ CUDA_HOST_UNARY_INIT(name, func, uint8, uint32), \
630
+ CUDA_HOST_UNARY_INIT(name, func, uint8, uint64), \
631
+ CUDA_HOST_UNARY_INIT(name, func, uint8, int16), \
632
+ CUDA_HOST_UNARY_INIT(name, func, uint8, int32), \
633
+ CUDA_HOST_UNARY_INIT(name, func, uint8, int64), \
634
+ CUDA_HOST_UNARY_INIT(name, func, uint8, bfloat16), \
635
+ CUDA_HOST_UNARY_INIT(name, hfunc, uint8, float16), \
636
+ CUDA_HOST_UNARY_INIT(name, func, uint8, float32), \
637
+ CUDA_HOST_UNARY_INIT(name, func, uint8, float64), \
638
+ CUDA_HOST_UNARY_INIT(name, func, uint8, complex32), \
639
+ CUDA_HOST_UNARY_INIT(name, func, uint8, complex64), \
640
+ CUDA_HOST_UNARY_INIT(name, func, uint8, complex128), \
641
+ \
642
+ CUDA_HOST_UNARY_INIT(name, func, uint16, uint16), \
643
+ CUDA_HOST_UNARY_INIT(name, func, uint16, uint32), \
644
+ CUDA_HOST_UNARY_INIT(name, func, uint16, uint64), \
645
+ CUDA_HOST_UNARY_INIT(name, func, uint16, int32), \
646
+ CUDA_HOST_UNARY_INIT(name, func, uint16, int64), \
647
+ CUDA_HOST_UNARY_INIT(name, func, uint16, float32), \
648
+ CUDA_HOST_UNARY_INIT(name, func, uint16, float64), \
649
+ CUDA_HOST_UNARY_INIT(name, func, uint16, complex64), \
650
+ CUDA_HOST_UNARY_INIT(name, func, uint16, complex128), \
651
+ \
652
+ CUDA_HOST_UNARY_INIT(name, func, uint32, uint32), \
653
+ CUDA_HOST_UNARY_INIT(name, func, uint32, uint64), \
654
+ CUDA_HOST_UNARY_INIT(name, func, uint32, int64), \
655
+ CUDA_HOST_UNARY_INIT(name, func, uint32, float64), \
656
+ CUDA_HOST_UNARY_INIT(name, func, uint32, complex128), \
657
+ \
658
+ CUDA_HOST_UNARY_INIT(name, func, uint64, uint64), \
659
+ \
660
+ CUDA_HOST_UNARY_INIT(name, func, int8, int8), \
661
+ CUDA_HOST_UNARY_INIT(name, func, int8, int16), \
662
+ CUDA_HOST_UNARY_INIT(name, func, int8, int32), \
663
+ CUDA_HOST_UNARY_INIT(name, func, int8, int64), \
664
+ CUDA_HOST_UNARY_INIT(name, func, int8, bfloat16), \
665
+ CUDA_HOST_UNARY_INIT(name, hfunc, int8, float16), \
666
+ CUDA_HOST_UNARY_INIT(name, func, int8, float32), \
667
+ CUDA_HOST_UNARY_INIT(name, func, int8, float64), \
668
+ CUDA_HOST_UNARY_INIT(name, func, int8, complex32), \
669
+ CUDA_HOST_UNARY_INIT(name, func, int8, complex64), \
670
+ CUDA_HOST_UNARY_INIT(name, func, int8, complex128), \
671
+ \
672
+ CUDA_HOST_UNARY_INIT(name, func, int16, int16), \
673
+ CUDA_HOST_UNARY_INIT(name, func, int16, int32), \
674
+ CUDA_HOST_UNARY_INIT(name, func, int16, int64), \
675
+ CUDA_HOST_UNARY_INIT(name, func, int16, float32), \
676
+ CUDA_HOST_UNARY_INIT(name, func, int16, float64), \
677
+ CUDA_HOST_UNARY_INIT(name, func, int16, complex64), \
678
+ CUDA_HOST_UNARY_INIT(name, func, int16, complex128), \
679
+ \
680
+ CUDA_HOST_UNARY_INIT(name, func, int32, int32), \
681
+ CUDA_HOST_UNARY_INIT(name, func, int32, int64), \
682
+ CUDA_HOST_UNARY_INIT(name, func, int32, float64), \
683
+ CUDA_HOST_UNARY_INIT(name, func, int32, complex128), \
684
+ \
685
+ CUDA_HOST_UNARY_INIT(name, func, int64, int64), \
686
+ \
687
+ CUDA_HOST_UNARY_INIT(name, func, bfloat16, bfloat16), \
688
+ CUDA_HOST_UNARY_INIT(name, func, bfloat16, float32), \
689
+ CUDA_HOST_UNARY_INIT(name, func, bfloat16, float64), \
690
+ CUDA_HOST_UNARY_INIT(name, func, bfloat16, complex64), \
691
+ CUDA_HOST_UNARY_INIT(name, func, bfloat16, complex128), \
692
+ \
693
+ CUDA_HOST_UNARY_INIT(name, hfunc, float16, float16), \
694
+ CUDA_HOST_UNARY_INIT(name, func, float16, float32), \
695
+ CUDA_HOST_UNARY_INIT(name, func, float16, float64), \
696
+ CUDA_HOST_UNARY_INIT(name, func, float16, complex32), \
697
+ CUDA_HOST_UNARY_INIT(name, func, float16, complex64), \
698
+ CUDA_HOST_UNARY_INIT(name, func, float16, complex128), \
699
+ \
700
+ CUDA_HOST_UNARY_INIT(name, func, float32, float32), \
701
+ CUDA_HOST_UNARY_INIT(name, func, float32, float64), \
702
+ CUDA_HOST_UNARY_INIT(name, func, float32, complex64), \
703
+ CUDA_HOST_UNARY_INIT(name, func, float32, complex128), \
704
+ \
705
+ CUDA_HOST_UNARY_INIT(name, func, float64, float64), \
706
+ CUDA_HOST_UNARY_INIT(name, func, float64, complex128), \
707
+ \
708
+ CUDA_HOST_UNARY_INIT(name, func, complex32, complex32), \
709
+ CUDA_HOST_UNARY_INIT(name, func, complex32, complex64), \
710
+ CUDA_HOST_UNARY_INIT(name, func, complex32, complex128), \
711
+ \
712
+ CUDA_HOST_UNARY_INIT(name, func, complex64, complex64), \
713
+ CUDA_HOST_UNARY_INIT(name, func, complex64, complex128), \
714
+ \
715
+ CUDA_HOST_UNARY_INIT(name, func, complex128, complex128)
716
+
717
+
718
+ CUDA_HOST_ALL_UNARY(copy)
719
+ CUDA_HOST_ALL_UNARY(abs)
720
+
721
+
722
+ static const gm_kernel_init_t unary_copy[] = {
723
+ /* COPY */
724
+ CUDA_HOST_ALL_UNARY_INIT(copy, copy, copy),
725
+ CUDA_HOST_ALL_UNARY_INIT(abs, abs, abs),
726
+
727
+ { .name = NULL, .sig = NULL }
728
+ };
729
+
730
+ /*****************************************************************************/
731
+ /* Reduce */
732
+ /*****************************************************************************/
733
+
734
+ #define CUDA_HOST_ALL_UNARY_REDUCE(name) \
735
+ CUDA_HOST_UNARY_REDUCE(name, bool, bool) \
736
+ CUDA_HOST_UNARY_REDUCE(name, bool, uint8) \
737
+ CUDA_HOST_UNARY_REDUCE(name, bool, uint16) \
738
+ CUDA_HOST_UNARY_REDUCE(name, bool, uint32) \
739
+ CUDA_HOST_UNARY_REDUCE(name, bool, uint64) \
740
+ CUDA_HOST_UNARY_REDUCE(name, bool, int8) \
741
+ CUDA_HOST_UNARY_REDUCE(name, bool, int16) \
742
+ CUDA_HOST_UNARY_REDUCE(name, bool, int32) \
743
+ CUDA_HOST_UNARY_REDUCE(name, bool, int64) \
744
+ CUDA_HOST_REDUCE_NOIMPL(name, bool, bfloat16) \
745
+ CUDA_HOST_UNARY_REDUCE(name, bool, float16) \
746
+ CUDA_HOST_UNARY_REDUCE(name, bool, float32) \
747
+ CUDA_HOST_UNARY_REDUCE(name, bool, float64) \
748
+ CUDA_HOST_REDUCE_NOIMPL(name,bool, complex32) \
749
+ CUDA_HOST_REDUCE_NOIMPL(name, bool, complex64) \
750
+ CUDA_HOST_REDUCE_NOIMPL(name, bool, complex128) \
751
+ \
752
+ CUDA_HOST_UNARY_REDUCE(name, uint8, uint8) \
753
+ CUDA_HOST_UNARY_REDUCE(name, uint8, uint16) \
754
+ CUDA_HOST_UNARY_REDUCE(name, uint8, uint32) \
755
+ CUDA_HOST_UNARY_REDUCE(name, uint8, uint64) \
756
+ CUDA_HOST_UNARY_REDUCE(name, uint8, int16) \
757
+ CUDA_HOST_UNARY_REDUCE(name, uint8, int32) \
758
+ CUDA_HOST_UNARY_REDUCE(name, uint8, int64) \
759
+ CUDA_HOST_REDUCE_NOIMPL(name, uint8, bfloat16) \
760
+ CUDA_HOST_UNARY_REDUCE(name, uint8, float16) \
761
+ CUDA_HOST_UNARY_REDUCE(name, uint8, float32) \
762
+ CUDA_HOST_UNARY_REDUCE(name, uint8, float64) \
763
+ CUDA_HOST_REDUCE_NOIMPL(name, uint8, complex32) \
764
+ CUDA_HOST_REDUCE_NOIMPL(name, uint8, complex64) \
765
+ CUDA_HOST_REDUCE_NOIMPL(name, uint8, complex128) \
766
+ \
767
+ CUDA_HOST_UNARY_REDUCE(name, uint16, uint16) \
768
+ CUDA_HOST_UNARY_REDUCE(name, uint16, uint32) \
769
+ CUDA_HOST_UNARY_REDUCE(name, uint16, uint64) \
770
+ CUDA_HOST_UNARY_REDUCE(name, uint16, int32) \
771
+ CUDA_HOST_UNARY_REDUCE(name, uint16, int64) \
772
+ CUDA_HOST_UNARY_REDUCE(name, uint16, float32) \
773
+ CUDA_HOST_UNARY_REDUCE(name, uint16, float64) \
774
+ CUDA_HOST_REDUCE_NOIMPL(name, uint16, complex64) \
775
+ CUDA_HOST_REDUCE_NOIMPL(name, uint16, complex128) \
776
+ \
777
+ CUDA_HOST_UNARY_REDUCE(name, uint32, uint32) \
778
+ CUDA_HOST_UNARY_REDUCE(name, uint32, uint64) \
779
+ CUDA_HOST_UNARY_REDUCE(name, uint32, int64) \
780
+ CUDA_HOST_UNARY_REDUCE(name, uint32, float64) \
781
+ CUDA_HOST_REDUCE_NOIMPL(name, uint32, complex128) \
782
+ \
783
+ CUDA_HOST_UNARY_REDUCE(name, uint64, uint64) \
784
+ \
785
+ CUDA_HOST_UNARY_REDUCE(name, int8, int8) \
786
+ CUDA_HOST_UNARY_REDUCE(name, int8, int16) \
787
+ CUDA_HOST_UNARY_REDUCE(name, int8, int32) \
788
+ CUDA_HOST_UNARY_REDUCE(name, int8, int64) \
789
+ CUDA_HOST_REDUCE_NOIMPL(name, int8, bfloat16) \
790
+ CUDA_HOST_UNARY_REDUCE(name, int8, float16) \
791
+ CUDA_HOST_UNARY_REDUCE(name, int8, float32) \
792
+ CUDA_HOST_UNARY_REDUCE(name, int8, float64) \
793
+ CUDA_HOST_REDUCE_NOIMPL(name, int8, complex32) \
794
+ CUDA_HOST_REDUCE_NOIMPL(name, int8, complex64) \
795
+ CUDA_HOST_REDUCE_NOIMPL(name, int8, complex128) \
796
+ \
797
+ CUDA_HOST_UNARY_REDUCE(name, int16, int16) \
798
+ CUDA_HOST_UNARY_REDUCE(name, int16, int32) \
799
+ CUDA_HOST_UNARY_REDUCE(name, int16, int64) \
800
+ CUDA_HOST_UNARY_REDUCE(name, int16, float32) \
801
+ CUDA_HOST_UNARY_REDUCE(name, int16, float64) \
802
+ CUDA_HOST_REDUCE_NOIMPL(name, int16, complex64) \
803
+ CUDA_HOST_REDUCE_NOIMPL(name, int16, complex128) \
804
+ \
805
+ CUDA_HOST_UNARY_REDUCE(name, int32, int32) \
806
+ CUDA_HOST_UNARY_REDUCE(name, int32, int64) \
807
+ CUDA_HOST_UNARY_REDUCE(name, int32, float64) \
808
+ CUDA_HOST_REDUCE_NOIMPL(name, int32, complex128) \
809
+ \
810
+ CUDA_HOST_UNARY_REDUCE(name, int64, int64) \
811
+ \
812
+ CUDA_HOST_REDUCE_NOIMPL(name, bfloat16, bfloat16) \
813
+ CUDA_HOST_REDUCE_NOIMPL(name, bfloat16, float32) \
814
+ CUDA_HOST_REDUCE_NOIMPL(name, bfloat16, float64) \
815
+ CUDA_HOST_REDUCE_NOIMPL(name, bfloat16, complex64) \
816
+ CUDA_HOST_REDUCE_NOIMPL(name, bfloat16, complex128) \
817
+ \
818
+ CUDA_HOST_UNARY_REDUCE(name, float16, float16) \
819
+ CUDA_HOST_REDUCE_NOIMPL(name, float16, float32) \
820
+ CUDA_HOST_REDUCE_NOIMPL(name, float16, float64) \
821
+ CUDA_HOST_REDUCE_NOIMPL(name, float16, complex32) \
822
+ CUDA_HOST_REDUCE_NOIMPL(name, float16, complex64) \
823
+ CUDA_HOST_REDUCE_NOIMPL(name, float16, complex128) \
824
+ \
825
+ CUDA_HOST_UNARY_REDUCE(name, float32, float32) \
826
+ CUDA_HOST_UNARY_REDUCE(name, float32, float64) \
827
+ CUDA_HOST_REDUCE_NOIMPL(name, float32, complex64) \
828
+ CUDA_HOST_REDUCE_NOIMPL(name, float32, complex128) \
829
+ \
830
+ CUDA_HOST_UNARY_REDUCE(name, float64, float64) \
831
+ CUDA_HOST_REDUCE_NOIMPL(name, float64, complex128) \
832
+ \
833
+ CUDA_HOST_REDUCE_NOIMPL(name, complex32, complex32) \
834
+ CUDA_HOST_REDUCE_NOIMPL(name, complex32, complex64) \
835
+ CUDA_HOST_REDUCE_NOIMPL(name, complex32, complex128) \
836
+ \
837
+ CUDA_HOST_REDUCE_NOIMPL(name, complex64, complex64) \
838
+ CUDA_HOST_REDUCE_NOIMPL(name, complex64, complex128) \
839
+ \
840
+ CUDA_HOST_REDUCE_NOIMPL(name, complex128, complex128)
841
+
842
+ #define CUDA_HOST_ALL_UNARY_REDUCE_INIT(name, func) \
843
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, bool), \
844
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, uint8), \
845
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, uint16), \
846
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, uint32), \
847
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, uint64), \
848
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, int8), \
849
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, int16), \
850
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, int32), \
851
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, int64), \
852
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, bfloat16), \
853
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, float16), \
854
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, float32), \
855
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, float64), \
856
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, complex32), \
857
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, complex64), \
858
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bool, complex128), \
859
+ \
860
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, uint8), \
861
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, uint16), \
862
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, uint32), \
863
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, uint64), \
864
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, int16), \
865
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, int32), \
866
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, int64), \
867
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, bfloat16), \
868
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, float16), \
869
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, float32), \
870
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, float64), \
871
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, complex32), \
872
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, complex64), \
873
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint8, complex128), \
874
+ \
875
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, uint16), \
876
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, uint32), \
877
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, uint64), \
878
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, int32), \
879
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, int64), \
880
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, float32), \
881
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, float64), \
882
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, complex64), \
883
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint16, complex128), \
884
+ \
885
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint32, uint32), \
886
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint32, uint64), \
887
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint32, int64), \
888
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint32, float64), \
889
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint32, complex128), \
890
+ \
891
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, uint64, uint64), \
892
+ \
893
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, int8), \
894
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, int16), \
895
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, int32), \
896
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, int64), \
897
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, bfloat16), \
898
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, float16), \
899
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, float32), \
900
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, float64), \
901
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, complex32), \
902
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, complex64), \
903
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int8, complex128), \
904
+ \
905
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, int16), \
906
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, int32), \
907
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, int64), \
908
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, float32), \
909
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, float64), \
910
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, complex64), \
911
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int16, complex128), \
912
+ \
913
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int32, int32), \
914
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int32, int64), \
915
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int32, float64), \
916
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int32, complex128), \
917
+ \
918
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, int64, int64), \
919
+ \
920
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bfloat16, bfloat16), \
921
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bfloat16, float32), \
922
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bfloat16, float64), \
923
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bfloat16, complex64), \
924
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, bfloat16, complex128), \
925
+ \
926
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, float16), \
927
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, float32), \
928
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, float64), \
929
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, complex32), \
930
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, complex64), \
931
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float16, complex128), \
932
+ \
933
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float32, float32), \
934
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float32, float64), \
935
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float32, complex64), \
936
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float32, complex128), \
937
+ \
938
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float64, float64), \
939
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, float64, complex128), \
940
+ \
941
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex32, complex32), \
942
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex32, complex64), \
943
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex32, complex128), \
944
+ \
945
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex64, complex64), \
946
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex64, complex128), \
947
+ \
948
+ CUDA_HOST_UNARY_REDUCE_INIT(name, func, complex128, complex128)
949
+
950
+
951
+ CUDA_HOST_ALL_UNARY_REDUCE(add)
952
+ CUDA_HOST_ALL_UNARY_REDUCE(multiply)
953
+
954
+
955
+ static const gm_kernel_init_t unary_reduce[] = {
956
+ /* REDUCE */
957
+ CUDA_HOST_ALL_UNARY_REDUCE_INIT(add, add),
958
+ CUDA_HOST_ALL_UNARY_REDUCE_INIT(multiply, multiply),
959
+
960
+ { .name = NULL, .sig = NULL }
961
+ };
962
+
963
+
964
+ /*****************************************************************************/
965
+ /* Bitwise NOT */
966
+ /*****************************************************************************/
967
+
968
+ CUDA_HOST_UNARY(invert, bool, bool)
969
+
970
+ CUDA_HOST_UNARY(invert, uint8, uint8)
971
+ CUDA_HOST_UNARY(invert, uint16, uint16)
972
+ CUDA_HOST_UNARY(invert, uint32, uint32)
973
+ CUDA_HOST_UNARY(invert, uint64, uint64)
974
+
975
+ CUDA_HOST_UNARY(invert, int8, int8)
976
+ CUDA_HOST_UNARY(invert, int16, int16)
977
+ CUDA_HOST_UNARY(invert, int32, int32)
978
+ CUDA_HOST_UNARY(invert, int64, int64)
979
+
980
+
981
+ static const gm_kernel_init_t unary_invert[] = {
982
+ /* INVERT */
983
+ CUDA_HOST_UNARY_INIT(invert, invert, bool, bool),
984
+
985
+ CUDA_HOST_UNARY_INIT(invert, invert, uint8, uint8),
986
+ CUDA_HOST_UNARY_INIT(invert, invert, uint16, uint16),
987
+ CUDA_HOST_UNARY_INIT(invert, invert, uint32, uint32),
988
+ CUDA_HOST_UNARY_INIT(invert, invert, uint64, uint64),
989
+
990
+ CUDA_HOST_UNARY_INIT(invert, invert, int8, int8),
991
+ CUDA_HOST_UNARY_INIT(invert, invert, int16, int16),
992
+ CUDA_HOST_UNARY_INIT(invert, invert, int32, int32),
993
+ CUDA_HOST_UNARY_INIT(invert, invert, int64, int64),
994
+
995
+ { .name = NULL, .sig = NULL }
996
+ };
997
+
998
+
999
+ /*****************************************************************************/
1000
+ /* Negative */
1001
+ /*****************************************************************************/
1002
+
1003
+ CUDA_HOST_UNARY(negative, uint8, int16)
1004
+ CUDA_HOST_UNARY(negative, uint16, int32)
1005
+ CUDA_HOST_UNARY(negative, uint32, int64)
1006
+
1007
+ CUDA_HOST_UNARY(negative, int8, int8)
1008
+ CUDA_HOST_UNARY(negative, int16, int16)
1009
+ CUDA_HOST_UNARY(negative, int32, int32)
1010
+ CUDA_HOST_UNARY(negative, int64, int64)
1011
+
1012
+ CUDA_HOST_UNARY(negative, bfloat16, bfloat16)
1013
+ CUDA_HOST_UNARY(negative, float16, float16)
1014
+ CUDA_HOST_UNARY(negative, float32, float32)
1015
+ CUDA_HOST_UNARY(negative, float64, float64)
1016
+
1017
+ CUDA_HOST_NOIMPL(negative, complex32, complex32)
1018
+ CUDA_HOST_UNARY(negative, complex64, complex64)
1019
+ CUDA_HOST_UNARY(negative, complex128, complex128)
1020
+
1021
+
1022
+ static const gm_kernel_init_t unary_negative[] = {
1023
+ /* NEGATIVE */
1024
+ CUDA_HOST_UNARY_INIT(negative, negative, uint8, int16),
1025
+ CUDA_HOST_UNARY_INIT(negative, negative, uint16, int32),
1026
+ CUDA_HOST_UNARY_INIT(negative, negative, uint32, int64),
1027
+
1028
+ CUDA_HOST_UNARY_INIT(negative, negative, int8, int8),
1029
+ CUDA_HOST_UNARY_INIT(negative, negative, int16, int16),
1030
+ CUDA_HOST_UNARY_INIT(negative, negative, int32, int32),
1031
+ CUDA_HOST_UNARY_INIT(negative, negative, int64, int64),
1032
+
1033
+ CUDA_HOST_UNARY_INIT(negative, negative, bfloat16, bfloat16),
1034
+ CUDA_HOST_UNARY_INIT(negative, negative, float16, float16),
1035
+ CUDA_HOST_UNARY_INIT(negative, negative, float32, float32),
1036
+ CUDA_HOST_UNARY_INIT(negative, negative, float64, float64),
1037
+
1038
+ CUDA_HOST_UNARY_INIT(negative, negative, complex32, complex32),
1039
+ CUDA_HOST_UNARY_INIT(negative, negative, complex64, complex64),
1040
+ CUDA_HOST_UNARY_INIT(negative, negative, complex128, complex128),
1041
+
1042
+ { .name = NULL, .sig = NULL }
1043
+ };
1044
+
1045
+
1046
+ /*****************************************************************************/
1047
+ /* Math */
1048
+ /*****************************************************************************/
1049
+
1050
+ #define _CUDA_ALL_HALF_MATH(name) \
1051
+ CUDA_HOST_UNARY(name##f16, uint8, float16) \
1052
+ CUDA_HOST_UNARY(name##f16, int8, float16) \
1053
+ CUDA_HOST_UNARY(name##f16, float16, float16)
1054
+
1055
+ #define _CUDA_ALL_HALF_MATH_NOIMPL(name) \
1056
+ CUDA_HOST_NOIMPL(name##f16, uint8, float16) \
1057
+ CUDA_HOST_NOIMPL(name##f16, int8, float16) \
1058
+ CUDA_HOST_NOIMPL(name##f16, float16, float16)
1059
+
1060
+ #define _CUDA_ALL_COMPLEX_MATH(name) \
1061
+ CUDA_HOST_NOIMPL(name, complex32, complex32) \
1062
+ CUDA_HOST_UNARY(name, complex64, complex64) \
1063
+ CUDA_HOST_UNARY(name, complex128, complex128)
1064
+
1065
+ #define _CUDA_ALL_COMPLEX_MATH_NOIMPL(name) \
1066
+ CUDA_HOST_NOIMPL(name, complex32, complex32) \
1067
+ CUDA_HOST_NOIMPL(name, complex64, complex64) \
1068
+ CUDA_HOST_NOIMPL(name, complex128, complex128)
1069
+
1070
+ #define _CUDA_ALL_REAL_MATH(name) \
1071
+ CUDA_HOST_UNARY(name##b16, bfloat16, bfloat16) \
1072
+ CUDA_HOST_UNARY(name##f, uint16, float32) \
1073
+ CUDA_HOST_UNARY(name##f, int16, float32) \
1074
+ CUDA_HOST_UNARY(name##f, float32, float32) \
1075
+ CUDA_HOST_UNARY(name, uint32, float64) \
1076
+ CUDA_HOST_UNARY(name, int32, float64) \
1077
+ CUDA_HOST_UNARY(name, float64, float64) \
1078
+
1079
+ #define CUDA_ALL_REAL_MATH(name) \
1080
+ _CUDA_ALL_HALF_MATH_NOIMPL(name) \
1081
+ _CUDA_ALL_REAL_MATH(name) \
1082
+ _CUDA_ALL_COMPLEX_MATH_NOIMPL(name)
1083
+
1084
+ #define CUDA_ALL_REAL_MATH_WITH_HALF(name) \
1085
+ _CUDA_ALL_HALF_MATH(name) \
1086
+ _CUDA_ALL_REAL_MATH(name) \
1087
+ _CUDA_ALL_COMPLEX_MATH_NOIMPL(name)
1088
+
1089
+ #define CUDA_ALL_COMPLEX_MATH(name) \
1090
+ _CUDA_ALL_HALF_MATH_NOIMPL(name) \
1091
+ _CUDA_ALL_REAL_MATH(name) \
1092
+ _CUDA_ALL_COMPLEX_MATH(name)
1093
+
1094
+ #define CUDA_ALL_COMPLEX_MATH_WITH_HALF(name) \
1095
+ _CUDA_ALL_HALF_MATH(name) \
1096
+ _CUDA_ALL_REAL_MATH(name) \
1097
+ _CUDA_ALL_COMPLEX_MATH(name) \
1098
+
1099
+
1100
+ #define CUDA_ALL_UNARY_MATH_INIT(name) \
1101
+ CUDA_HOST_UNARY_INIT(name, name##f16, uint8, float16), \
1102
+ CUDA_HOST_UNARY_INIT(name, name##f16, int8, float16), \
1103
+ CUDA_HOST_UNARY_INIT(name, name##f16, float16, float16), \
1104
+ \
1105
+ CUDA_HOST_UNARY_INIT(name, name##b16, bfloat16, bfloat16), \
1106
+ \
1107
+ CUDA_HOST_UNARY_INIT(name, name##f, uint16, float32), \
1108
+ CUDA_HOST_UNARY_INIT(name, name##f, int16, float32), \
1109
+ CUDA_HOST_UNARY_INIT(name, name##f, float32, float32), \
1110
+ \
1111
+ CUDA_HOST_UNARY_INIT(name, name, uint32, float64), \
1112
+ CUDA_HOST_UNARY_INIT(name, name, int32, float64), \
1113
+ CUDA_HOST_UNARY_INIT(name, name, float64, float64), \
1114
+ \
1115
+ CUDA_HOST_UNARY_INIT(name, name, complex32, complex32), \
1116
+ CUDA_HOST_UNARY_INIT(name, name, complex64, complex64), \
1117
+ CUDA_HOST_UNARY_INIT(name, name, complex128, complex128)
1118
+
1119
+
1120
+ /*****************************************************************************/
1121
+ /* Abs functions */
1122
+ /*****************************************************************************/
1123
+
1124
+ CUDA_ALL_REAL_MATH_WITH_HALF(fabs)
1125
+
1126
+
1127
+ /*****************************************************************************/
1128
+ /* Exponential functions */
1129
+ /*****************************************************************************/
1130
+
1131
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(exp)
1132
+ CUDA_ALL_REAL_MATH_WITH_HALF(exp2)
1133
+ CUDA_ALL_REAL_MATH(expm1)
1134
+
1135
+
1136
+ /*****************************************************************************/
1137
+ /* Logarithm functions */
1138
+ /*****************************************************************************/
1139
+
1140
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(log)
1141
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(log10)
1142
+ CUDA_ALL_REAL_MATH_WITH_HALF(log2)
1143
+ CUDA_ALL_REAL_MATH(log1p)
1144
+ CUDA_ALL_REAL_MATH(logb)
1145
+
1146
+
1147
+ /*****************************************************************************/
1148
+ /* Power functions */
1149
+ /*****************************************************************************/
1150
+
1151
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(sqrt)
1152
+ CUDA_ALL_REAL_MATH(cbrt)
1153
+
1154
+
1155
+ /*****************************************************************************/
1156
+ /* Trigonometric functions */
1157
+ /*****************************************************************************/
1158
+
1159
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(sin)
1160
+ CUDA_ALL_COMPLEX_MATH_WITH_HALF(cos)
1161
+ CUDA_ALL_COMPLEX_MATH(tan)
1162
+ CUDA_ALL_COMPLEX_MATH(asin)
1163
+ CUDA_ALL_COMPLEX_MATH(acos)
1164
+ CUDA_ALL_COMPLEX_MATH(atan)
1165
+
1166
+
1167
+ /*****************************************************************************/
1168
+ /* Hyperbolic functions */
1169
+ /*****************************************************************************/
1170
+
1171
+ CUDA_ALL_COMPLEX_MATH(sinh)
1172
+ CUDA_ALL_COMPLEX_MATH(cosh)
1173
+ CUDA_ALL_COMPLEX_MATH(tanh)
1174
+ CUDA_ALL_COMPLEX_MATH(asinh)
1175
+ CUDA_ALL_COMPLEX_MATH(acosh)
1176
+ CUDA_ALL_COMPLEX_MATH(atanh)
1177
+
1178
+
1179
+ /*****************************************************************************/
1180
+ /* Error and gamma functions */
1181
+ /*****************************************************************************/
1182
+
1183
+ CUDA_ALL_REAL_MATH(erf)
1184
+ CUDA_ALL_REAL_MATH(erfc)
1185
+ CUDA_ALL_REAL_MATH(lgamma)
1186
+ CUDA_ALL_REAL_MATH(tgamma)
1187
+
1188
+
1189
+ /*****************************************************************************/
1190
+ /* Ceiling, floor, trunc */
1191
+ /*****************************************************************************/
1192
+
1193
+ CUDA_ALL_REAL_MATH(ceil)
1194
+ CUDA_ALL_REAL_MATH(floor)
1195
+ CUDA_ALL_REAL_MATH(trunc)
1196
+ CUDA_ALL_REAL_MATH(round)
1197
+ CUDA_ALL_REAL_MATH(nearbyint)
1198
+
1199
+
1200
+ static const gm_kernel_init_t unary_float[] = {
1201
+ /* ABS */
1202
+ CUDA_ALL_UNARY_MATH_INIT(fabs),
1203
+
1204
+ /* EXPONENTIAL */
1205
+ CUDA_ALL_UNARY_MATH_INIT(exp),
1206
+ CUDA_ALL_UNARY_MATH_INIT(exp2),
1207
+ CUDA_ALL_UNARY_MATH_INIT(expm1),
1208
+
1209
+ /* LOGARITHM */
1210
+ CUDA_ALL_UNARY_MATH_INIT(log),
1211
+ CUDA_ALL_UNARY_MATH_INIT(log2),
1212
+ CUDA_ALL_UNARY_MATH_INIT(log10),
1213
+ CUDA_ALL_UNARY_MATH_INIT(log1p),
1214
+ CUDA_ALL_UNARY_MATH_INIT(logb),
1215
+
1216
+ /* POWER */
1217
+ CUDA_ALL_UNARY_MATH_INIT(sqrt),
1218
+ CUDA_ALL_UNARY_MATH_INIT(cbrt),
1219
+
1220
+ /* TRIGONOMETRIC */
1221
+ CUDA_ALL_UNARY_MATH_INIT(sin),
1222
+ CUDA_ALL_UNARY_MATH_INIT(cos),
1223
+ CUDA_ALL_UNARY_MATH_INIT(tan),
1224
+ CUDA_ALL_UNARY_MATH_INIT(asin),
1225
+ CUDA_ALL_UNARY_MATH_INIT(acos),
1226
+ CUDA_ALL_UNARY_MATH_INIT(atan),
1227
+
1228
+ /* HYPERBOLIC */
1229
+ CUDA_ALL_UNARY_MATH_INIT(sinh),
1230
+ CUDA_ALL_UNARY_MATH_INIT(cosh),
1231
+ CUDA_ALL_UNARY_MATH_INIT(tanh),
1232
+ CUDA_ALL_UNARY_MATH_INIT(asinh),
1233
+ CUDA_ALL_UNARY_MATH_INIT(acosh),
1234
+ CUDA_ALL_UNARY_MATH_INIT(atanh),
1235
+
1236
+ /* ERROR AND GAMMA */
1237
+ CUDA_ALL_UNARY_MATH_INIT(erf),
1238
+ CUDA_ALL_UNARY_MATH_INIT(erfc),
1239
+ CUDA_ALL_UNARY_MATH_INIT(lgamma),
1240
+ CUDA_ALL_UNARY_MATH_INIT(tgamma),
1241
+
1242
+ /* CEILING, FLOOR, TRUNC */
1243
+ CUDA_ALL_UNARY_MATH_INIT(ceil),
1244
+ CUDA_ALL_UNARY_MATH_INIT(floor),
1245
+ CUDA_ALL_UNARY_MATH_INIT(trunc),
1246
+ CUDA_ALL_UNARY_MATH_INIT(round),
1247
+ CUDA_ALL_UNARY_MATH_INIT(nearbyint),
1248
+
1249
+ { .name = NULL, .sig = NULL }
1250
+ };
1251
+
1252
+
1253
+ /****************************************************************************/
1254
+ /* Initialize kernel table */
1255
+ /****************************************************************************/
1256
+
1257
+ typedef _Bool bool;
1258
+
1259
+ static const gm_kernel_set_t *
1260
+ unary_copy_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
1261
+ const int64_t li[], int nin, int nout, bool check_broadcast,
1262
+ ndt_context_t *ctx)
1263
+ {
1264
+ return cuda_unary_typecheck(copy_kernel_location, spec, f, types, li,
1265
+ nin, nout, check_broadcast, ctx);
1266
+ }
1267
+
1268
+ static const gm_kernel_set_t *
1269
+ unary_invert_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
1270
+ const int64_t li[], int nin, int nout, bool check_broadcast,
1271
+ ndt_context_t *ctx)
1272
+ {
1273
+ return cuda_unary_typecheck(invert_kernel_location, spec, f, types, li,
1274
+ nin, nout, check_broadcast, ctx);
1275
+ }
1276
+
1277
+ static const gm_kernel_set_t *
1278
+ unary_negative_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
1279
+ const int64_t li[], int nin, int nout, bool check_broadcast,
1280
+ ndt_context_t *ctx)
1281
+ {
1282
+ return cuda_unary_typecheck(negative_kernel_location, spec, f, types, li,
1283
+ nin, nout, check_broadcast, ctx);
1284
+ }
1285
+
1286
+ static const gm_kernel_set_t *
1287
+ unary_math_typecheck(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
1288
+ const int64_t li[], int nin, int nout, bool check_broadcast,
1289
+ ndt_context_t *ctx)
1290
+ {
1291
+ return cuda_unary_typecheck(math_kernel_location, spec, f, types, li,
1292
+ nin, nout, check_broadcast, ctx);
1293
+ }
1294
+
1295
+ int
1296
+ gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx)
1297
+ {
1298
+ const gm_kernel_init_t *k;
1299
+
1300
+ for (k = unary_copy; k->name != NULL; k++) {
1301
+ if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_copy_typecheck) < 0) {
1302
+ return -1;
1303
+ }
1304
+ }
1305
+
1306
+ for (k = unary_reduce; k->name != NULL; k++) {
1307
+ if (gm_add_kernel(tbl, k, ctx) < 0) {
1308
+ return -1;
1309
+ }
1310
+ }
1311
+
1312
+ for (k = unary_invert; k->name != NULL; k++) {
1313
+ if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_invert_typecheck) < 0) {
1314
+ return -1;
1315
+ }
1316
+ }
1317
+
1318
+ for (k = unary_negative; k->name != NULL; k++) {
1319
+ if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_negative_typecheck) < 0) {
1320
+ return -1;
1321
+ }
1322
+ }
1323
+
1324
+ for (k = unary_float; k->name != NULL; k++) {
1325
+ if (gm_add_kernel_typecheck(tbl, k, ctx, &unary_math_typecheck) < 0) {
1326
+ return -1;
1327
+ }
1328
+ }
1329
+
1330
+ return 0;
1331
+ }