gumath 0.2.0dev5 → 0.2.0dev8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,155 @@
|
|
1
|
+
/*
|
2
|
+
* BSD 3-Clause License
|
3
|
+
*
|
4
|
+
* Copyright (c) 2017-2018, plures
|
5
|
+
* All rights reserved.
|
6
|
+
*
|
7
|
+
* Redistribution and use in source and binary forms, with or without
|
8
|
+
* modification, are permitted provided that the following conditions are met:
|
9
|
+
*
|
10
|
+
* 1. Redistributions of source code must retain the above copyright notice,
|
11
|
+
* this list of conditions and the following disclaimer.
|
12
|
+
*
|
13
|
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
14
|
+
* this list of conditions and the following disclaimer in the documentation
|
15
|
+
* and/or other materials provided with the distribution.
|
16
|
+
*
|
17
|
+
* 3. Neither the name of the copyright holder nor the names of its
|
18
|
+
* contributors may be used to endorse or promote products derived from
|
19
|
+
* this software without specific prior written permission.
|
20
|
+
*
|
21
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
22
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
25
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
26
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
27
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
28
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
29
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
30
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31
|
+
*/
|
32
|
+
|
33
|
+
|
34
|
+
#include <stdlib.h>
|
35
|
+
#include <stdint.h>
|
36
|
+
#include <string.h>
|
37
|
+
#include <math.h>
|
38
|
+
#include <complex.h>
|
39
|
+
#include <inttypes.h>
|
40
|
+
#include "ndtypes.h"
|
41
|
+
#include "xnd.h"
|
42
|
+
#include "gumath.h"
|
43
|
+
|
44
|
+
|
45
|
+
#ifndef COMMON_H
|
46
|
+
#define COMMON_H
|
47
|
+
|
48
|
+
|
49
|
+
#define XSTRINGIZE(v) #v
|
50
|
+
#define STRINGIZE(v) XSTRINGIZE(v)
|
51
|
+
|
52
|
+
|
53
|
+
/*****************************************************************************/
|
54
|
+
/* Apply linear index to the data pointer (1D kernels) */
|
55
|
+
/*****************************************************************************/
|
56
|
+
|
57
|
+
static inline char *
|
58
|
+
apply_index(const xnd_t *x)
|
59
|
+
{
|
60
|
+
return xnd_fixed_apply_index(x);
|
61
|
+
}
|
62
|
+
|
63
|
+
|
64
|
+
/*****************************************************************************/
|
65
|
+
/* Optimized bitmap handling */
|
66
|
+
/*****************************************************************************/
|
67
|
+
|
68
|
+
static inline uint8_t *
|
69
|
+
get_bitmap(const xnd_t *x)
|
70
|
+
{
|
71
|
+
const ndt_t *t = x->type;
|
72
|
+
assert(t->ndim == 0);
|
73
|
+
return ndt_is_optional(t) ? x->bitmap.data : NULL;
|
74
|
+
}
|
75
|
+
|
76
|
+
static inline uint8_t *
|
77
|
+
get_bitmap1D(const xnd_t *x)
|
78
|
+
{
|
79
|
+
const ndt_t *t = x->type;
|
80
|
+
assert(t->ndim == 1 && t->tag == FixedDim);
|
81
|
+
return ndt_is_optional(ndt_dtype(t)) ? x->bitmap.data : NULL;
|
82
|
+
}
|
83
|
+
|
84
|
+
static inline bool
|
85
|
+
is_valid(const uint8_t *data, int64_t n)
|
86
|
+
{
|
87
|
+
int64_t pos = n / 8;
|
88
|
+
int64_t shift = n % 8;
|
89
|
+
uint8_t mask = (uint8_t)1 << shift;;
|
90
|
+
|
91
|
+
return data[pos] & mask;
|
92
|
+
}
|
93
|
+
|
94
|
+
static inline void
|
95
|
+
set_bit(uint8_t *data, int64_t n, bool x)
|
96
|
+
{
|
97
|
+
int64_t pos = n / 8;
|
98
|
+
int64_t shift = n % 8;
|
99
|
+
uint8_t dmask = ((uint8_t)1) << shift;
|
100
|
+
uint8_t xmask = ((uint8_t)x) << shift;
|
101
|
+
|
102
|
+
data[pos] ^= ((data[pos] & dmask) ^ xmask);
|
103
|
+
}
|
104
|
+
|
105
|
+
static inline int64_t
|
106
|
+
linear_index1D(const xnd_t *x, const int64_t i)
|
107
|
+
{
|
108
|
+
const ndt_t *t = x->type;
|
109
|
+
const int64_t step = i * t->Concrete.FixedDim.step;
|
110
|
+
return x->index + step;
|
111
|
+
}
|
112
|
+
|
113
|
+
|
114
|
+
/*****************************************************************************/
|
115
|
+
/* Binary typecheck */
|
116
|
+
/*****************************************************************************/
|
117
|
+
|
118
|
+
/* LOCAL SCOPE */
|
119
|
+
NDT_PRAGMA(NDT_HIDE_SYMBOLS_START)
|
120
|
+
|
121
|
+
void unary_update_bitmap_1D_S(xnd_t stack[]);
|
122
|
+
void unary_reduce_bitmap_1D_S(xnd_t stack[]);
|
123
|
+
void unary_update_bitmap_0D(xnd_t stack[]);
|
124
|
+
|
125
|
+
void binary_update_bitmap_1D_S(xnd_t stack[]);
|
126
|
+
void binary_update_bitmap_0D(xnd_t stack[]);
|
127
|
+
|
128
|
+
void binary_update_bitmap_1D_S_bool(xnd_t stack[]);
|
129
|
+
void binary_update_bitmap_0D_bool(xnd_t stack[]);
|
130
|
+
|
131
|
+
const gm_kernel_set_t *cpu_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
|
132
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
133
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
134
|
+
ndt_context_t *ctx);
|
135
|
+
|
136
|
+
const gm_kernel_set_t *cuda_unary_typecheck(int (*kernel_location)(const ndt_t *, const ndt_t *, ndt_context_t *),
|
137
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
138
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
139
|
+
ndt_context_t *ctx);
|
140
|
+
|
141
|
+
const gm_kernel_set_t *cpu_binary_typecheck(int (*kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
|
142
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
143
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
144
|
+
ndt_context_t *ctx);
|
145
|
+
|
146
|
+
const gm_kernel_set_t *cuda_binary_typecheck(int (* kernel_location)(const ndt_t *in0, const ndt_t *in1, ndt_context_t *ctx),
|
147
|
+
ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *types[],
|
148
|
+
const int64_t li[], int nin, int nout, bool check_broadcast,
|
149
|
+
ndt_context_t *ctx);
|
150
|
+
|
151
|
+
/* END LOCAL SCOPE */
|
152
|
+
NDT_PRAGMA(NDT_HIDE_SYMBOLS_END)
|
153
|
+
|
154
|
+
|
155
|
+
#endif /* COMMON_H */
|
@@ -0,0 +1,520 @@
|
|
1
|
+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
/* Modified and adapted for gumath. */
|
17
|
+
|
18
|
+
#ifndef BFLOAT16_H
|
19
|
+
#define BFLOAT16_H
|
20
|
+
|
21
|
+
|
22
|
+
// Byte order defines provided by gcc. MSVC doesn't define those so
|
23
|
+
// we define them here.
|
24
|
+
// We assume that all windows platform out there are little endian.
|
25
|
+
#if defined(_MSC_VER) && !defined(__clang__)
|
26
|
+
#define __ORDER_LITTLE_ENDIAN__ 0x4d2
|
27
|
+
#define __ORDER_BIG_ENDIAN__ 0x10e1
|
28
|
+
#define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__
|
29
|
+
#endif
|
30
|
+
|
31
|
+
#ifdef __CUDACC__
|
32
|
+
// All functions callable from CUDA code must be qualified with __device__
|
33
|
+
#define B16_DEVICE_FUNC __host__ __device__
|
34
|
+
#include <thrust/complex.h>
|
35
|
+
#include <math.h>
|
36
|
+
typedef thrust::complex<float> complex64;
|
37
|
+
typedef thrust::complex<double> complex128;
|
38
|
+
#else
|
39
|
+
#define B16_DEVICE_FUNC
|
40
|
+
#include <cmath>
|
41
|
+
#include <complex>
|
42
|
+
typedef std::complex<float> complex64;
|
43
|
+
typedef std::complex<double> complex128;
|
44
|
+
#endif
|
45
|
+
|
46
|
+
|
47
|
+
namespace tf {
|
48
|
+
|
49
|
+
// see framework/bfloat16.h for description.
|
50
|
+
struct bfloat16 {
|
51
|
+
// The default constructor must yield a zero value, not an uninitialized
|
52
|
+
// value; some TF kernels use T() as a zero value.
|
53
|
+
B16_DEVICE_FUNC bfloat16() : value(ZERO_VALUE) {}
|
54
|
+
|
55
|
+
B16_DEVICE_FUNC static bfloat16 truncate_to_bfloat16(const float v) {
|
56
|
+
bfloat16 output;
|
57
|
+
if (float_isnan(v)) {
|
58
|
+
output.value = NAN_VALUE;
|
59
|
+
return output;
|
60
|
+
}
|
61
|
+
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
62
|
+
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
63
|
+
output.value = p[0];
|
64
|
+
#else
|
65
|
+
output.value = p[1];
|
66
|
+
#endif
|
67
|
+
return output;
|
68
|
+
}
|
69
|
+
|
70
|
+
B16_DEVICE_FUNC explicit bfloat16(const float v) {
|
71
|
+
value = round_to_bfloat16(v).value;
|
72
|
+
}
|
73
|
+
|
74
|
+
B16_DEVICE_FUNC explicit bfloat16(const double val)
|
75
|
+
: bfloat16(static_cast<float>(val)) {}
|
76
|
+
// Following the convention of numpy, converting between complex and
|
77
|
+
// float will lead to loss of imag value.
|
78
|
+
B16_DEVICE_FUNC explicit bfloat16(const complex64& val)
|
79
|
+
: bfloat16(val.real()) {}
|
80
|
+
|
81
|
+
B16_DEVICE_FUNC explicit bfloat16(const complex128& val)
|
82
|
+
: bfloat16(static_cast<float>(val.real())) {}
|
83
|
+
|
84
|
+
B16_DEVICE_FUNC explicit bfloat16(const unsigned short val)
|
85
|
+
: bfloat16(static_cast<float>(val)) {}
|
86
|
+
|
87
|
+
B16_DEVICE_FUNC explicit bfloat16(const unsigned int val)
|
88
|
+
: bfloat16(static_cast<float>(val)) {}
|
89
|
+
|
90
|
+
B16_DEVICE_FUNC explicit bfloat16(const int val)
|
91
|
+
: bfloat16(static_cast<float>(val)) {}
|
92
|
+
|
93
|
+
B16_DEVICE_FUNC explicit bfloat16(const long val)
|
94
|
+
: bfloat16(static_cast<float>(val)) {}
|
95
|
+
|
96
|
+
B16_DEVICE_FUNC explicit bfloat16(const long long val)
|
97
|
+
: bfloat16(static_cast<float>(val)) {}
|
98
|
+
|
99
|
+
template <class T>
|
100
|
+
B16_DEVICE_FUNC explicit bfloat16(const T& val)
|
101
|
+
: bfloat16(static_cast<float>(val)) {}
|
102
|
+
|
103
|
+
B16_DEVICE_FUNC explicit operator float() const {
|
104
|
+
float result = 0;
|
105
|
+
|
106
|
+
uint16_t* q = reinterpret_cast<uint16_t*>(&result);
|
107
|
+
|
108
|
+
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
109
|
+
q[0] = value;
|
110
|
+
#else
|
111
|
+
q[1] = value;
|
112
|
+
#endif
|
113
|
+
return result;
|
114
|
+
}
|
115
|
+
|
116
|
+
B16_DEVICE_FUNC explicit operator bool() const {
|
117
|
+
return static_cast<bool>(!!(float(*this)));
|
118
|
+
}
|
119
|
+
|
120
|
+
B16_DEVICE_FUNC explicit operator short() const {
|
121
|
+
return static_cast<short>(float(*this));
|
122
|
+
}
|
123
|
+
|
124
|
+
B16_DEVICE_FUNC explicit operator int() const {
|
125
|
+
return static_cast<int>(float(*this));
|
126
|
+
}
|
127
|
+
|
128
|
+
B16_DEVICE_FUNC explicit operator long() const {
|
129
|
+
return static_cast<long>(float(*this));
|
130
|
+
}
|
131
|
+
|
132
|
+
B16_DEVICE_FUNC explicit operator char() const {
|
133
|
+
return static_cast<char>(float(*this));
|
134
|
+
}
|
135
|
+
|
136
|
+
B16_DEVICE_FUNC explicit operator signed char() const {
|
137
|
+
return static_cast<signed char>(float(*this));
|
138
|
+
}
|
139
|
+
|
140
|
+
B16_DEVICE_FUNC explicit operator unsigned char() const {
|
141
|
+
return static_cast<unsigned char>(float(*this));
|
142
|
+
}
|
143
|
+
|
144
|
+
B16_DEVICE_FUNC explicit operator unsigned short() const {
|
145
|
+
return static_cast<unsigned short>(float(*this));
|
146
|
+
}
|
147
|
+
|
148
|
+
B16_DEVICE_FUNC explicit operator unsigned int() const {
|
149
|
+
return static_cast<unsigned int>(float(*this));
|
150
|
+
}
|
151
|
+
|
152
|
+
B16_DEVICE_FUNC explicit operator unsigned long() const {
|
153
|
+
return static_cast<unsigned long>(float(*this));
|
154
|
+
}
|
155
|
+
|
156
|
+
B16_DEVICE_FUNC explicit operator unsigned long long() const {
|
157
|
+
return static_cast<unsigned long long>(float(*this));
|
158
|
+
}
|
159
|
+
|
160
|
+
B16_DEVICE_FUNC explicit operator long long() const {
|
161
|
+
return static_cast<long long>(float(*this));
|
162
|
+
}
|
163
|
+
|
164
|
+
B16_DEVICE_FUNC explicit operator double() const {
|
165
|
+
return static_cast<double>(float(*this));
|
166
|
+
}
|
167
|
+
|
168
|
+
B16_DEVICE_FUNC explicit operator complex64() const {
|
169
|
+
return complex64(float(*this), float(0.0));
|
170
|
+
}
|
171
|
+
|
172
|
+
B16_DEVICE_FUNC explicit operator complex128() const {
|
173
|
+
return complex128(double(*this), double(0.0));
|
174
|
+
}
|
175
|
+
|
176
|
+
union FP32 {
|
177
|
+
unsigned int u;
|
178
|
+
float f;
|
179
|
+
};
|
180
|
+
|
181
|
+
// Converts a float point to bfloat16, with round-nearest-to-even as rounding
|
182
|
+
// method.
|
183
|
+
// TODO: There is a slightly faster implementation (8% faster on CPU)
|
184
|
+
// than this (documented in cl/175987786), that is exponentially harder to
|
185
|
+
// understand and document. Switch to the faster version when converting to
|
186
|
+
// BF16 becomes compute-bound.
|
187
|
+
B16_DEVICE_FUNC static bfloat16 round_to_bfloat16(float v) {
|
188
|
+
uint32_t input;
|
189
|
+
FP32 f;
|
190
|
+
f.f = v;
|
191
|
+
input = f.u;
|
192
|
+
bfloat16 output;
|
193
|
+
|
194
|
+
if (float_isnan(v)) {
|
195
|
+
// If the value is a NaN, squash it to a qNaN with msb of fraction set,
|
196
|
+
// this makes sure after truncation we don't end up with an inf.
|
197
|
+
//
|
198
|
+
// qNaN magic: All exponent bits set + most significant bit of fraction
|
199
|
+
// set.
|
200
|
+
output.value = 0x7fc0;
|
201
|
+
} else {
|
202
|
+
// Fast rounding algorithm that rounds a half value to nearest even. This
|
203
|
+
// reduces expected error when we convert a large number of floats. Here
|
204
|
+
// is how it works:
|
205
|
+
//
|
206
|
+
// Definitions:
|
207
|
+
// To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
|
208
|
+
// with the following tags:
|
209
|
+
//
|
210
|
+
// Sign | Exp (8 bits) | Frac (23 bits)
|
211
|
+
// S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT
|
212
|
+
//
|
213
|
+
// S: Sign bit.
|
214
|
+
// E: Exponent bits.
|
215
|
+
// F: First 6 bits of fraction.
|
216
|
+
// L: Least significant bit of resulting bfloat16 if we truncate away the
|
217
|
+
// rest of the float32. This is also the 7th bit of fraction
|
218
|
+
// R: Rounding bit, 8th bit of fraction.
|
219
|
+
// T: Sticky bits, rest of fraction, 15 bits.
|
220
|
+
//
|
221
|
+
// To round half to nearest even, there are 3 cases where we want to round
|
222
|
+
// down (simply truncate the result of the bits away, which consists of
|
223
|
+
// rounding bit and sticky bits) and two cases where we want to round up
|
224
|
+
// (truncate then add one to the result).
|
225
|
+
//
|
226
|
+
// The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
|
227
|
+
// 1s) as the rounding bias, adds the rounding bias to the input, then
|
228
|
+
// truncates the last 16 bits away.
|
229
|
+
//
|
230
|
+
// To understand how it works, we can analyze this algorithm case by case:
|
231
|
+
//
|
232
|
+
// 1. L = 0, R = 0:
|
233
|
+
// Expect: round down, this is less than half value.
|
234
|
+
//
|
235
|
+
// Algorithm:
|
236
|
+
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
237
|
+
// - Adding rounding bias to input may create any carry, depending on
|
238
|
+
// whether there is any value set to 1 in T bits.
|
239
|
+
// - R may be set to 1 if there is a carry.
|
240
|
+
// - L remains 0.
|
241
|
+
// - Note that this case also handles Inf and -Inf, where all fraction
|
242
|
+
// bits, including L, R and Ts are all 0. The output remains Inf after
|
243
|
+
// this algorithm.
|
244
|
+
//
|
245
|
+
// 2. L = 1, R = 0:
|
246
|
+
// Expect: round down, this is less than half value.
|
247
|
+
//
|
248
|
+
// Algorithm:
|
249
|
+
// - Rounding bias: 0x7fff + 1 = 0x8000
|
250
|
+
// - Adding rounding bias to input doesn't change sticky bits but
|
251
|
+
// adds 1 to rounding bit.
|
252
|
+
// - L remains 1.
|
253
|
+
//
|
254
|
+
// 3. L = 0, R = 1, all of T are 0:
|
255
|
+
// Expect: round down, this is exactly at half, the result is already
|
256
|
+
// even (L=0).
|
257
|
+
//
|
258
|
+
// Algorithm:
|
259
|
+
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
260
|
+
// - Adding rounding bias to input sets all sticky bits to 1, but
|
261
|
+
// doesn't create a carry.
|
262
|
+
// - R remains 1.
|
263
|
+
// - L remains 0.
|
264
|
+
//
|
265
|
+
// 4. L = 1, R = 1:
|
266
|
+
// Expect: round up, this is exactly at half, the result needs to be
|
267
|
+
// round to the next even number.
|
268
|
+
//
|
269
|
+
// Algorithm:
|
270
|
+
// - Rounding bias: 0x7fff + 1 = 0x8000
|
271
|
+
// - Adding rounding bias to input doesn't change sticky bits, but
|
272
|
+
// creates a carry from rounding bit.
|
273
|
+
// - The carry sets L to 0, creates another carry bit and propagate
|
274
|
+
// forward to F bits.
|
275
|
+
// - If all the F bits are 1, a carry then propagates to the exponent
|
276
|
+
// bits, which then creates the minimum value with the next exponent
|
277
|
+
// value. Note that we won't have the case where exponents are all 1,
|
278
|
+
// since that's either a NaN (handled in the other if condition) or inf
|
279
|
+
// (handled in case 1).
|
280
|
+
//
|
281
|
+
// 5. L = 0, R = 1, any of T is 1:
|
282
|
+
// Expect: round up, this is greater than half.
|
283
|
+
//
|
284
|
+
// Algorithm:
|
285
|
+
// - Rounding bias: 0x7fff + 0 = 0x7fff
|
286
|
+
// - Adding rounding bias to input creates a carry from sticky bits,
|
287
|
+
// sets rounding bit to 0, then create another carry.
|
288
|
+
// - The second carry sets L to 1.
|
289
|
+
//
|
290
|
+
// Examples:
|
291
|
+
//
|
292
|
+
// Exact half value that is already even:
|
293
|
+
// Input:
|
294
|
+
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
295
|
+
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
296
|
+
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000
|
297
|
+
//
|
298
|
+
// This falls into case 3. We truncate the rest of 16 bits and no
|
299
|
+
// carry is created into F and L:
|
300
|
+
//
|
301
|
+
// Output:
|
302
|
+
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
303
|
+
// S E E E E E E E E F F F F F F L
|
304
|
+
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
305
|
+
//
|
306
|
+
// Exact half value, round to next even number:
|
307
|
+
// Input:
|
308
|
+
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
309
|
+
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
310
|
+
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000
|
311
|
+
//
|
312
|
+
// This falls into case 4. We create a carry from R and T,
|
313
|
+
// which then propagates into L and F:
|
314
|
+
//
|
315
|
+
// Output:
|
316
|
+
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
317
|
+
// S E E E E E E E E F F F F F F L
|
318
|
+
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
|
319
|
+
//
|
320
|
+
//
|
321
|
+
// Max denormal value round to min normal value:
|
322
|
+
// Input:
|
323
|
+
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
324
|
+
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
325
|
+
// 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111
|
326
|
+
//
|
327
|
+
// This falls into case 4. We create a carry from R and T,
|
328
|
+
// propagate into L and F, which then propagates into exponent
|
329
|
+
// bits:
|
330
|
+
//
|
331
|
+
// Output:
|
332
|
+
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
333
|
+
// S E E E E E E E E F F F F F F L
|
334
|
+
// 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
|
335
|
+
//
|
336
|
+
// Max normal value round to Inf:
|
337
|
+
// Input:
|
338
|
+
// Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
|
339
|
+
// S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT
|
340
|
+
// 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111
|
341
|
+
//
|
342
|
+
// This falls into case 4. We create a carry from R and T,
|
343
|
+
// propagate into L and F, which then propagates into exponent
|
344
|
+
// bits:
|
345
|
+
//
|
346
|
+
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
347
|
+
// S E E E E E E E E F F F F F F L
|
348
|
+
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
|
349
|
+
//
|
350
|
+
//
|
351
|
+
// Least significant bit of resulting bfloat.
|
352
|
+
uint32_t lsb = (input >> 16) & 1;
|
353
|
+
uint32_t rounding_bias = 0x7fff + lsb;
|
354
|
+
input += rounding_bias;
|
355
|
+
output.value = static_cast<uint16_t>(input >> 16);
|
356
|
+
}
|
357
|
+
return output;
|
358
|
+
}
|
359
|
+
|
360
|
+
static bfloat16 epsilon() {
|
361
|
+
bfloat16 x;
|
362
|
+
x.value = 0x3c00; // 0x1.0p-7
|
363
|
+
return x;
|
364
|
+
}
|
365
|
+
|
366
|
+
static bfloat16 highest() {
|
367
|
+
bfloat16 x;
|
368
|
+
x.value = 0x7F7F; // 0x1.FEp127
|
369
|
+
return x;
|
370
|
+
}
|
371
|
+
|
372
|
+
static bfloat16 lowest() {
|
373
|
+
bfloat16 x;
|
374
|
+
x.value = 0xFF7F; // -0x1.FEp127
|
375
|
+
return x;
|
376
|
+
}
|
377
|
+
|
378
|
+
uint16_t value;
|
379
|
+
|
380
|
+
// A value that represents "not a number".
|
381
|
+
static const uint16_t NAN_VALUE = 0x7FC0;
|
382
|
+
|
383
|
+
private:
|
384
|
+
// A value that represents "zero".
|
385
|
+
static const uint16_t ZERO_VALUE = 0;
|
386
|
+
|
387
|
+
B16_DEVICE_FUNC static bool float_isnan(const float& x) {
|
388
|
+
#ifdef __CUDA_ARCH__
|
389
|
+
return ::isnan(x);
|
390
|
+
#else
|
391
|
+
return std::isnan(x);
|
392
|
+
#endif
|
393
|
+
}
|
394
|
+
};
|
395
|
+
|
396
|
+
B16_DEVICE_FUNC inline std::ostream& operator<<(std::ostream& os,
|
397
|
+
const bfloat16& dt) {
|
398
|
+
os << static_cast<float>(dt);
|
399
|
+
return os;
|
400
|
+
}
|
401
|
+
|
402
|
+
B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
|
403
|
+
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
404
|
+
}
|
405
|
+
B16_DEVICE_FUNC inline bfloat16 operator+(bfloat16 a, int b) {
|
406
|
+
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
407
|
+
}
|
408
|
+
B16_DEVICE_FUNC inline bfloat16 operator+(int a, bfloat16 b) {
|
409
|
+
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
|
410
|
+
}
|
411
|
+
B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
|
412
|
+
return bfloat16(static_cast<float>(a) - static_cast<float>(b));
|
413
|
+
}
|
414
|
+
B16_DEVICE_FUNC inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
|
415
|
+
return bfloat16(static_cast<float>(a) * static_cast<float>(b));
|
416
|
+
}
|
417
|
+
B16_DEVICE_FUNC inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
|
418
|
+
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
|
419
|
+
}
|
420
|
+
B16_DEVICE_FUNC inline bfloat16 operator-(bfloat16 a) {
|
421
|
+
a.value ^= 0x8000;
|
422
|
+
return a;
|
423
|
+
}
|
424
|
+
B16_DEVICE_FUNC inline bool operator<(bfloat16 a, bfloat16 b) {
|
425
|
+
return static_cast<float>(a) < static_cast<float>(b);
|
426
|
+
}
|
427
|
+
B16_DEVICE_FUNC inline bool operator<=(bfloat16 a, bfloat16 b) {
|
428
|
+
return static_cast<float>(a) <= static_cast<float>(b);
|
429
|
+
}
|
430
|
+
B16_DEVICE_FUNC inline bool operator==(bfloat16 a, bfloat16 b) {
|
431
|
+
return static_cast<float>(a) == static_cast<float>(b);
|
432
|
+
}
|
433
|
+
B16_DEVICE_FUNC inline bool operator!=(bfloat16 a, bfloat16 b) {
|
434
|
+
return static_cast<float>(a) != static_cast<float>(b);
|
435
|
+
}
|
436
|
+
B16_DEVICE_FUNC inline bool operator>(bfloat16 a, bfloat16 b) {
|
437
|
+
return static_cast<float>(a) > static_cast<float>(b);
|
438
|
+
}
|
439
|
+
B16_DEVICE_FUNC inline bool operator>=(bfloat16 a, bfloat16 b) {
|
440
|
+
return static_cast<float>(a) >= static_cast<float>(b);
|
441
|
+
}
|
442
|
+
B16_DEVICE_FUNC inline bfloat16& operator+=(bfloat16& a, bfloat16 b) {
|
443
|
+
a = a + b;
|
444
|
+
return a;
|
445
|
+
}
|
446
|
+
B16_DEVICE_FUNC inline bfloat16& operator-=(bfloat16& a, bfloat16 b) {
|
447
|
+
a = a - b;
|
448
|
+
return a;
|
449
|
+
}
|
450
|
+
B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a) {
|
451
|
+
a += bfloat16(1);
|
452
|
+
return a;
|
453
|
+
}
|
454
|
+
B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a) {
|
455
|
+
a -= bfloat16(1);
|
456
|
+
return a;
|
457
|
+
}
|
458
|
+
B16_DEVICE_FUNC inline bfloat16 operator++(bfloat16& a, int) {
|
459
|
+
bfloat16 original_value = a;
|
460
|
+
++a;
|
461
|
+
return original_value;
|
462
|
+
}
|
463
|
+
B16_DEVICE_FUNC inline bfloat16 operator--(bfloat16& a, int) {
|
464
|
+
bfloat16 original_value = a;
|
465
|
+
--a;
|
466
|
+
return original_value;
|
467
|
+
}
|
468
|
+
B16_DEVICE_FUNC inline bfloat16& operator*=(bfloat16& a, bfloat16 b) {
|
469
|
+
a = a * b;
|
470
|
+
return a;
|
471
|
+
}
|
472
|
+
B16_DEVICE_FUNC inline bfloat16& operator/=(bfloat16& a, bfloat16 b) {
|
473
|
+
a = a / b;
|
474
|
+
return a;
|
475
|
+
}
|
476
|
+
} // end namespace tf
|
477
|
+
|
478
|
+
namespace tf {
|
479
|
+
B16_DEVICE_FUNC inline bfloat16 fabs(const bfloat16& a) { return bfloat16(fabsf(float(a))); }
|
480
|
+
|
481
|
+
B16_DEVICE_FUNC inline bfloat16 exp(const bfloat16& a) { return bfloat16(expf(float(a))); }
|
482
|
+
B16_DEVICE_FUNC inline bfloat16 exp2(const bfloat16& a) { return bfloat16(exp2f(float(a))); }
|
483
|
+
B16_DEVICE_FUNC inline bfloat16 expm1(const bfloat16& a) { return bfloat16(expm1f(float(a))); }
|
484
|
+
|
485
|
+
B16_DEVICE_FUNC inline bfloat16 log(const bfloat16& a) { return bfloat16(logf(float(a))); }
|
486
|
+
B16_DEVICE_FUNC inline bfloat16 log10(const bfloat16& a) { return bfloat16(log10f(float(a))); }
|
487
|
+
B16_DEVICE_FUNC inline bfloat16 log2(const bfloat16& a) { return bfloat16(log2f(float(a))); }
|
488
|
+
B16_DEVICE_FUNC inline bfloat16 log1p(const bfloat16& a) { return bfloat16(log1pf(float(a))); }
|
489
|
+
B16_DEVICE_FUNC inline bfloat16 logb(const bfloat16& a) { return bfloat16(logbf(float(a))); }
|
490
|
+
|
491
|
+
B16_DEVICE_FUNC inline bfloat16 sqrt(const bfloat16& a) { return bfloat16(sqrtf(float(a))); }
|
492
|
+
B16_DEVICE_FUNC inline bfloat16 cbrt(const bfloat16& a) { return bfloat16(cbrtf(float(a))); }
|
493
|
+
|
494
|
+
B16_DEVICE_FUNC inline bfloat16 sin(const bfloat16& a) { return bfloat16(sinf(float(a))); }
|
495
|
+
B16_DEVICE_FUNC inline bfloat16 cos(const bfloat16& a) { return bfloat16(cosf(float(a))); }
|
496
|
+
B16_DEVICE_FUNC inline bfloat16 tan(const bfloat16& a) { return bfloat16(tanf(float(a))); }
|
497
|
+
B16_DEVICE_FUNC inline bfloat16 asin(const bfloat16& a) { return bfloat16(asinf(float(a))); }
|
498
|
+
B16_DEVICE_FUNC inline bfloat16 acos(const bfloat16& a) { return bfloat16(acosf(float(a))); }
|
499
|
+
B16_DEVICE_FUNC inline bfloat16 atan(const bfloat16& a) { return bfloat16(atanf(float(a))); }
|
500
|
+
|
501
|
+
B16_DEVICE_FUNC inline bfloat16 sinh(const bfloat16& a) { return bfloat16(sinhf(float(a))); }
|
502
|
+
B16_DEVICE_FUNC inline bfloat16 cosh(const bfloat16& a) { return bfloat16(coshf(float(a))); }
|
503
|
+
B16_DEVICE_FUNC inline bfloat16 tanh(const bfloat16& a) { return bfloat16(tanhf(float(a))); }
|
504
|
+
B16_DEVICE_FUNC inline bfloat16 asinh(const bfloat16& a) { return bfloat16(asinhf(float(a))); }
|
505
|
+
B16_DEVICE_FUNC inline bfloat16 acosh(const bfloat16& a) { return bfloat16(acoshf(float(a))); }
|
506
|
+
B16_DEVICE_FUNC inline bfloat16 atanh(const bfloat16& a) { return bfloat16(atanhf(float(a))); }
|
507
|
+
|
508
|
+
B16_DEVICE_FUNC inline bfloat16 erf(const bfloat16& a) { return bfloat16(erff(float(a))); }
|
509
|
+
B16_DEVICE_FUNC inline bfloat16 erfc(const bfloat16& a) { return bfloat16(erfcf(float(a))); }
|
510
|
+
B16_DEVICE_FUNC inline bfloat16 lgamma(const bfloat16& a) { return bfloat16(lgammaf(float(a))); }
|
511
|
+
B16_DEVICE_FUNC inline bfloat16 tgamma(const bfloat16& a) { return bfloat16(tgammaf(float(a))); }
|
512
|
+
|
513
|
+
B16_DEVICE_FUNC inline bfloat16 floor(const bfloat16& a) { return bfloat16(floorf(float(a))); }
|
514
|
+
B16_DEVICE_FUNC inline bfloat16 ceil(const bfloat16& a) { return bfloat16(ceilf(float(a))); }
|
515
|
+
B16_DEVICE_FUNC inline bfloat16 trunc(const bfloat16& a) { return bfloat16(truncf(float(a))); }
|
516
|
+
B16_DEVICE_FUNC inline bfloat16 round(const bfloat16& a) { return bfloat16(roundf(float(a))); }
|
517
|
+
B16_DEVICE_FUNC inline bfloat16 nearbyint(const bfloat16& a) { return bfloat16(nearbyintf(float(a))); }
|
518
|
+
} // namespace tf
|
519
|
+
|
520
|
+
#endif // BFLOAT16_H
|