gumath 0.2.0dev5 → 0.2.0dev8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
|
@@ -43,6 +43,7 @@ from glob import glob
|
|
|
43
43
|
import platform
|
|
44
44
|
import subprocess
|
|
45
45
|
import shutil
|
|
46
|
+
import argparse
|
|
46
47
|
import warnings
|
|
47
48
|
|
|
48
49
|
|
|
@@ -55,6 +56,15 @@ LONG_DESCRIPTION = """\
|
|
|
55
56
|
|
|
56
57
|
warnings.simplefilter("ignore", UserWarning)
|
|
57
58
|
|
|
59
|
+
|
|
60
|
+
# Pre-parse and remove the '-j' argument from sys.argv.
|
|
61
|
+
parser = argparse.ArgumentParser()
|
|
62
|
+
parser.add_argument('-j', default=None)
|
|
63
|
+
values, rest = parser.parse_known_args()
|
|
64
|
+
PARALLEL = values.j
|
|
65
|
+
sys.argv = sys.argv[:1] + rest
|
|
66
|
+
|
|
67
|
+
|
|
58
68
|
if sys.platform == "darwin":
|
|
59
69
|
LIBNAME = "libgumath.dylib"
|
|
60
70
|
LIBSONAME = "libgumath.0.dylib"
|
|
@@ -154,7 +164,10 @@ if len(sys.argv) == 2:
|
|
|
154
164
|
path = module_path + ':' + python_path if python_path else module_path
|
|
155
165
|
env = os.environ.copy()
|
|
156
166
|
env['PYTHONPATH'] = path
|
|
157
|
-
ret = subprocess.call([sys.executable, "python/test_gumath.py"], env=env)
|
|
167
|
+
ret = subprocess.call([sys.executable, "python/test_gumath.py", "--long"], env=env)
|
|
168
|
+
if ret != 0:
|
|
169
|
+
sys.exit(ret)
|
|
170
|
+
ret = subprocess.call([sys.executable, "python/test_xndarray.py"], env=env)
|
|
158
171
|
sys.exit(ret)
|
|
159
172
|
elif sys.argv[1] == 'clean':
|
|
160
173
|
shutil.rmtree("build", ignore_errors=True)
|
|
@@ -173,11 +186,27 @@ if len(sys.argv) == 2:
|
|
|
173
186
|
else:
|
|
174
187
|
pass
|
|
175
188
|
|
|
189
|
+
def get_config_vars():
|
|
190
|
+
f = open("config.h")
|
|
191
|
+
config_vars = {}
|
|
192
|
+
for line in f:
|
|
193
|
+
if line.startswith("#define"):
|
|
194
|
+
l = line.split()
|
|
195
|
+
try:
|
|
196
|
+
config_vars[l[1]] = int(l[2])
|
|
197
|
+
except ValueError:
|
|
198
|
+
pass
|
|
199
|
+
elif line.startswith("/* #undef"):
|
|
200
|
+
l = line.split()
|
|
201
|
+
config_vars[l[2]] = 0
|
|
202
|
+
f.close()
|
|
203
|
+
return config_vars
|
|
176
204
|
|
|
177
205
|
def gumath_extensions():
|
|
178
206
|
add_include_dirs = [".", "libgumath", "ndtypes/python/ndtypes", "xnd/python/xnd"] + INCLUDES
|
|
179
207
|
add_library_dirs = ["libgumath", "ndtypes/libndtypes", "xnd/libxnd"] + LIBS
|
|
180
208
|
add_depends = []
|
|
209
|
+
config_vars = {}
|
|
181
210
|
|
|
182
211
|
if sys.platform == "win32":
|
|
183
212
|
add_libraries = ["libndtypes-0.2.0dev3.dll", "libxnd-0.2.0dev3.dll", "libgumath-0.2.0dev3.dll"]
|
|
@@ -199,6 +228,14 @@ def gumath_extensions():
|
|
|
199
228
|
os.system("vcbuild32.bat")
|
|
200
229
|
os.chdir("..")
|
|
201
230
|
else:
|
|
231
|
+
if BUILD_ALL:
|
|
232
|
+
cflags = '"-I%s -I%s"' % tuple(CONFIGURE_INCLUDES)
|
|
233
|
+
ldflags = '"-L%s -L%s"' % tuple(CONFIGURE_LIBS)
|
|
234
|
+
make = "make -j%d" % int(PARALLEL) if PARALLEL else "make"
|
|
235
|
+
os.system("./configure CFLAGS=%s LDFLAGS=%s && %s" % (cflags, ldflags, make))
|
|
236
|
+
|
|
237
|
+
config_vars = get_config_vars()
|
|
238
|
+
|
|
202
239
|
add_extra_compile_args = ["-Wextra", "-Wno-missing-field-initializers", "-std=c11"]
|
|
203
240
|
if sys.platform == "darwin":
|
|
204
241
|
add_libraries = ["ndtypes", "xnd", "gumath"]
|
|
@@ -209,10 +246,16 @@ def gumath_extensions():
|
|
|
209
246
|
add_extra_link_args = []
|
|
210
247
|
add_runtime_library_dirs = ["$ORIGIN"]
|
|
211
248
|
|
|
212
|
-
if
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
249
|
+
if config_vars["HAVE_CUDA"]:
|
|
250
|
+
add_libraries += ["cudart"]
|
|
251
|
+
|
|
252
|
+
for d in [
|
|
253
|
+
"/usr/cuda/lib",
|
|
254
|
+
"/usr/cuda/lib64",
|
|
255
|
+
"/usr/local/cuda/lib/",
|
|
256
|
+
"/usr/local/cuda/lib64"]:
|
|
257
|
+
if os.path.isdir(d):
|
|
258
|
+
add_library_dirs.append(d)
|
|
216
259
|
|
|
217
260
|
def gumath_ext():
|
|
218
261
|
sources = ["python/gumath/_gumath.c"]
|
|
@@ -244,6 +287,21 @@ def gumath_extensions():
|
|
|
244
287
|
runtime_library_dirs = add_runtime_library_dirs
|
|
245
288
|
)
|
|
246
289
|
|
|
290
|
+
def cuda_ext():
|
|
291
|
+
sources = ["python/gumath/cuda.c"]
|
|
292
|
+
|
|
293
|
+
return Extension (
|
|
294
|
+
"gumath.cuda",
|
|
295
|
+
include_dirs = add_include_dirs,
|
|
296
|
+
library_dirs = add_library_dirs,
|
|
297
|
+
depends = add_depends,
|
|
298
|
+
sources = sources,
|
|
299
|
+
libraries = add_libraries,
|
|
300
|
+
extra_compile_args = add_extra_compile_args,
|
|
301
|
+
extra_link_args = add_extra_link_args,
|
|
302
|
+
runtime_library_dirs = add_runtime_library_dirs
|
|
303
|
+
)
|
|
304
|
+
|
|
247
305
|
def examples_ext():
|
|
248
306
|
sources = ["python/gumath/examples.c"]
|
|
249
307
|
|
|
@@ -259,7 +317,10 @@ def gumath_extensions():
|
|
|
259
317
|
runtime_library_dirs = add_runtime_library_dirs
|
|
260
318
|
)
|
|
261
319
|
|
|
262
|
-
|
|
320
|
+
extensions = [gumath_ext(), functions_ext(), examples_ext()]
|
|
321
|
+
if config_vars.get("HAVE_CUDA"):
|
|
322
|
+
extensions += [cuda_ext()]
|
|
323
|
+
return extensions
|
|
263
324
|
|
|
264
325
|
setup (
|
|
265
326
|
name = "gumath",
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
#include <cstdio>
|
|
2
|
+
#include <cstdlib>
|
|
3
|
+
#include <cuda_runtime.h>
|
|
4
|
+
|
|
5
|
+
static void
|
|
6
|
+
check(cudaError_t err)
|
|
7
|
+
{
|
|
8
|
+
if (err != cudaSuccess) {
|
|
9
|
+
exit(1);
|
|
10
|
+
}
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
static int
|
|
14
|
+
min(int x, int y)
|
|
15
|
+
{
|
|
16
|
+
return x <= y ? x : y;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
int main()
|
|
20
|
+
{
|
|
21
|
+
int res = INT_MAX;
|
|
22
|
+
cudaDeviceProp prop;
|
|
23
|
+
int count, i, n;
|
|
24
|
+
|
|
25
|
+
check(cudaGetDeviceCount(&count));
|
|
26
|
+
|
|
27
|
+
for (i = 0; i < count; i++) {
|
|
28
|
+
check(cudaGetDeviceProperties(&prop, i));
|
|
29
|
+
n = prop.major * 10 + prop.minor;
|
|
30
|
+
res = min(res, n);
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
printf("%d", res);
|
|
34
|
+
return 0;
|
|
35
|
+
}
|
|
@@ -34,6 +34,17 @@
|
|
|
34
34
|
#ifndef GUMATH_H
|
|
35
35
|
#define GUMATH_H
|
|
36
36
|
|
|
37
|
+
|
|
38
|
+
#ifdef __cplusplus
|
|
39
|
+
extern "C" {
|
|
40
|
+
#endif
|
|
41
|
+
|
|
42
|
+
#ifdef __cplusplus
|
|
43
|
+
#include <cstdint>
|
|
44
|
+
#else
|
|
45
|
+
#include <stdint.h>
|
|
46
|
+
#endif
|
|
47
|
+
|
|
37
48
|
#include "ndtypes.h"
|
|
38
49
|
#include "xnd.h"
|
|
39
50
|
|
|
@@ -65,7 +76,8 @@
|
|
|
65
76
|
#endif
|
|
66
77
|
|
|
67
78
|
|
|
68
|
-
#define GM_MAX_KERNELS
|
|
79
|
+
#define GM_MAX_KERNELS 8192
|
|
80
|
+
#define GM_THREAD_CUTOFF 1000000
|
|
69
81
|
|
|
70
82
|
typedef float float32_t;
|
|
71
83
|
typedef double float64_t;
|
|
@@ -74,15 +86,25 @@ typedef double float64_t;
|
|
|
74
86
|
typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
|
|
75
87
|
typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
|
|
76
88
|
|
|
77
|
-
/*
|
|
89
|
+
/*
|
|
90
|
+
* Collection of specialized kernels for a single function signature.
|
|
91
|
+
*
|
|
92
|
+
* NOTE: The specialized kernel lookup scheme is transitional and may
|
|
93
|
+
* be replaced by something else.
|
|
94
|
+
*
|
|
95
|
+
* This should be considered as a first version of a kernel request
|
|
96
|
+
* protocol.
|
|
97
|
+
*/
|
|
78
98
|
typedef struct {
|
|
79
|
-
ndt_t *sig;
|
|
99
|
+
const ndt_t *sig;
|
|
80
100
|
const ndt_constraint_t *constraint;
|
|
81
101
|
|
|
82
102
|
/* Xnd signatures */
|
|
83
|
-
gm_xnd_kernel_t
|
|
84
|
-
gm_xnd_kernel_t
|
|
85
|
-
gm_xnd_kernel_t
|
|
103
|
+
gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
|
|
104
|
+
gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
|
|
105
|
+
gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
|
|
106
|
+
gm_xnd_kernel_t C; /* C in inner dimensions */
|
|
107
|
+
gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
|
|
86
108
|
gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
|
|
87
109
|
|
|
88
110
|
/* NumPy signature */
|
|
@@ -99,11 +121,17 @@ typedef struct {
|
|
|
99
121
|
const char *name;
|
|
100
122
|
const char *sig;
|
|
101
123
|
const ndt_constraint_t *constraint;
|
|
124
|
+
uint32_t cap;
|
|
102
125
|
|
|
103
|
-
|
|
126
|
+
/* Xnd signatures */
|
|
127
|
+
gm_xnd_kernel_t OptC;
|
|
128
|
+
gm_xnd_kernel_t OptZ;
|
|
129
|
+
gm_xnd_kernel_t OptS;
|
|
104
130
|
gm_xnd_kernel_t C;
|
|
105
131
|
gm_xnd_kernel_t Fortran;
|
|
106
132
|
gm_xnd_kernel_t Xnd;
|
|
133
|
+
|
|
134
|
+
/* NumPy signature */
|
|
107
135
|
gm_strided_kernel_t Strided;
|
|
108
136
|
} gm_kernel_init_t;
|
|
109
137
|
|
|
@@ -115,7 +143,10 @@ typedef struct {
|
|
|
115
143
|
|
|
116
144
|
/* Multimethod with associated kernels */
|
|
117
145
|
typedef struct gm_func gm_func_t;
|
|
118
|
-
typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
|
|
146
|
+
typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
|
|
147
|
+
const ndt_t *in[], const int64_t li[],
|
|
148
|
+
int nin, int nout, bool check_broadcast,
|
|
149
|
+
ndt_context_t *ctx);
|
|
119
150
|
struct gm_func {
|
|
120
151
|
char *name;
|
|
121
152
|
gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
|
|
@@ -139,10 +170,10 @@ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_cont
|
|
|
139
170
|
GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
|
|
140
171
|
|
|
141
172
|
GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
|
142
|
-
const ndt_t *
|
|
143
|
-
ndt_context_t *ctx);
|
|
173
|
+
const ndt_t *types[], const int64_t li[], int nin, int nout,
|
|
174
|
+
bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
|
|
144
175
|
GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
|
|
145
|
-
GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
176
|
+
GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
|
|
146
177
|
|
|
147
178
|
|
|
148
179
|
/******************************************************************************/
|
|
@@ -171,6 +202,7 @@ GM_API int gm_np_map(const gm_strided_kernel_t f,
|
|
|
171
202
|
/* Xnd loops */
|
|
172
203
|
/******************************************************************************/
|
|
173
204
|
|
|
205
|
+
GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
|
|
174
206
|
GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
|
175
207
|
const int outer_dims, ndt_context_t *ctx);
|
|
176
208
|
|
|
@@ -191,10 +223,14 @@ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *sta
|
|
|
191
223
|
/******************************************************************************/
|
|
192
224
|
|
|
193
225
|
GM_API void gm_init(void);
|
|
194
|
-
GM_API int
|
|
195
|
-
GM_API int
|
|
226
|
+
GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
227
|
+
GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
228
|
+
GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
229
|
+
|
|
230
|
+
GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
231
|
+
GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
232
|
+
|
|
196
233
|
GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
197
|
-
GM_API int gm_init_bfloat16_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
198
234
|
GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
199
235
|
GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
200
236
|
GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
@@ -202,4 +238,9 @@ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
|
202
238
|
GM_API void gm_finalize(void);
|
|
203
239
|
|
|
204
240
|
|
|
241
|
+
#ifdef __cplusplus
|
|
242
|
+
} /* END extern "C" */
|
|
243
|
+
#endif
|
|
244
|
+
|
|
245
|
+
|
|
205
246
|
#endif /* GUMATH_H */
|
|
@@ -33,8 +33,11 @@
|
|
|
33
33
|
#define RUBY_GUMATH_H
|
|
34
34
|
|
|
35
35
|
/* Classes */
|
|
36
|
-
VALUE cGumath;
|
|
36
|
+
extern VALUE cGumath;
|
|
37
37
|
|
|
38
|
+
/* C API call for adding functions from a gumath kernel table to a Ruby module.
|
|
39
|
+
* Only adds CPU functions.
|
|
40
|
+
*/
|
|
38
41
|
int rb_gumath_add_functions(VALUE module, const gm_tbl_t *tbl);
|
|
39
42
|
#define GUMATH_FUNCTION_HASH rb_intern("@gumath_functions")
|
|
40
43
|
|
|
Binary file
|
|
Binary file
|
|
@@ -43,12 +43,14 @@ static gm_tbl_t *table = NULL;
|
|
|
43
43
|
/* Maximum number of threads */
|
|
44
44
|
static int64_t max_threads = 1;
|
|
45
45
|
static int initialized = 0;
|
|
46
|
-
|
|
46
|
+
VALUE cGumath;
|
|
47
47
|
|
|
48
48
|
/****************************************************************************/
|
|
49
49
|
/* Error handling */
|
|
50
50
|
/****************************************************************************/
|
|
51
51
|
|
|
52
|
+
static VALUE rb_eValueError;
|
|
53
|
+
|
|
52
54
|
VALUE
|
|
53
55
|
seterr(ndt_context_t *ctx)
|
|
54
56
|
{
|
|
@@ -59,117 +61,274 @@ seterr(ndt_context_t *ctx)
|
|
|
59
61
|
/* Instance methods */
|
|
60
62
|
/****************************************************************************/
|
|
61
63
|
|
|
64
|
+
/* Parse optional arguments passed to GuFuncObject#call.
|
|
65
|
+
*
|
|
66
|
+
* Populates the rbstack with all the input arguments. Then checks whether
|
|
67
|
+
* the 'out' kwarg has been specified and populates the rest of rbstack
|
|
68
|
+
* with contents of 'out'.
|
|
69
|
+
*/
|
|
70
|
+
void
|
|
71
|
+
parse_args(VALUE *rbstack, int *rb_nin, int *rb_nout, int *rb_nargs, int noptargs,
|
|
72
|
+
VALUE *argv, VALUE out)
|
|
73
|
+
{
|
|
74
|
+
size_t nin = noptargs, nout;
|
|
75
|
+
|
|
76
|
+
if (noptargs == 0) {
|
|
77
|
+
*rb_nin = 0;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
for (int i = 0; i < nin; i++) {
|
|
81
|
+
if (!rb_is_a(argv[i], cXND)) {
|
|
82
|
+
rb_raise(rb_eArgError, "expected xnd arguments.");
|
|
83
|
+
}
|
|
84
|
+
rbstack[i] = argv[i];
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
if (out == Qnil) {
|
|
88
|
+
nout = 0;
|
|
89
|
+
}
|
|
90
|
+
else {
|
|
91
|
+
if (rb_xnd_check_type(out)) {
|
|
92
|
+
nout = 1;
|
|
93
|
+
if (nin + nout > NDT_MAX_ARGS) {
|
|
94
|
+
rb_raise(rb_eTypeError, "max number of arguments is %d, got %ld.",
|
|
95
|
+
NDT_MAX_ARGS, nin+nout);
|
|
96
|
+
}
|
|
97
|
+
rbstack[nin] = out;
|
|
98
|
+
}
|
|
99
|
+
else if (RB_TYPE_P(out, T_ARRAY)) {
|
|
100
|
+
nout = rb_ary_size(out);
|
|
101
|
+
if (nout > NDT_MAX_ARGS || nin+nout > NDT_MAX_ARGS) {
|
|
102
|
+
rb_raise(rb_eTypeError, "max number of arguments is %d, got %ld.",
|
|
103
|
+
NDT_MAX_ARGS, nin+nout);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
for (int i = 0; i < nout; ++i) {
|
|
107
|
+
VALUE v = rb_ary_entry(out, i);
|
|
108
|
+
if (!rb_is_a(v, cXND)) {
|
|
109
|
+
rb_raise(rb_eTypeError, "expected xnd argument in all elements of out array.");
|
|
110
|
+
}
|
|
111
|
+
rbstack[nin+i] = v;
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
else {
|
|
115
|
+
rb_raise(rb_eTypeError, "'out' argument must of type XND or Array of XND objects.");
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
*rb_nin = (int)nin;
|
|
120
|
+
*rb_nout = (int)nout;
|
|
121
|
+
*rb_nargs = (int)nin + (int)nout;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
/* Implement call method on the GufuncObject call. */
|
|
62
125
|
static VALUE
|
|
63
126
|
Gumath_GufuncObject_call(int argc, VALUE *argv, VALUE self)
|
|
64
127
|
{
|
|
128
|
+
VALUE out = Qnil;
|
|
129
|
+
VALUE dt = Qnil;
|
|
130
|
+
VALUE cls = Qnil;
|
|
131
|
+
|
|
65
132
|
NDT_STATIC_CONTEXT(ctx);
|
|
133
|
+
VALUE rbstack[NDT_MAX_ARGS], opts = Qnil;
|
|
66
134
|
xnd_t stack[NDT_MAX_ARGS];
|
|
67
|
-
const ndt_t *
|
|
135
|
+
const ndt_t *types[NDT_MAX_ARGS];
|
|
68
136
|
gm_kernel_t kernel;
|
|
69
137
|
ndt_apply_spec_t spec = ndt_apply_spec_empty;
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
int
|
|
73
|
-
|
|
138
|
+
int64_t li[NDT_MAX_ARGS];
|
|
139
|
+
NdtObject *dt_p;
|
|
140
|
+
int k;
|
|
141
|
+
ndt_t *dtype = NULL;
|
|
142
|
+
int nin = argc, nout, nargs;
|
|
143
|
+
bool have_cpu_device = false;
|
|
144
|
+
GufuncObject * self_p;
|
|
145
|
+
bool check_broadcast = true, enable_threads = true;
|
|
74
146
|
|
|
75
147
|
if (argc > NDT_MAX_ARGS) {
|
|
76
148
|
rb_raise(rb_eArgError, "too many arguments.");
|
|
77
149
|
}
|
|
150
|
+
|
|
151
|
+
/* parse keyword arguments. */
|
|
152
|
+
int noptargs = argc;
|
|
153
|
+
for (int i = 0; i < argc; ++i) {
|
|
154
|
+
if (RB_TYPE_P(argv[i], T_HASH)) {
|
|
155
|
+
noptargs = i;
|
|
156
|
+
opts = argv[i];
|
|
157
|
+
break;
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
if (NIL_P(opts)) { opts = rb_hash_new(); }
|
|
162
|
+
|
|
163
|
+
out = rb_hash_aref(opts, ID2SYM(rb_intern("out")));
|
|
164
|
+
dt = rb_hash_aref(opts, ID2SYM(rb_intern("dtype")));
|
|
165
|
+
cls = rb_hash_aref(opts, ID2SYM(rb_intern("cls")));
|
|
166
|
+
|
|
167
|
+
if (NIL_P(cls)) { cls = cXND; }
|
|
168
|
+
if (!NIL_P(dt)) {
|
|
169
|
+
if (!NIL_P(out)) {
|
|
170
|
+
rb_raise(rb_eArgError, "the 'out' and 'dtype' arguments are mutually exclusive.");
|
|
171
|
+
}
|
|
78
172
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
if (!rb_xnd_check_type(argv[i])) {
|
|
82
|
-
VALUE str = rb_funcall(argv[i], rb_intern("inspect"), 0, NULL);
|
|
83
|
-
rb_raise(rb_eArgError, "Args must be XND. Received %s.", RSTRING_PTR(str));
|
|
173
|
+
if (!rb_ndtypes_check_type(dt)) {
|
|
174
|
+
rb_raise(rb_eArgError, "'dtype' argument must be an NDT object.");
|
|
84
175
|
}
|
|
176
|
+
dtype = (ndt_t *)rb_ndtypes_const_ndt(dt);
|
|
177
|
+
ndt_incref(dtype);
|
|
178
|
+
}
|
|
85
179
|
|
|
86
|
-
|
|
87
|
-
|
|
180
|
+
if (!rb_klass_has_ancestor(cls, cXND)) {
|
|
181
|
+
rb_raise(rb_eTypeError, "the 'cls' argument must be a subtype of 'xnd'.");
|
|
88
182
|
}
|
|
89
183
|
|
|
90
|
-
/*
|
|
184
|
+
/* parse leading optional arguments */
|
|
185
|
+
parse_args(rbstack, &nin, &nout, &nargs, noptargs, argv, out);
|
|
186
|
+
|
|
187
|
+
for (k = 0; k < nargs; ++k) {
|
|
188
|
+
if (!rb_xnd_is_cuda_managed(rbstack[k])) {
|
|
189
|
+
have_cpu_device = true;
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
stack[k] = *rb_xnd_const_xnd(rbstack[k]);
|
|
193
|
+
types[k] = stack[k].type;
|
|
194
|
+
li[k] = stack[k].index;
|
|
195
|
+
}
|
|
91
196
|
GET_GUOBJ(self, self_p);
|
|
197
|
+
|
|
198
|
+
if (have_cpu_device) {
|
|
199
|
+
if (self_p->flags & GM_CUDA_MANAGED_FUNC) {
|
|
200
|
+
rb_raise(rb_eValueError,
|
|
201
|
+
"cannot run a cuda function on xnd objects with cpu memory.");
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
kernel = gm_select(&spec, self_p->table, self_p->name, types, li, nin, nout,
|
|
206
|
+
nout && check_broadcast, stack, &ctx);
|
|
92
207
|
|
|
93
|
-
kernel = gm_select(&spec, self_p->table, self_p->name, in_types, argc, stack, &ctx);
|
|
94
208
|
if (kernel.set == NULL) {
|
|
95
209
|
seterr(&ctx);
|
|
96
210
|
raise_error();
|
|
97
211
|
}
|
|
98
212
|
|
|
99
|
-
if (
|
|
100
|
-
|
|
101
|
-
|
|
213
|
+
if (dtype) {
|
|
214
|
+
if (spec.nout != 1) {
|
|
215
|
+
ndt_err_format(&ctx, NDT_TypeError,
|
|
216
|
+
"the 'dtype' argument is only supported for a single "
|
|
217
|
+
"return value.");
|
|
218
|
+
ndt_apply_spec_clear(&spec);
|
|
219
|
+
ndt_decref(dtype);
|
|
220
|
+
seterr(&ctx);
|
|
221
|
+
raise_error();
|
|
102
222
|
}
|
|
103
|
-
}
|
|
104
223
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
stack[nin+i] = *rb_xnd_const_xnd(x);
|
|
224
|
+
const ndt_t *u = spec.types[spec.nin];
|
|
225
|
+
const ndt_t *v = ndt_copy_contiguous_dtype(u, dtype, 0, &ctx);
|
|
226
|
+
|
|
227
|
+
ndt_apply_spec_clear(&spec);
|
|
228
|
+
ndt_decref(dtype);
|
|
229
|
+
|
|
230
|
+
if (v == NULL) {
|
|
231
|
+
seterr(&ctx);
|
|
232
|
+
raise_error();
|
|
115
233
|
}
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
234
|
+
|
|
235
|
+
types[nin] = v;
|
|
236
|
+
kernel = gm_select(&spec, self_p->table, self_p->name, types, li, nin, 1,
|
|
237
|
+
1 && check_broadcast, stack, &ctx);
|
|
238
|
+
if (kernel.set == NULL) {
|
|
239
|
+
seterr(&ctx);
|
|
240
|
+
raise_error();
|
|
119
241
|
}
|
|
120
242
|
}
|
|
121
243
|
|
|
122
|
-
/*
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
244
|
+
/*
|
|
245
|
+
* Replace args/kwargs types with types after substitution and broadcasting.
|
|
246
|
+
* This includes 'out' types, if explicitly passed as kwargs.
|
|
247
|
+
*/
|
|
248
|
+
for (int i = 0; i < spec.nargs; ++i) {
|
|
249
|
+
stack[i].type = spec.types[i];
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
if (nout == 0) {
|
|
253
|
+
/* 'out' types have been inferred, create new XndObjects. */
|
|
254
|
+
VALUE x;
|
|
255
|
+
for (int i = 0; i < spec.nout; ++i) {
|
|
256
|
+
if (ndt_is_concrete(spec.types[nin+i])) {
|
|
257
|
+
uint32_t flags = self_p->flags == GM_CUDA_MANAGED_FUNC ? XND_CUDA_MANAGED : 0;
|
|
258
|
+
x = rb_xnd_empty_from_type(cls, spec.types[nin+i], flags);
|
|
259
|
+
rbstack[nin+i] = x;
|
|
260
|
+
stack[nin+i] = *rb_xnd_const_xnd(x);
|
|
261
|
+
}
|
|
262
|
+
else {
|
|
263
|
+
rb_raise(rb_eValueError,
|
|
264
|
+
"args with abstract types are temporarily disabled.");
|
|
265
|
+
}
|
|
266
|
+
}
|
|
128
267
|
}
|
|
268
|
+
|
|
269
|
+
if (self_p->flags == GM_CUDA_MANAGED_FUNC) {
|
|
270
|
+
#ifdef HAVE_CUDA
|
|
271
|
+
/* populate with CUDA specific stuff */
|
|
129
272
|
#else
|
|
130
|
-
|
|
273
|
+
ndt_err_format(&ctx, NDT_RuntimeError,
|
|
274
|
+
"internal error: GM_CUDA_MANAGED_FUNC set in a build without cuda support");
|
|
275
|
+
ndt_apply_spec_clear(&spec);
|
|
131
276
|
seterr(&ctx);
|
|
132
277
|
raise_error();
|
|
278
|
+
#endif // HAVE_CUDA
|
|
133
279
|
}
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
if (x == NULL) {
|
|
143
|
-
for (k = i+i; k < spec.nout; k++) {
|
|
144
|
-
if (ndt_is_abstract(spec.out[k])) {
|
|
145
|
-
xnd_del_buffer(&stack[nin+k], XND_OWN_ALL);
|
|
146
|
-
}
|
|
147
|
-
}
|
|
148
|
-
}
|
|
149
|
-
result[i] = x;
|
|
150
|
-
}
|
|
151
|
-
}
|
|
280
|
+
else {
|
|
281
|
+
#ifdef HAVE_PTHREAD_H
|
|
282
|
+
const int rounding = fegetround();
|
|
283
|
+
fesetround(FE_TONEAREST);
|
|
284
|
+
|
|
285
|
+
const int64_t N = enable_threads ? max_threads : 1;
|
|
286
|
+
const int ret = gm_apply_thread(&kernel, stack, spec.outer_dims, N, &ctx);
|
|
287
|
+
fesetround(rounding);
|
|
152
288
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
289
|
+
if (ret < 0) {
|
|
290
|
+
ndt_apply_spec_clear(&spec);
|
|
291
|
+
seterr(&ctx);
|
|
292
|
+
raise_error();
|
|
156
293
|
}
|
|
294
|
+
#else
|
|
295
|
+
const int rounding = fegetround();
|
|
296
|
+
fesetround(FE_TONEAREST);
|
|
297
|
+
|
|
298
|
+
const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
|
|
299
|
+
fesetround(rounding);
|
|
300
|
+
|
|
301
|
+
if (ret < 0) {
|
|
302
|
+
ndt_apply_spec_clear(&spec);
|
|
303
|
+
seterr(&ctx);
|
|
304
|
+
raise_error();
|
|
305
|
+
}
|
|
306
|
+
#endif // HAVE_PTHREAD_H
|
|
157
307
|
}
|
|
158
308
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
309
|
+
nin = spec.nin;
|
|
310
|
+
nout = spec.nout;
|
|
311
|
+
nargs = spec.nargs;
|
|
312
|
+
ndt_apply_spec_clear(&spec);
|
|
313
|
+
|
|
314
|
+
switch (nout) {
|
|
315
|
+
case 0: {
|
|
316
|
+
return Qnil;
|
|
317
|
+
}
|
|
318
|
+
case 1: {
|
|
319
|
+
return rbstack[nin];
|
|
320
|
+
}
|
|
163
321
|
default: {
|
|
164
|
-
VALUE
|
|
165
|
-
for (i = 0; i <
|
|
166
|
-
rb_ary_store(
|
|
322
|
+
VALUE arr = rb_ary_new2(nout);
|
|
323
|
+
for (int i = 0; i < nout; ++i) {
|
|
324
|
+
rb_ary_store(arr, i, rbstack[nin+i]);
|
|
167
325
|
}
|
|
168
|
-
return
|
|
326
|
+
return arr;
|
|
169
327
|
}
|
|
170
328
|
}
|
|
171
329
|
}
|
|
172
330
|
|
|
331
|
+
|
|
173
332
|
/****************************************************************************/
|
|
174
333
|
/* Singleton methods */
|
|
175
334
|
/****************************************************************************/
|
|
@@ -225,7 +384,7 @@ add_function(const gm_func_t *f, void *args)
|
|
|
225
384
|
struct map_args *a = (struct map_args *)args;
|
|
226
385
|
VALUE func, func_hash;
|
|
227
386
|
|
|
228
|
-
func = GufuncObject_alloc(a->table, f->name);
|
|
387
|
+
func = GufuncObject_alloc(a->table, f->name, GM_CPU_FUNC);
|
|
229
388
|
if (func == NULL) {
|
|
230
389
|
return -1;
|
|
231
390
|
}
|
|
@@ -236,7 +395,6 @@ add_function(const gm_func_t *f, void *args)
|
|
|
236
395
|
return 0;
|
|
237
396
|
}
|
|
238
397
|
|
|
239
|
-
/* C API call for adding functions from a gumath kernel table to */
|
|
240
398
|
int
|
|
241
399
|
rb_gumath_add_functions(VALUE module, const gm_tbl_t *tbl)
|
|
242
400
|
{
|
|
@@ -289,6 +447,9 @@ void Init_ruby_gumath(void)
|
|
|
289
447
|
|
|
290
448
|
/* Instance methods */
|
|
291
449
|
rb_define_method(cGumath_GufuncObject, "call", Gumath_GufuncObject_call,-1);
|
|
450
|
+
|
|
451
|
+
/* errors */
|
|
452
|
+
rb_eValueError = rb_define_class("ValueError", rb_eRuntimeError);
|
|
292
453
|
|
|
293
454
|
Init_gumath_functions();
|
|
294
455
|
Init_gumath_examples();
|