gumath 0.2.0dev5 → 0.2.0dev8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,155 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #include <stdlib.h>
35
+ #include <stdint.h>
36
+ #include <string.h>
37
+ #include <math.h>
38
+ #include <complex.h>
39
+ #include <inttypes.h>
40
+ #include "ndtypes.h"
41
+ #include "xnd.h"
42
+ #include "gumath.h"
43
+
44
+
45
+ #ifndef COMMON_H
46
+ #define COMMON_H
47
+
48
+
49
+ #define XSTRINGIZE(v) #v
50
+ #define STRINGIZE(v) XSTRINGIZE(v)
51
+
52
+
53
+ /*****************************************************************************/
54
+ /* Apply linear index to the data pointer (1D kernels) */
55
+ /*****************************************************************************/
56
+
57
+ static inline char *
58
+ apply_index(const xnd_t *x)
59
+ {
60
+ return xnd_fixed_apply_index(x);
61
+ }
62
+
63
+
64
+ /*****************************************************************************/
65
+ /* Optimized bitmap handling */
66
+ /*****************************************************************************/
67
+
68
+ static inline uint8_t *
69
+ get_bitmap(const xnd_t *x)
70
+ {
71
+ const ndt_t *t = x->type;
72
+ assert(t->ndim == 0);
73
+ return ndt_is_optional(t) ? x->bitmap.data : NULL;
74
+ }
75
+
76
+ static inline uint8_t *
77
+ get_bitmap1D(const xnd_t *x)
78
+ {
79
+ const ndt_t *t = x->type;
80
+ assert(t->ndim == 1 && t->tag == FixedDim);
81
+ return ndt_is_optional(ndt_dtype(t)) ? x->bitmap.data : NULL;
82
+ }
83
+
84
+ static inline bool
85
+ is_valid(const uint8_t *data, int64_t n)
86
+ {
87
+ int64_t pos = n / 8;
88
+ int64_t shift = n % 8;
89
+ uint8_t mask = (uint8_t)1 << shift;;
90
+
91
+ return data[pos] & mask;
92
+ }
93
+
94
+ static inline void
95
+ set_bit(uint8_t *data, int64_t n, bool x)
96
+ {
97
+ int64_t pos = n / 8;
98
+ int64_t shift = n % 8;
99
+ uint8_t dmask = ((uint8_t)1) << shift;
100
+ uint8_t xmask = ((uint8_t)x) << shift;
101
+
102
+ data[pos] ^= ((data[pos] & dmask) ^ xmask);
103
+ }
104
+
105
+ static inline int64_t
106
+ linear_index1D(const xnd_t *x, const int64_t i)
107
+ {
108
+ const ndt_t *t = x->type;
109
+ const int64_t step = i * t->Concrete.FixedDim.step;
110
+ return x->index + step;
111
+ }
112
+
113
+
114
+ /*****************************************************************************/
115
+ /* Binary typecheck */
116
+ /*****************************************************************************/
117
+
118
+ /* LOCAL SCOPE */
119
+ NDT_PRAGMA(NDT_HIDE_SYMBOLS_START)
120
+
121
+ void unary_update_bitmap_1D_S(xnd_t stack[]);
122
+ void unary_reduce_bitmap_1D_S(xnd_t stack[]);
123
+ void unary_update_bitmap_0D(xnd_t stack[]);
124
+
125
+ void binary_update_bitmap_1D_S(xnd_t stack[]);
126
+ void binary_update_bitmap_0D(xnd_t stack[]);
127
+
128
+ void binary_update_bitmap_1D_S_bool(xnd_t stack[]);
129
+ void binary_update_bitmap_0D_bool(xnd_t stack[]);
130
+
131
+ const gm_kernel_set_t *cpu_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
132
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
133
+ const int64_t li[], int nin, int nout, bool check_broadcast,
134
+ ndt_context_t *ctx);
135
+
136
+ const gm_kernel_set_t *cuda_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
137
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
138
+ const int64_t li[], int nin, int nout, bool check_broadcast,
139
+ ndt_context_t *ctx);
140
+
141
+ const gm_kernel_set_t *cpu_binary_typecheck(int (*kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
142
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
143
+ const int64_t li[], int nin, int nout, bool check_broadcast,
144
+ ndt_context_t *ctx);
145
+
146
+ const gm_kernel_set_t *cuda_binary_typecheck(int (* kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
147
+ ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
148
+ const int64_t li[], int nin, int nout, bool check_broadcast,
149
+ ndt_context_t *ctx);
150
+
151
+ /* END LOCAL SCOPE */
152
+ NDT_PRAGMA(NDT_HIDE_SYMBOLS_END)
153
+
154
+
155
+ #endif /* COMMON_H */
@@ -0,0 +1,520 @@
1
+ /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ /* Modified and adapted for gumath. */
17
+
18
+ #ifndef BFLOAT16_H
19
+ #define BFLOAT16_H
20
+
21
+
22
+ // Byte order defines provided by gcc. MSVC doesn't define those so
23
+ // we define them here.
24
+ // We assume that all windows platform out there are little endian.
25
+ #if defined(_MSC_VER) && !defined(__clang__)
26
+ #define __ORDER_LITTLE_ENDIAN__ 0x4d2
27
+ #define __ORDER_BIG_ENDIAN__ 0x10e1
28
+ #define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__
29
+ #endif
30
+
31
+ #ifdef __CUDACC__
32
+ // All functions callable from CUDA code must be qualified with __device__
33
+ #define B16_DEVICE_FUNC __host__ __device__
34
+ #include <thrust/complex.h>
35
+ #include <math.h>
36
+ typedef thrust::complex<float> complex64;
37
+ typedef thrust::complex<double> complex128;
38
+ #else
39
+ #define B16_DEVICE_FUNC
40
+ #include <cmath>
41
+ #include <complex>
42
+ typedef std::complex<float> complex64;
43
+ typedef std::complex<double> complex128;
44
+ #endif
45
+
46
+
47
+ namespace tf {
48
+
49
+ // see framework/bfloat16.h for description.
50
+ struct bfloat16 {
51
+ // The default constructor must yield a zero value, not an uninitialized
52
+ // value; some TF kernels use T() as a zero value.
53
+ B16_DEVICE_FUNC bfloat16() : value(ZERO_VALUE) {}
54
+
55
+ B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) {
56
+ bfloat16 output;
57
+ if (float_isnan(v)) {
58
+ output.value = NAN_VALUE;
59
+ return output;
60
+ }
61
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
62
+ #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
63
+ output.value = p[0];
64
+ #else
65
+ output.value = p[1];
66
+ #endif
67
+ return output;
68
+ }
69
+
70
+ B16_DEVICE_FUNC explicit bfloat16(const float v) {
71
+ value = round_to_bfloat16(v).value;
72
+ }
73
+
74
+ B16_DEVICE_FUNC explicit bfloat16(const double val)
75
+ : bfloat16(static_cast<float>(val)) {}
76
+ // Following the convention of numpy, converting between complex and
77
+ // float will lead to loss of imag value.
78
+ B16_DEVICE_FUNC explicit bfloat16(const complex64& val)
79
+ : bfloat16(val.real()) {}
80
+
81
+ B16_DEVICE_FUNC explicit bfloat16(const complex128& val)
82
+ : bfloat16(static_cast<float>(val.real())) {}
83
+
84
+ B16_DEVICE_FUNC explicit bfloat16(const unsigned short val)
85
+ : bfloat16(static_cast<float>(val)) {}
86
+
87
+ B16_DEVICE_FUNC explicit bfloat16(const unsigned int val)
88
+ : bfloat16(static_cast<float>(val)) {}
89
+
90
+ B16_DEVICE_FUNC explicit bfloat16(const int val)
91
+ : bfloat16(static_cast<float>(val)) {}
92
+
93
+ B16_DEVICE_FUNC explicit bfloat16(const long val)
94
+ : bfloat16(static_cast<float>(val)) {}
95
+
96
+ B16_DEVICE_FUNC explicit bfloat16(const long long val)
97
+ : bfloat16(static_cast<float>(val)) {}
98
+
99
+ template <class T>
100
+ B16_DEVICE_FUNC explicit bfloat16(const T& val)
101
+ : bfloat16(static_cast<float>(val)) {}
102
+
103
+ B16_DEVICE_FUNC explicit operator float() const {
104
+ float result = 0;
105
+
106
+ uint16_t* q = reinterpret_cast<uint16_t*>(&result);
107
+
108
+ #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
109
+ q[0] = value;
110
+ #else
111
+ q[1] = value;
112
+ #endif
113
+ return result;
114
+ }
115
+
116
+ B16_DEVICE_FUNC explicit operator bool() const {
117
+ return static_cast<bool>(!!(float(*this)));
118
+ }
119
+
120
+ B16_DEVICE_FUNC explicit operator short() const {
121
+ return static_cast<short>(float(*this));
122
+ }
123
+
124
+ B16_DEVICE_FUNC explicit operator int() const {
125
+ return static_cast<int>(float(*this));
126
+ }
127
+
128
+ B16_DEVICE_FUNC explicit operator long() const {
129
+ return static_cast<long>(float(*this));
130
+ }
131
+
132
+ B16_DEVICE_FUNC explicit operator char() const {
133
+ return static_cast<char>(float(*this));
134
+ }
135
+
136
+ B16_DEVICE_FUNC explicit operator signed char() const {
137
+ return static_cast<signed char>(float(*this));
138
+ }
139
+
140
+ B16_DEVICE_FUNC explicit operator unsigned char() const {
141
+ return static_cast<unsigned char>(float(*this));
142
+ }
143
+
144
+ B16_DEVICE_FUNC explicit operator unsigned short() const {
145
+ return static_cast<unsigned short>(float(*this));
146
+ }
147
+
148
+ B16_DEVICE_FUNC explicit operator unsigned int() const {
149
+ return static_cast<unsigned int>(float(*this));
150
+ }
151
+
152
+ B16_DEVICE_FUNC explicit operator unsigned long() const {
153
+ return static_cast<unsigned long>(float(*this));
154
+ }
155
+
156
+ B16_DEVICE_FUNC explicit operator unsigned long long() const {
157
+ return static_cast<unsigned long long>(float(*this));
158
+ }
159
+
160
+ B16_DEVICE_FUNC explicit operator long long() const {
161
+ return static_cast<long long>(float(*this));
162
+ }
163
+
164
+ B16_DEVICE_FUNC explicit operator double() const {
165
+ return static_cast<double>(float(*this));
166
+ }
167
+
168
+ B16_DEVICE_FUNC explicit operator complex64() const {
169
+ return complex64(float(*this), float(0.0));
170
+ }
171
+
172
+ B16_DEVICE_FUNC explicit operator complex128() const {
173
+ return complex128(double(*this), double(0.0));
174
+ }
175
+
176
+ union FP32 {
177
+ unsigned int u;
178
+ float f;
179
+ };
180
+
181
+ // Converts a float point to bfloat16, with round-nearest-to-even as rounding
182
+ // method.
183
+ // TODO: There is a slightly faster implementation (8% faster on CPU)
184
+ // than this (documented in cl/175987786), that is exponentially harder to
185
+ // understand and document. Switch to the faster version when converting to
186
+ // BF16 becomes compute-bound.
187
+ B16_DEVICE_FUNC static bfloat16 round_to_bfloat16(float v) {
188
+ uint32_t input;
189
+ FP32 f;
190
+ f.f = v;
191
+ input = f.u;
192
+ bfloat16 output;
193
+
194
+ if (float_isnan(v)) {
195
+ // If the value is a NaN, squash it to a qNaN with msb of fraction set,
196
+ // this makes sure after truncation we don't end up with an inf.
197
+ //
198
+ // qNaN magic: All exponent bits set + most significant bit of fraction
199
+ // set.
200
+ output.value = 0x7fc0;
201
+ } else {
202
+ // Fast rounding algorithm that rounds a half value to nearest even. This
203
+ // reduces expected error when we convert a large number of floats. Here
204
+ // is how it works:
205
+ //
206
+ // Definitions:
207
+ // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
208
+ // with the following tags:
209
+ //
210
+ // Sign | Exp (8 bits) | Frac (23 bits)
211
+ // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
212
+ //
213
+ // S: Sign bit.
214
+ // E: Exponent bits.
215
+ // F: First 6 bits of fraction.
216
+ // L: Least significant bit of resulting bfloat16 if we truncate away the
217
+ // rest of the float32. This is also the 7th bit of fraction
218
+ // R: Rounding bit, 8th bit of fraction.
219
+ // T: Sticky bits, rest of fraction, 15 bits.
220
+ //
221
+ // To round half to nearest even, there are 3 cases where we want to round
222
+ // down (simply truncate the result of the bits away, which consists of
223
+ // rounding bit and sticky bits) and two cases where we want to round up
224
+ // (truncate then add one to the result).
225
+ //
226
+ // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
227
+ // 1s) as the rounding bias, adds the rounding bias to the input, then
228
+ // truncates the last 16 bits away.
229
+ //
230
+ // To understand how it works, we can analyze this algorithm case by case:
231
+ //
232
+ // 1. L = 0, R = 0:
233
+ // Expect: round down, this is less than half value.
234
+ //
235
+ // Algorithm:
236
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
237
+ // - Adding rounding bias to input may create any carry, depending on
238
+ // whether there is any value set to 1 in T bits.
239
+ // - R may be set to 1 if there is a carry.
240
+ // - L remains 0.
241
+ // - Note that this case also handles Inf and -Inf, where all fraction
242
+ // bits, including L, R and Ts are all 0. The output remains Inf after
243
+ // this algorithm.
244
+ //
245
+ // 2. L = 1, R = 0:
246
+ // Expect: round down, this is less than half value.
247
+ //
248
+ // Algorithm:
249
+ // - Rounding bias: 0x7fff + 1 = 0x8000
250
+ // - Adding rounding bias to input doesn't change sticky bits but
251
+ // adds 1 to rounding bit.
252
+ // - L remains 1.
253
+ //
254
+ // 3. L = 0, R = 1, all of T are 0:
255
+ // Expect: round down, this is exactly at half, the result is already
256
+ // even (L=0).
257
+ //
258
+ // Algorithm:
259
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
260
+ // - Adding rounding bias to input sets all sticky bits to 1, but
261
+ // doesn't create a carry.
262
+ // - R remains 1.
263
+ // - L remains 0.
264
+ //
265
+ // 4. L = 1, R = 1:
266
+ // Expect: round up, this is exactly at half, the result needs to be
267
+ // round to the next even number.
268
+ //
269
+ // Algorithm:
270
+ // - Rounding bias: 0x7fff + 1 = 0x8000
271
+ // - Adding rounding bias to input doesn't change sticky bits, but
272
+ // creates a carry from rounding bit.
273
+ // - The carry sets L to 0, creates another carry bit and propagate
274
+ // forward to F bits.
275
+ // - If all the F bits are 1, a carry then propagates to the exponent
276
+ // bits, which then creates the minimum value with the next exponent
277
+ // value. Note that we won't have the case where exponents are all 1,
278
+ // since that's either a NaN (handled in the other if condition) or inf
279
+ // (handled in case 1).
280
+ //
281
+ // 5. L = 0, R = 1, any of T is 1:
282
+ // Expect: round up, this is greater than half.
283
+ //
284
+ // Algorithm:
285
+ // - Rounding bias: 0x7fff + 0 = 0x7fff
286
+ // - Adding rounding bias to input creates a carry from sticky bits,
287
+ // sets rounding bit to 0, then create another carry.
288
+ // - The second carry sets L to 1.
289
+ //
290
+ // Examples:
291
+ //
292
+ // Exact half value that is already even:
293
+ // Input:
294
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
295
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
296
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
297
+ //
298
+ // This falls into case 3. We truncate the rest of 16 bits and no
299
+ // carry is created into F and L:
300
+ //
301
+ // Output:
302
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
303
+ // S E E E E E E E E F F F F F F L
304
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
305
+ //
306
+ // Exact half value, round to next even number:
307
+ // Input:
308
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
309
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
310
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
311
+ //
312
+ // This falls into case 4. We create a carry from R and T,
313
+ // which then propagates into L and F:
314
+ //
315
+ // Output:
316
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
317
+ // S E E E E E E E E F F F F F F L
318
+ // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
319
+ //
320
+ //
321
+ // Max denormal value round to min normal value:
322
+ // Input:
323
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
324
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
325
+ // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
326
+ //
327
+ // This falls into case 4. We create a carry from R and T,
328
+ // propagate into L and F, which then propagates into exponent
329
+ // bits:
330
+ //
331
+ // Output:
332
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
333
+ // S E E E E E E E E F F F F F F L
334
+ // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
335
+ //
336
+ // Max normal value round to Inf:
337
+ // Input:
338
+ // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
339
+ // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
340
+ // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
341
+ //
342
+ // This falls into case 4. We create a carry from R and T,
343
+ // propagate into L and F, which then propagates into exponent
344
+ // bits:
345
+ //
346
+ // Sign | Exp (8 bit) | Frac (first 7 bit)
347
+ // S E E E E E E E E F F F F F F L
348
+ // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
349
+ //
350
+ //
351
+ // Least significant bit of resulting bfloat.
352
+ uint32_t lsb = (input >> 16) & 1;
353
+ uint32_t rounding_bias = 0x7fff + lsb;
354
+ input += rounding_bias;
355
+ output.value = static_cast<uint16_t>(input >> 16);
356
+ }
357
+ return output;
358
+ }
359
+
360
+ static bfloat16 epsilon() {
361
+ bfloat16 x;
362
+ x.value = 0x3c00; // 0x1.0p-7
363
+ return x;
364
+ }
365
+
366
+ static bfloat16 highest() {
367
+ bfloat16 x;
368
+ x.value = 0x7F7F; // 0x1.FEp127
369
+ return x;
370
+ }
371
+
372
+ static bfloat16 lowest() {
373
+ bfloat16 x;
374
+ x.value = 0xFF7F; // -0x1.FEp127
375
+ return x;
376
+ }
377
+
378
+ uint16_t value;
379
+
380
+ // A value that represents "not a number".
381
+ static const uint16_t NAN_VALUE = 0x7FC0;
382
+
383
+ private:
384
+ // A value that represents "zero".
385
+ static const uint16_t ZERO_VALUE = 0;
386
+
387
+ B16_DEVICE_FUNC static bool float_isnan(const float& x) {
388
+ #ifdef __CUDA_ARCH__
389
+ return ::isnan(x);
390
+ #else
391
+ return std::isnan(x);
392
+ #endif
393
+ }
394
+ };
395
+
396
+ B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os,
397
+ const bfloat16& dt) {
398
+ os << static_cast<float>(dt);
399
+ return os;
400
+ }
401
+
402
+ B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
403
+ return bfloat16(static_cast<float>(a) + static_cast<float>(b));
404
+ }
405
+ B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) {
406
+ return bfloat16(static_cast<float>(a) + static_cast<float>(b));
407
+ }
408
+ B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) {
409
+ return bfloat16(static_cast<float>(a) + static_cast<float>(b));
410
+ }
411
+ B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
412
+ return bfloat16(static_cast<float>(a) - static_cast<float>(b));
413
+ }
414
+ B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
415
+ return bfloat16(static_cast<float>(a) * static_cast<float>(b));
416
+ }
417
+ B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
418
+ return bfloat16(static_cast<float>(a) / static_cast<float>(b));
419
+ }
420
+ B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) {
421
+ a.value ^= 0x8000;
422
+ return a;
423
+ }
424
+ B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) {
425
+ return static_cast<float>(a) < static_cast<float>(b);
426
+ }
427
+ B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) {
428
+ return static_cast<float>(a) <= static_cast<float>(b);
429
+ }
430
+ B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) {
431
+ return static_cast<float>(a) == static_cast<float>(b);
432
+ }
433
+ B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) {
434
+ return static_cast<float>(a) != static_cast<float>(b);
435
+ }
436
+ B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) {
437
+ return static_cast<float>(a) > static_cast<float>(b);
438
+ }
439
+ B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) {
440
+ return static_cast<float>(a) >= static_cast<float>(b);
441
+ }
442
+ B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) {
443
+ a = a + b;
444
+ return a;
445
+ }
446
+ B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) {
447
+ a = a - b;
448
+ return a;
449
+ }
450
+ B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) {
451
+ a += bfloat16(1);
452
+ return a;
453
+ }
454
+ B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) {
455
+ a -= bfloat16(1);
456
+ return a;
457
+ }
458
+ B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) {
459
+ bfloat16 original_value = a;
460
+ ++a;
461
+ return original_value;
462
+ }
463
+ B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) {
464
+ bfloat16 original_value = a;
465
+ --a;
466
+ return original_value;
467
+ }
468
+ B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) {
469
+ a = a * b;
470
+ return a;
471
+ }
472
+ B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) {
473
+ a = a / b;
474
+ return a;
475
+ }
476
+ } // end namespace tf
477
+
478
+ namespace tf {
479
+ B16_DEVICE_FUNC inline bfloat16 fabs(const bfloat16& a) { return bfloat16(fabsf(float(a))); }
480
+
481
+ B16_DEVICE_FUNC inline bfloat16 exp(const bfloat16& a) { return bfloat16(expf(float(a))); }
482
+ B16_DEVICE_FUNC inline bfloat16 exp2(const bfloat16& a) { return bfloat16(exp2f(float(a))); }
483
+ B16_DEVICE_FUNC inline bfloat16 expm1(const bfloat16& a) { return bfloat16(expm1f(float(a))); }
484
+
485
+ B16_DEVICE_FUNC inline bfloat16 log(const bfloat16& a) { return bfloat16(logf(float(a))); }
486
+ B16_DEVICE_FUNC inline bfloat16 log10(const bfloat16& a) { return bfloat16(log10f(float(a))); }
487
+ B16_DEVICE_FUNC inline bfloat16 log2(const bfloat16& a) { return bfloat16(log2f(float(a))); }
488
+ B16_DEVICE_FUNC inline bfloat16 log1p(const bfloat16& a) { return bfloat16(log1pf(float(a))); }
489
+ B16_DEVICE_FUNC inline bfloat16 logb(const bfloat16& a) { return bfloat16(logbf(float(a))); }
490
+
491
+ B16_DEVICE_FUNC inline bfloat16 sqrt(const bfloat16& a) { return bfloat16(sqrtf(float(a))); }
492
+ B16_DEVICE_FUNC inline bfloat16 cbrt(const bfloat16& a) { return bfloat16(cbrtf(float(a))); }
493
+
494
+ B16_DEVICE_FUNC inline bfloat16 sin(const bfloat16& a) { return bfloat16(sinf(float(a))); }
495
+ B16_DEVICE_FUNC inline bfloat16 cos(const bfloat16& a) { return bfloat16(cosf(float(a))); }
496
+ B16_DEVICE_FUNC inline bfloat16 tan(const bfloat16& a) { return bfloat16(tanf(float(a))); }
497
+ B16_DEVICE_FUNC inline bfloat16 asin(const bfloat16& a) { return bfloat16(asinf(float(a))); }
498
+ B16_DEVICE_FUNC inline bfloat16 acos(const bfloat16& a) { return bfloat16(acosf(float(a))); }
499
+ B16_DEVICE_FUNC inline bfloat16 atan(const bfloat16& a) { return bfloat16(atanf(float(a))); }
500
+
501
+ B16_DEVICE_FUNC inline bfloat16 sinh(const bfloat16& a) { return bfloat16(sinhf(float(a))); }
502
+ B16_DEVICE_FUNC inline bfloat16 cosh(const bfloat16& a) { return bfloat16(coshf(float(a))); }
503
+ B16_DEVICE_FUNC inline bfloat16 tanh(const bfloat16& a) { return bfloat16(tanhf(float(a))); }
504
+ B16_DEVICE_FUNC inline bfloat16 asinh(const bfloat16& a) { return bfloat16(asinhf(float(a))); }
505
+ B16_DEVICE_FUNC inline bfloat16 acosh(const bfloat16& a) { return bfloat16(acoshf(float(a))); }
506
+ B16_DEVICE_FUNC inline bfloat16 atanh(const bfloat16& a) { return bfloat16(atanhf(float(a))); }
507
+
508
+ B16_DEVICE_FUNC inline bfloat16 erf(const bfloat16& a) { return bfloat16(erff(float(a))); }
509
+ B16_DEVICE_FUNC inline bfloat16 erfc(const bfloat16& a) { return bfloat16(erfcf(float(a))); }
510
+ B16_DEVICE_FUNC inline bfloat16 lgamma(const bfloat16& a) { return bfloat16(lgammaf(float(a))); }
511
+ B16_DEVICE_FUNC inline bfloat16 tgamma(const bfloat16& a) { return bfloat16(tgammaf(float(a))); }
512
+
513
+ B16_DEVICE_FUNC inline bfloat16 floor(const bfloat16& a) { return bfloat16(floorf(float(a))); }
514
+ B16_DEVICE_FUNC inline bfloat16 ceil(const bfloat16& a) { return bfloat16(ceilf(float(a))); }
515
+ B16_DEVICE_FUNC inline bfloat16 trunc(const bfloat16& a) { return bfloat16(truncf(float(a))); }
516
+ B16_DEVICE_FUNC inline bfloat16 round(const bfloat16& a) { return bfloat16(roundf(float(a))); }
517
+ B16_DEVICE_FUNC inline bfloat16 nearbyint(const bfloat16& a) { return bfloat16(nearbyintf(float(a))); }
518
+ } // namespace tf
519
+
520
+ #endif // BFLOAT16_H