gumath 0.2.0dev5 → 0.2.0dev8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,528 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #include <cinttypes>
35
+ #include <thrust/complex.h>
36
+ #include <thrust/device_vector.h>
37
+ #include "cuda_device_unary.h"
38
+ #include "contrib/bfloat16.h"
39
+ #include "device.hh"
40
+
41
+
42
+ /*****************************************************************************/
43
+ /* Half float */
44
+ /*****************************************************************************/
45
+
46
+ static inline __device__ half
47
+ half_abs(half a)
48
+ {
49
+ return __hlt(a, 0) ? __hneg(a) : a;
50
+ }
51
+
52
+
53
+ /*****************************************************************************/
54
+ /* CUDA device unary kernels */
55
+ /*****************************************************************************/
56
+
57
+ #define CUDA_DEVICE_UNARY(name, func, t0, t1, common) \
58
+ static __global__ void \
59
+ _1D_C_##name##_##t0##_##t1(const t0##_t *x0, t1##_t *x1, \
60
+ const int64_t N) \
61
+ { \
62
+ int64_t index = threadIdx.x + blockIdx.x * blockDim.x; \
63
+ int64_t stride = blockDim.x * gridDim.x; \
64
+ \
65
+ for (int64_t i = index; i < N; i += stride) { \
66
+ x1[i] = func((common##_t)x0[i]); \
67
+ } \
68
+ } \
69
+ \
70
+ extern "C" void \
71
+ gm_cuda_device_fixed_1D_C_##name##_##t0##_##t1(const char *a0, char *a1, \
72
+ const int64_t N) \
73
+ { \
74
+ const t0##_t *x0 = (const t0##_t *)a0; \
75
+ t1##_t *x1 = (t1##_t *)a1; \
76
+ const int blockSize = 256; \
77
+ const int64_t numBlocks = (N + blockSize - 1) / blockSize; \
78
+ \
79
+ _1D_C_##name##_##t0##_##t1<<<numBlocks, blockSize>>>(x0, x1, N); \
80
+ } \
81
+ \
82
+ static __global__ void \
83
+ _1D_S_##name##_##t0##_##t1(const t0##_t *x0, t1##_t *x1, \
84
+ const int64_t s0, const int64_t s1, \
85
+ const int64_t N) \
86
+ { \
87
+ int64_t index = threadIdx.x + blockIdx.x * blockDim.x; \
88
+ int64_t stride = blockDim.x * gridDim.x; \
89
+ \
90
+ for (int64_t i = index; i < N; i += stride) { \
91
+ const int64_t k0 = i * s0; \
92
+ const int64_t k1 = i * s1; \
93
+ x1[k1] = func((common##_t)x0[k0]); \
94
+ } \
95
+ } \
96
+ \
97
+ extern "C" void \
98
+ gm_cuda_device_fixed_1D_S_##name##_##t0##_##t1(const char *a0, char *a1, \
99
+ const int64_t s0, const int64_t s1, \
100
+ const int64_t N) \
101
+ { \
102
+ const t0##_t *x0 = (const t0##_t *)a0; \
103
+ t1##_t *x1 = (t1##_t *)a1; \
104
+ const int blockSize = 256; \
105
+ const int64_t numBlocks = (N + blockSize - 1) / blockSize; \
106
+ \
107
+ _1D_S_##name##_##t0##_##t1<<<numBlocks, blockSize>>>(x0, x1, s0, s1, N); \
108
+ } \
109
+ \
110
+ static __global__ void \
111
+ _0D_##name##_##t0##_##t1(const t0##_t *x0, t1##_t *x1) \
112
+ { \
113
+ *x1 = func((common##_t)*x0); \
114
+ } \
115
+ \
116
+ extern "C" void \
117
+ gm_cuda_device_0D_##name##_##t0##_##t1(const char *a0, char *a1) \
118
+ { \
119
+ const t0##_t *x0 = (const t0##_t *)a0; \
120
+ t1##_t *x1 = (t1##_t *)a1; \
121
+ \
122
+ _0D_##name##_##t0##_##t1<<<1, 1>>>(x0, x1); \
123
+ }
124
+
125
+ #define CUDA_DEVICE_UNARY_REDUCE(name, func, t0, t1) \
126
+ extern "C" void \
127
+ gm_cuda_device_1D_C_reduce_##name##_##t0##_##t1(const char *a0, char *a1, \
128
+ const int64_t N) \
129
+ { \
130
+ thrust::device_ptr<t0##_t> x0 = thrust::device_pointer_cast((t0##_t *)a0); \
131
+ t1##_t *x1 = (t1##_t *)a1; \
132
+ \
133
+ *x1 = thrust::reduce(x0, x0+N, (t1##_t)0, thrust::func<t1##_t>()); \
134
+ }
135
+
136
+ #define CUDA_DEVICE_NOIMPL(name, func, t0, t1, common)
137
+
138
+
139
+ /*****************************************************************************/
140
+ /* Copy */
141
+ /*****************************************************************************/
142
+
143
+ #define CUDA_DEVICE_ALL_UNARY(name, func, ufunc, tfunc, hfunc) \
144
+ CUDA_DEVICE_UNARY(name, ufunc, bool, bool, bool) \
145
+ CUDA_DEVICE_UNARY(name, ufunc, bool, uint8, uint8) \
146
+ CUDA_DEVICE_UNARY(name, ufunc, bool, uint16, uint16) \
147
+ CUDA_DEVICE_UNARY(name, ufunc, bool, uint32, uint32) \
148
+ CUDA_DEVICE_UNARY(name, ufunc, bool, uint64, uint64) \
149
+ CUDA_DEVICE_UNARY(name, func, bool, int8, int8) \
150
+ CUDA_DEVICE_UNARY(name, func, bool, int16, int16) \
151
+ CUDA_DEVICE_UNARY(name, func, bool, int32, int32) \
152
+ CUDA_DEVICE_UNARY(name, func, bool, int64, int64) \
153
+ CUDA_DEVICE_UNARY(name, tfunc, bool, bfloat16, bfloat16) \
154
+ CUDA_DEVICE_UNARY(name, hfunc, bool, float16, float16) \
155
+ CUDA_DEVICE_UNARY(name, func, bool, float32, float32) \
156
+ CUDA_DEVICE_UNARY(name, func, bool, float64, float64) \
157
+ CUDA_DEVICE_NOIMPL(name, func, bool, complex32, complex32) \
158
+ CUDA_DEVICE_UNARY(name, func, bool, complex64, complex64) \
159
+ CUDA_DEVICE_UNARY(name, func, bool, complex128, complex128) \
160
+ \
161
+ CUDA_DEVICE_UNARY(name, ufunc, uint8, uint8, uint8) \
162
+ CUDA_DEVICE_UNARY(name, ufunc, uint8, uint16, uint16) \
163
+ CUDA_DEVICE_UNARY(name, ufunc, uint8, uint32, uint32) \
164
+ CUDA_DEVICE_UNARY(name, ufunc, uint8, uint64, uint64) \
165
+ CUDA_DEVICE_UNARY(name, func, uint8, int16, int16) \
166
+ CUDA_DEVICE_UNARY(name, func, uint8, int32, int32) \
167
+ CUDA_DEVICE_UNARY(name, func, uint8, int64, int64) \
168
+ CUDA_DEVICE_UNARY(name, tfunc, uint8, bfloat16, bfloat16) \
169
+ CUDA_DEVICE_UNARY(name, hfunc, uint8, float16, float16) \
170
+ CUDA_DEVICE_UNARY(name, func, uint8, float32, float32) \
171
+ CUDA_DEVICE_UNARY(name, func, uint8, float64, float64) \
172
+ CUDA_DEVICE_NOIMPL(name, func, uint8, complex32, complex32) \
173
+ CUDA_DEVICE_UNARY(name, func, uint8, complex64, complex64) \
174
+ CUDA_DEVICE_UNARY(name, func, uint8, complex128, complex128) \
175
+ \
176
+ CUDA_DEVICE_UNARY(name, ufunc, uint16, uint16, uint16) \
177
+ CUDA_DEVICE_UNARY(name, ufunc, uint16, uint32, uint32) \
178
+ CUDA_DEVICE_UNARY(name, ufunc, uint16, uint64, uint64) \
179
+ CUDA_DEVICE_UNARY(name, func, uint16, int32, int32) \
180
+ CUDA_DEVICE_UNARY(name, func, uint16, int64, int64) \
181
+ CUDA_DEVICE_UNARY(name, func, uint16, float32, float32) \
182
+ CUDA_DEVICE_UNARY(name, func, uint16, float64, float64) \
183
+ CUDA_DEVICE_UNARY(name, func, uint16, complex64, complex64) \
184
+ CUDA_DEVICE_UNARY(name, func, uint16, complex128, complex128) \
185
+ \
186
+ CUDA_DEVICE_UNARY(name, ufunc, uint32, uint32, uint32) \
187
+ CUDA_DEVICE_UNARY(name, ufunc, uint32, uint64, uint64) \
188
+ CUDA_DEVICE_UNARY(name, func, uint32, int64, int64) \
189
+ CUDA_DEVICE_UNARY(name, func, uint32, float64, float64) \
190
+ CUDA_DEVICE_UNARY(name, func, uint32, complex128, complex128) \
191
+ \
192
+ CUDA_DEVICE_UNARY(name, ufunc, uint64, uint64, uint64) \
193
+ \
194
+ CUDA_DEVICE_UNARY(name, func, int8, int8, int8) \
195
+ CUDA_DEVICE_UNARY(name, func, int8, int16, int16) \
196
+ CUDA_DEVICE_UNARY(name, func, int8, int32, int32) \
197
+ CUDA_DEVICE_UNARY(name, func, int8, int64, int64) \
198
+ CUDA_DEVICE_UNARY(name, tfunc, int8, bfloat16, bfloat16) \
199
+ CUDA_DEVICE_UNARY(name, hfunc, int8, float16, float16) \
200
+ CUDA_DEVICE_UNARY(name, func, int8, float32, float32) \
201
+ CUDA_DEVICE_UNARY(name, func, int8, float64, float64) \
202
+ CUDA_DEVICE_NOIMPL(name, func, int8, complex32, complex32) \
203
+ CUDA_DEVICE_UNARY(name, func, int8, complex64, complex64) \
204
+ CUDA_DEVICE_UNARY(name, func, int8, complex128, complex128) \
205
+ \
206
+ CUDA_DEVICE_UNARY(name, func, int16, int16, int16) \
207
+ CUDA_DEVICE_UNARY(name, func, int16, int32, int32) \
208
+ CUDA_DEVICE_UNARY(name, func, int16, int64, int64) \
209
+ CUDA_DEVICE_UNARY(name, func, int16, float32, float32) \
210
+ CUDA_DEVICE_UNARY(name, func, int16, float64, float64) \
211
+ CUDA_DEVICE_UNARY(name, func, int16, complex64, complex64) \
212
+ CUDA_DEVICE_UNARY(name, func, int16, complex128, complex128) \
213
+ \
214
+ CUDA_DEVICE_UNARY(name, func, int32, int32, int32) \
215
+ CUDA_DEVICE_UNARY(name, func, int32, int64, int64) \
216
+ CUDA_DEVICE_UNARY(name, func, int32, float64, float64) \
217
+ CUDA_DEVICE_UNARY(name, func, int32, complex128, complex128) \
218
+ \
219
+ CUDA_DEVICE_UNARY(name, func, int64, int64, int64) \
220
+ \
221
+ CUDA_DEVICE_UNARY(name, tfunc, bfloat16, bfloat16, bfloat16) \
222
+ CUDA_DEVICE_UNARY(name, func, bfloat16, float32, float32) \
223
+ CUDA_DEVICE_UNARY(name, func, bfloat16, float64, float64) \
224
+ CUDA_DEVICE_UNARY(name, func, bfloat16, complex64, complex64) \
225
+ CUDA_DEVICE_UNARY(name, func, bfloat16, complex128, complex128) \
226
+ \
227
+ CUDA_DEVICE_UNARY(name, hfunc, float16, float16, float16) \
228
+ CUDA_DEVICE_UNARY(name, func, float16, float32, float32) \
229
+ CUDA_DEVICE_UNARY(name, func, float16, float64, float64) \
230
+ CUDA_DEVICE_NOIMPL(name, func, float16, complex32, complex32) \
231
+ CUDA_DEVICE_UNARY(name, func, float16, complex64, complex64) \
232
+ CUDA_DEVICE_UNARY(name, func, float16, complex128, complex128) \
233
+ \
234
+ CUDA_DEVICE_UNARY(name, func, float32, float32, float32) \
235
+ CUDA_DEVICE_UNARY(name, func, float32, float64, float64) \
236
+ CUDA_DEVICE_UNARY(name, func, float32, complex64, complex64) \
237
+ CUDA_DEVICE_UNARY(name, func, float32, complex128, complex128) \
238
+ \
239
+ CUDA_DEVICE_UNARY(name, func, float64, float64, float64) \
240
+ CUDA_DEVICE_UNARY(name, func, float64, complex128, complex128) \
241
+ \
242
+ CUDA_DEVICE_NOIMPL(name, func, complex32, complex32, complex32) \
243
+ CUDA_DEVICE_NOIMPL(name, func, complex32, complex64, complex64) \
244
+ CUDA_DEVICE_NOIMPL(name, func, complex32, complex128, complex128) \
245
+ \
246
+ CUDA_DEVICE_UNARY(name, func, complex64, complex64, complex64) \
247
+ CUDA_DEVICE_UNARY(name, func, complex64, complex128, complex128) \
248
+ \
249
+ CUDA_DEVICE_UNARY(name, func, complex128, complex128, complex128)
250
+
251
+
252
+ #define copy(x) x
253
+ CUDA_DEVICE_ALL_UNARY(copy, copy, copy, copy, copy)
254
+ CUDA_DEVICE_ALL_UNARY(abs, _abs, copy, tf::fabs, half_abs)
255
+
256
+
257
+ /*****************************************************************************/
258
+ /* Reduce */
259
+ /*****************************************************************************/
260
+
261
+ #define CUDA_DEVICE_ALL_UNARY_REDUCE(name, func, hfunc) \
262
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, bool) \
263
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, uint8) \
264
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, uint16) \
265
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, uint32) \
266
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, uint64) \
267
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, int8) \
268
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, int16) \
269
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, int32) \
270
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, int64) \
271
+ CUDA_DEVICE_NOIMPL(name, func, bool, bfloat16, bfloat16) \
272
+ CUDA_DEVICE_UNARY_REDUCE(name, hfunc, bool, float16) \
273
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, float32) \
274
+ CUDA_DEVICE_UNARY_REDUCE(name, func, bool, float64) \
275
+ CUDA_DEVICE_NOIMPL(name, func, bool, complex32, complex32) \
276
+ CUDA_DEVICE_NOIMPL(name, func, bool, complex64, complex64) \
277
+ CUDA_DEVICE_NOIMPL(name, func, bool, complex128, complex128) \
278
+ \
279
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint8, uint8) \
280
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint8, uint16) \
281
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint8, uint32) \
282
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint8, uint64) \
283
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint8, int16) \
284
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint8, int32) \
285
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint8, int64) \
286
+ CUDA_DEVICE_NOIMPL(name, func, uint8, bfloat16, bfloat16) \
287
+ CUDA_DEVICE_UNARY_REDUCE(name, hfunc, uint8, float16) \
288
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint8, float32) \
289
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint8, float64) \
290
+ CUDA_DEVICE_NOIMPL(name, func, uint8, complex32, complex32) \
291
+ CUDA_DEVICE_NOIMPL(name, func, uint8, complex64, complex64) \
292
+ CUDA_DEVICE_NOIMPL(name, func, uint8, complex128, complex128) \
293
+ \
294
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint16, uint16) \
295
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint16, uint32) \
296
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint16, uint64) \
297
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint16, int32) \
298
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint16, int64) \
299
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint16, float32) \
300
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint16, float64) \
301
+ CUDA_DEVICE_NOIMPL(name, func, uint16, complex64, complex64) \
302
+ CUDA_DEVICE_NOIMPL(name, func, uint16, complex128, complex128) \
303
+ \
304
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint32, uint32) \
305
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint32, uint64) \
306
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint32, int64) \
307
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint32, float64) \
308
+ CUDA_DEVICE_NOIMPL(name, func, uint32, complex128, complex128) \
309
+ \
310
+ CUDA_DEVICE_UNARY_REDUCE(name, func, uint64, uint64) \
311
+ \
312
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int8, int8) \
313
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int8, int16) \
314
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int8, int32) \
315
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int8, int64) \
316
+ CUDA_DEVICE_NOIMPL(name, func, int8, bfloat16, bfloat16) \
317
+ CUDA_DEVICE_UNARY_REDUCE(name, hfunc, int8, float16) \
318
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int8, float32) \
319
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int8, float64) \
320
+ CUDA_DEVICE_NOIMPL(name, func, int8, complex32, complex32) \
321
+ CUDA_DEVICE_NOIMPL(name, func, int8, complex64, complex64) \
322
+ CUDA_DEVICE_NOIMPL(name, func, int8, complex128, complex128) \
323
+ \
324
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int16, int16) \
325
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int16, int32) \
326
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int16, int64) \
327
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int16, float32) \
328
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int16, float64) \
329
+ CUDA_DEVICE_NOIMPL(name, func, int16, complex64, complex64) \
330
+ CUDA_DEVICE_NOIMPL(name, func, int16, complex128, complex128) \
331
+ \
332
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int32, int32) \
333
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int32, int64) \
334
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int32, float64) \
335
+ CUDA_DEVICE_NOIMPL(name, func, int32, complex128, complex128) \
336
+ \
337
+ CUDA_DEVICE_UNARY_REDUCE(name, func, int64, int64) \
338
+ \
339
+ CUDA_DEVICE_NOIMPL(name, func, bfloat16, bfloat16, bfloat16) \
340
+ CUDA_DEVICE_NOIMPL(name, func, bfloat16, float32, float32) \
341
+ CUDA_DEVICE_NOIMPL(name, func, bfloat16, float64, float64) \
342
+ CUDA_DEVICE_NOIMPL(name, func, bfloat16, complex64, complex64) \
343
+ CUDA_DEVICE_NOIMPL(name, func, bfloat16, complex128, complex128) \
344
+ \
345
+ CUDA_DEVICE_UNARY_REDUCE(name, hfunc, float16, float16) \
346
+ CUDA_DEVICE_NOIMPL(name, func, float16, float32, float32) \
347
+ CUDA_DEVICE_NOIMPL(name, func, float16, float64, float64) \
348
+ CUDA_DEVICE_NOIMPL(name, func, float16, complex32, complex32) \
349
+ CUDA_DEVICE_NOIMPL(name, func, float16, complex64, complex64) \
350
+ CUDA_DEVICE_NOIMPL(name, func, float16, complex128, complex128) \
351
+ \
352
+ CUDA_DEVICE_UNARY_REDUCE(name, func, float32, float32) \
353
+ CUDA_DEVICE_UNARY_REDUCE(name, func, float32, float64) \
354
+ CUDA_DEVICE_NOIMPL(name, func, float32, complex64, complex64) \
355
+ CUDA_DEVICE_NOIMPL(name, func, float32, complex128, complex128) \
356
+ \
357
+ CUDA_DEVICE_UNARY_REDUCE(name, func, float64, float64) \
358
+ CUDA_DEVICE_NOIMPL(name, func, float64, complex128, complex128) \
359
+ \
360
+ CUDA_DEVICE_NOIMPL(name, func, complex32, complex32, complex32) \
361
+ CUDA_DEVICE_NOIMPL(name, func, complex32, complex64, complex64) \
362
+ CUDA_DEVICE_NOIMPL(name, func, complex32, complex128, complex128) \
363
+ \
364
+ CUDA_DEVICE_NOIMPL(name, func, complex64, complex64, complex64) \
365
+ CUDA_DEVICE_NOIMPL(name, func, complex64, complex128, complex128) \
366
+ \
367
+ CUDA_DEVICE_NOIMPL(name, func, complex128, complex128, complex128)
368
+
369
+
370
+ CUDA_DEVICE_ALL_UNARY_REDUCE(add, plus, plus)
371
+ CUDA_DEVICE_ALL_UNARY_REDUCE(multiply, multiplies, multiplies)
372
+
373
+
374
+ /*****************************************************************************/
375
+ /* Bitwise NOT */
376
+ /*****************************************************************************/
377
+
378
+ #define invert(x) !x
379
+ CUDA_DEVICE_UNARY(invert, invert, bool, bool, bool)
380
+ #undef invert
381
+
382
+ #define invert(x) ~x
383
+ CUDA_DEVICE_UNARY(invert, invert, uint8, uint8, uint8)
384
+ CUDA_DEVICE_UNARY(invert, invert, uint16, uint16, uint16)
385
+ CUDA_DEVICE_UNARY(invert, invert, uint32, uint32, uint32)
386
+ CUDA_DEVICE_UNARY(invert, invert, uint64, uint64, uint64)
387
+
388
+ CUDA_DEVICE_UNARY(invert, invert, int8, int8, int8)
389
+ CUDA_DEVICE_UNARY(invert, invert, int16, int16, int16)
390
+ CUDA_DEVICE_UNARY(invert, invert, int32, int32, int32)
391
+ CUDA_DEVICE_UNARY(invert, invert, int64, int64, int64)
392
+
393
+
394
+ /*****************************************************************************/
395
+ /* Negative */
396
+ /*****************************************************************************/
397
+
398
+ #define negative(x) -x
399
+ CUDA_DEVICE_UNARY(negative, negative, uint8, int16, int16)
400
+ CUDA_DEVICE_UNARY(negative, negative, uint16, int32, int32)
401
+ CUDA_DEVICE_UNARY(negative, negative, uint32, int64, int64)
402
+
403
+ CUDA_DEVICE_UNARY(negative, negative, int8, int8, int8)
404
+ CUDA_DEVICE_UNARY(negative, negative, int16, int16, int16)
405
+ CUDA_DEVICE_UNARY(negative, negative, int32, int32, int32)
406
+ CUDA_DEVICE_UNARY(negative, negative, int64, int64, int64)
407
+
408
+ CUDA_DEVICE_UNARY(negative, negative, bfloat16, bfloat16, bfloat16)
409
+ CUDA_DEVICE_UNARY(negative, __hneg, float16, float16, float16)
410
+ CUDA_DEVICE_UNARY(negative, negative, float32, float32, float32)
411
+ CUDA_DEVICE_UNARY(negative, negative, float64, float64, float64)
412
+
413
+ CUDA_DEVICE_NOIMPL(negative, negative, complex32, complex32, complex32)
414
+ CUDA_DEVICE_UNARY(negative, negative, complex64, complex64, complex64)
415
+ CUDA_DEVICE_UNARY(negative, negative, complex128, complex128, complex128)
416
+
417
+
418
+ /*****************************************************************************/
419
+ /* Math */
420
+ /*****************************************************************************/
421
+
422
+ #define CUDA_DEVICE_UNARY_ALL_REAL_MATH(name) \
423
+ CUDA_DEVICE_UNARY(name##f, name##f, uint16, float32, float32) \
424
+ CUDA_DEVICE_UNARY(name##f, name##f, int16, float32, float32) \
425
+ CUDA_DEVICE_UNARY(name##b16, tf::name, bfloat16, bfloat16, bfloat16) \
426
+ CUDA_DEVICE_UNARY(name##f, name##f, float32, float32, float32) \
427
+ CUDA_DEVICE_UNARY(name, name, uint32, float64, float64) \
428
+ CUDA_DEVICE_UNARY(name, name, int32, float64, float64) \
429
+ CUDA_DEVICE_UNARY(name, name, float64, float64, float64)
430
+
431
+ #define CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(name) \
432
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(name) \
433
+ CUDA_DEVICE_NOIMPL(name, name, complex32, complex32, complex32) \
434
+ CUDA_DEVICE_UNARY(name, name, complex64, complex64, complex64) \
435
+ CUDA_DEVICE_UNARY(name, name, complex128, complex128, complex128)
436
+
437
+ #define CUDA_DEVICE_UNARY_ALL_HALF_MATH(name, hfunc) \
438
+ CUDA_DEVICE_UNARY(name##f16, hfunc, uint8, float16, float16) \
439
+ CUDA_DEVICE_UNARY(name##f16, hfunc, int8, float16, float16) \
440
+ CUDA_DEVICE_UNARY(name##f16, hfunc, float16, float16, float16)
441
+
442
+ #define CUDA_DEVICE_UNARY_ALL_REAL_MATH_WITH_HALF(name, hfunc) \
443
+ CUDA_DEVICE_UNARY_ALL_HALF_MATH(name, hfunc) \
444
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(name) \
445
+
446
+ #define CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH_WITH_HALF(name, hfunc) \
447
+ CUDA_DEVICE_UNARY_ALL_HALF_MATH(name, hfunc) \
448
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(name) \
449
+
450
+
451
+ /*****************************************************************************/
452
+ /* Abs functions */
453
+ /*****************************************************************************/
454
+
455
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH_WITH_HALF(fabs, half_abs)
456
+
457
+
458
+ /*****************************************************************************/
459
+ /* Exponential functions */
460
+ /*****************************************************************************/
461
+
462
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH_WITH_HALF(exp, hexp)
463
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH_WITH_HALF(exp2, hexp2)
464
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(expm1)
465
+
466
+
467
+ /*****************************************************************************/
468
+ /* Logarithm functions */
469
+ /*****************************************************************************/
470
+
471
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH_WITH_HALF(log, hlog)
472
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH_WITH_HALF(log10, hlog10)
473
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH_WITH_HALF(log2, hlog2)
474
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(log1p)
475
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(logb)
476
+
477
+
478
+ /*****************************************************************************/
479
+ /* Power functions */
480
+ /*****************************************************************************/
481
+
482
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH_WITH_HALF(sqrt, hsqrt)
483
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(cbrt)
484
+
485
+
486
+ /*****************************************************************************/
487
+ /* Trigonometric functions */
488
+ /*****************************************************************************/
489
+
490
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH_WITH_HALF(sin, hsin)
491
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH_WITH_HALF(cos, hcos)
492
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(tan)
493
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(asin)
494
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(acos)
495
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(atan)
496
+
497
+
498
+ /*****************************************************************************/
499
+ /* Hyperbolic functions */
500
+ /*****************************************************************************/
501
+
502
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(sinh)
503
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(cosh)
504
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(tanh)
505
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(asinh)
506
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(acosh)
507
+ CUDA_DEVICE_UNARY_ALL_COMPLEX_MATH(atanh)
508
+
509
+
510
+ /*****************************************************************************/
511
+ /* Error and gamma functions */
512
+ /*****************************************************************************/
513
+
514
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(erf)
515
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(erfc)
516
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(lgamma)
517
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(tgamma)
518
+
519
+
520
+ /*****************************************************************************/
521
+ /* Ceiling, floor, trunc */
522
+ /*****************************************************************************/
523
+
524
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(ceil)
525
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(floor)
526
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(trunc)
527
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(round)
528
+ CUDA_DEVICE_UNARY_ALL_REAL_MATH(nearbyint)