gumath 0.2.0dev5 → 0.2.0dev8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -43,6 +43,7 @@ from glob import glob
|
|
43
43
|
import platform
|
44
44
|
import subprocess
|
45
45
|
import shutil
|
46
|
+
import argparse
|
46
47
|
import warnings
|
47
48
|
|
48
49
|
|
@@ -55,6 +56,15 @@ LONG_DESCRIPTION = """\
|
|
55
56
|
|
56
57
|
warnings.simplefilter("ignore", UserWarning)
|
57
58
|
|
59
|
+
|
60
|
+
# Pre-parse and remove the '-j' argument from sys.argv.
|
61
|
+
parser = argparse.ArgumentParser()
|
62
|
+
parser.add_argument('-j', default=None)
|
63
|
+
values, rest = parser.parse_known_args()
|
64
|
+
PARALLEL = values.j
|
65
|
+
sys.argv = sys.argv[:1] + rest
|
66
|
+
|
67
|
+
|
58
68
|
if sys.platform == "darwin":
|
59
69
|
LIBNAME = "libgumath.dylib"
|
60
70
|
LIBSONAME = "libgumath.0.dylib"
|
@@ -154,7 +164,10 @@ if len(sys.argv) == 2:
|
|
154
164
|
path = module_path + ':' + python_path if python_path else module_path
|
155
165
|
env = os.environ.copy()
|
156
166
|
env['PYTHONPATH'] = path
|
157
|
-
ret = subprocess.call([sys.executable, "python/test_gumath.py"], env=env)
|
167
|
+
ret = subprocess.call([sys.executable, "python/test_gumath.py", "--long"], env=env)
|
168
|
+
if ret != 0:
|
169
|
+
sys.exit(ret)
|
170
|
+
ret = subprocess.call([sys.executable, "python/test_xndarray.py"], env=env)
|
158
171
|
sys.exit(ret)
|
159
172
|
elif sys.argv[1] == 'clean':
|
160
173
|
shutil.rmtree("build", ignore_errors=True)
|
@@ -173,11 +186,27 @@ if len(sys.argv) == 2:
|
|
173
186
|
else:
|
174
187
|
pass
|
175
188
|
|
189
|
+
def get_config_vars():
|
190
|
+
f = open("config.h")
|
191
|
+
config_vars = {}
|
192
|
+
for line in f:
|
193
|
+
if line.startswith("#define"):
|
194
|
+
l = line.split()
|
195
|
+
try:
|
196
|
+
config_vars[l[1]] = int(l[2])
|
197
|
+
except ValueError:
|
198
|
+
pass
|
199
|
+
elif line.startswith("/* #undef"):
|
200
|
+
l = line.split()
|
201
|
+
config_vars[l[2]] = 0
|
202
|
+
f.close()
|
203
|
+
return config_vars
|
176
204
|
|
177
205
|
def gumath_extensions():
|
178
206
|
add_include_dirs = [".", "libgumath", "ndtypes/python/ndtypes", "xnd/python/xnd"] + INCLUDES
|
179
207
|
add_library_dirs = ["libgumath", "ndtypes/libndtypes", "xnd/libxnd"] + LIBS
|
180
208
|
add_depends = []
|
209
|
+
config_vars = {}
|
181
210
|
|
182
211
|
if sys.platform == "win32":
|
183
212
|
add_libraries = ["libndtypes-0.2.0dev3.dll", "libxnd-0.2.0dev3.dll", "libgumath-0.2.0dev3.dll"]
|
@@ -199,6 +228,14 @@ def gumath_extensions():
|
|
199
228
|
os.system("vcbuild32.bat")
|
200
229
|
os.chdir("..")
|
201
230
|
else:
|
231
|
+
if BUILD_ALL:
|
232
|
+
cflags = '"-I%s -I%s"' % tuple(CONFIGURE_INCLUDES)
|
233
|
+
ldflags = '"-L%s -L%s"' % tuple(CONFIGURE_LIBS)
|
234
|
+
make = "make -j%d" % int(PARALLEL) if PARALLEL else "make"
|
235
|
+
os.system("./configure CFLAGS=%s LDFLAGS=%s && %s" % (cflags, ldflags, make))
|
236
|
+
|
237
|
+
config_vars = get_config_vars()
|
238
|
+
|
202
239
|
add_extra_compile_args = ["-Wextra", "-Wno-missing-field-initializers", "-std=c11"]
|
203
240
|
if sys.platform == "darwin":
|
204
241
|
add_libraries = ["ndtypes", "xnd", "gumath"]
|
@@ -209,10 +246,16 @@ def gumath_extensions():
|
|
209
246
|
add_extra_link_args = []
|
210
247
|
add_runtime_library_dirs = ["$ORIGIN"]
|
211
248
|
|
212
|
-
if
|
213
|
-
|
214
|
-
|
215
|
-
|
249
|
+
if config_vars["HAVE_CUDA"]:
|
250
|
+
add_libraries += ["cudart"]
|
251
|
+
|
252
|
+
for d in [
|
253
|
+
"/usr/cuda/lib",
|
254
|
+
"/usr/cuda/lib64",
|
255
|
+
"/usr/local/cuda/lib/",
|
256
|
+
"/usr/local/cuda/lib64"]:
|
257
|
+
if os.path.isdir(d):
|
258
|
+
add_library_dirs.append(d)
|
216
259
|
|
217
260
|
def gumath_ext():
|
218
261
|
sources = ["python/gumath/_gumath.c"]
|
@@ -244,6 +287,21 @@ def gumath_extensions():
|
|
244
287
|
runtime_library_dirs = add_runtime_library_dirs
|
245
288
|
)
|
246
289
|
|
290
|
+
def cuda_ext():
|
291
|
+
sources = ["python/gumath/cuda.c"]
|
292
|
+
|
293
|
+
return Extension (
|
294
|
+
"gumath.cuda",
|
295
|
+
include_dirs = add_include_dirs,
|
296
|
+
library_dirs = add_library_dirs,
|
297
|
+
depends = add_depends,
|
298
|
+
sources = sources,
|
299
|
+
libraries = add_libraries,
|
300
|
+
extra_compile_args = add_extra_compile_args,
|
301
|
+
extra_link_args = add_extra_link_args,
|
302
|
+
runtime_library_dirs = add_runtime_library_dirs
|
303
|
+
)
|
304
|
+
|
247
305
|
def examples_ext():
|
248
306
|
sources = ["python/gumath/examples.c"]
|
249
307
|
|
@@ -259,7 +317,10 @@ def gumath_extensions():
|
|
259
317
|
runtime_library_dirs = add_runtime_library_dirs
|
260
318
|
)
|
261
319
|
|
262
|
-
|
320
|
+
extensions = [gumath_ext(), functions_ext(), examples_ext()]
|
321
|
+
if config_vars.get("HAVE_CUDA"):
|
322
|
+
extensions += [cuda_ext()]
|
323
|
+
return extensions
|
263
324
|
|
264
325
|
setup (
|
265
326
|
name = "gumath",
|
@@ -0,0 +1,35 @@
|
|
1
|
+
#include <cstdio>
|
2
|
+
#include <cstdlib>
|
3
|
+
#include <cuda_runtime.h>
|
4
|
+
|
5
|
+
static void
|
6
|
+
check(cudaError_t err)
|
7
|
+
{
|
8
|
+
if (err != cudaSuccess) {
|
9
|
+
exit(1);
|
10
|
+
}
|
11
|
+
}
|
12
|
+
|
13
|
+
static int
|
14
|
+
min(int x, int y)
|
15
|
+
{
|
16
|
+
return x <= y ? x : y;
|
17
|
+
}
|
18
|
+
|
19
|
+
int main()
|
20
|
+
{
|
21
|
+
int res = INT_MAX;
|
22
|
+
cudaDeviceProp prop;
|
23
|
+
int count, i, n;
|
24
|
+
|
25
|
+
check(cudaGetDeviceCount(&count));
|
26
|
+
|
27
|
+
for (i = 0; i < count; i++) {
|
28
|
+
check(cudaGetDeviceProperties(&prop, i));
|
29
|
+
n = prop.major * 10 + prop.minor;
|
30
|
+
res = min(res, n);
|
31
|
+
}
|
32
|
+
|
33
|
+
printf("%d", res);
|
34
|
+
return 0;
|
35
|
+
}
|
@@ -34,6 +34,17 @@
|
|
34
34
|
#ifndef GUMATH_H
|
35
35
|
#define GUMATH_H
|
36
36
|
|
37
|
+
|
38
|
+
#ifdef __cplusplus
|
39
|
+
extern "C" {
|
40
|
+
#endif
|
41
|
+
|
42
|
+
#ifdef __cplusplus
|
43
|
+
#include <cstdint>
|
44
|
+
#else
|
45
|
+
#include <stdint.h>
|
46
|
+
#endif
|
47
|
+
|
37
48
|
#include "ndtypes.h"
|
38
49
|
#include "xnd.h"
|
39
50
|
|
@@ -65,7 +76,8 @@
|
|
65
76
|
#endif
|
66
77
|
|
67
78
|
|
68
|
-
#define GM_MAX_KERNELS
|
79
|
+
#define GM_MAX_KERNELS 8192
|
80
|
+
#define GM_THREAD_CUTOFF 1000000
|
69
81
|
|
70
82
|
typedef float float32_t;
|
71
83
|
typedef double float64_t;
|
@@ -74,15 +86,25 @@ typedef double float64_t;
|
|
74
86
|
typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
|
75
87
|
typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
|
76
88
|
|
77
|
-
/*
|
89
|
+
/*
|
90
|
+
* Collection of specialized kernels for a single function signature.
|
91
|
+
*
|
92
|
+
* NOTE: The specialized kernel lookup scheme is transitional and may
|
93
|
+
* be replaced by something else.
|
94
|
+
*
|
95
|
+
* This should be considered as a first version of a kernel request
|
96
|
+
* protocol.
|
97
|
+
*/
|
78
98
|
typedef struct {
|
79
|
-
ndt_t *sig;
|
99
|
+
const ndt_t *sig;
|
80
100
|
const ndt_constraint_t *constraint;
|
81
101
|
|
82
102
|
/* Xnd signatures */
|
83
|
-
gm_xnd_kernel_t
|
84
|
-
gm_xnd_kernel_t
|
85
|
-
gm_xnd_kernel_t
|
103
|
+
gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
|
104
|
+
gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
|
105
|
+
gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
|
106
|
+
gm_xnd_kernel_t C; /* C in inner dimensions */
|
107
|
+
gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
|
86
108
|
gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
|
87
109
|
|
88
110
|
/* NumPy signature */
|
@@ -99,11 +121,17 @@ typedef struct {
|
|
99
121
|
const char *name;
|
100
122
|
const char *sig;
|
101
123
|
const ndt_constraint_t *constraint;
|
124
|
+
uint32_t cap;
|
102
125
|
|
103
|
-
|
126
|
+
/* Xnd signatures */
|
127
|
+
gm_xnd_kernel_t OptC;
|
128
|
+
gm_xnd_kernel_t OptZ;
|
129
|
+
gm_xnd_kernel_t OptS;
|
104
130
|
gm_xnd_kernel_t C;
|
105
131
|
gm_xnd_kernel_t Fortran;
|
106
132
|
gm_xnd_kernel_t Xnd;
|
133
|
+
|
134
|
+
/* NumPy signature */
|
107
135
|
gm_strided_kernel_t Strided;
|
108
136
|
} gm_kernel_init_t;
|
109
137
|
|
@@ -115,7 +143,10 @@ typedef struct {
|
|
115
143
|
|
116
144
|
/* Multimethod with associated kernels */
|
117
145
|
typedef struct gm_func gm_func_t;
|
118
|
-
typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
|
146
|
+
typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
|
147
|
+
const ndt_t *in[], const int64_t li[],
|
148
|
+
int nin, int nout, bool check_broadcast,
|
149
|
+
ndt_context_t *ctx);
|
119
150
|
struct gm_func {
|
120
151
|
char *name;
|
121
152
|
gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
|
@@ -139,10 +170,10 @@ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_cont
|
|
139
170
|
GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
|
140
171
|
|
141
172
|
GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
142
|
-
const ndt_t *
|
143
|
-
ndt_context_t *ctx);
|
173
|
+
const ndt_t *types[], const int64_t li[], int nin, int nout,
|
174
|
+
bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
|
144
175
|
GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
|
145
|
-
GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
176
|
+
GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
|
146
177
|
|
147
178
|
|
148
179
|
/******************************************************************************/
|
@@ -171,6 +202,7 @@ GM_API int gm_np_map(const gm_strided_kernel_t f,
|
|
171
202
|
/* Xnd loops */
|
172
203
|
/******************************************************************************/
|
173
204
|
|
205
|
+
GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
|
174
206
|
GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
175
207
|
const int outer_dims, ndt_context_t *ctx);
|
176
208
|
|
@@ -191,10 +223,14 @@ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *sta
|
|
191
223
|
/******************************************************************************/
|
192
224
|
|
193
225
|
GM_API void gm_init(void);
|
194
|
-
GM_API int
|
195
|
-
GM_API int
|
226
|
+
GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
227
|
+
GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
228
|
+
GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
229
|
+
|
230
|
+
GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
231
|
+
GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
232
|
+
|
196
233
|
GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
197
|
-
GM_API int gm_init_bfloat16_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
198
234
|
GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
199
235
|
GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
200
236
|
GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
@@ -202,4 +238,9 @@ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
202
238
|
GM_API void gm_finalize(void);
|
203
239
|
|
204
240
|
|
241
|
+
#ifdef __cplusplus
|
242
|
+
} /* END extern "C" */
|
243
|
+
#endif
|
244
|
+
|
245
|
+
|
205
246
|
#endif /* GUMATH_H */
|
@@ -33,8 +33,11 @@
|
|
33
33
|
#define RUBY_GUMATH_H
|
34
34
|
|
35
35
|
/* Classes */
|
36
|
-
VALUE cGumath;
|
36
|
+
extern VALUE cGumath;
|
37
37
|
|
38
|
+
/* C API call for adding functions from a gumath kernel table to a Ruby module.
|
39
|
+
* Only adds CPU functions.
|
40
|
+
*/
|
38
41
|
int rb_gumath_add_functions(VALUE module, const gm_tbl_t *tbl);
|
39
42
|
#define GUMATH_FUNCTION_HASH rb_intern("@gumath_functions")
|
40
43
|
|
Binary file
|
Binary file
|
@@ -43,12 +43,14 @@ static gm_tbl_t *table = NULL;
|
|
43
43
|
/* Maximum number of threads */
|
44
44
|
static int64_t max_threads = 1;
|
45
45
|
static int initialized = 0;
|
46
|
-
|
46
|
+
VALUE cGumath;
|
47
47
|
|
48
48
|
/****************************************************************************/
|
49
49
|
/* Error handling */
|
50
50
|
/****************************************************************************/
|
51
51
|
|
52
|
+
static VALUE rb_eValueError;
|
53
|
+
|
52
54
|
VALUE
|
53
55
|
seterr(ndt_context_t *ctx)
|
54
56
|
{
|
@@ -59,117 +61,274 @@ seterr(ndt_context_t *ctx)
|
|
59
61
|
/* Instance methods */
|
60
62
|
/****************************************************************************/
|
61
63
|
|
64
|
+
/* Parse optional arguments passed to GuFuncObject#call.
|
65
|
+
*
|
66
|
+
* Populates the rbstack with all the input arguments. Then checks whether
|
67
|
+
* the 'out' kwarg has been specified and populates the rest of rbstack
|
68
|
+
* with contents of 'out'.
|
69
|
+
*/
|
70
|
+
void
|
71
|
+
parse_args(VALUE *rbstack, int *rb_nin, int *rb_nout, int *rb_nargs, int noptargs,
|
72
|
+
VALUE *argv, VALUE out)
|
73
|
+
{
|
74
|
+
size_t nin = noptargs, nout;
|
75
|
+
|
76
|
+
if (noptargs == 0) {
|
77
|
+
*rb_nin = 0;
|
78
|
+
}
|
79
|
+
|
80
|
+
for (int i = 0; i < nin; i++) {
|
81
|
+
if (!rb_is_a(argv[i], cXND)) {
|
82
|
+
rb_raise(rb_eArgError, "expected xnd arguments.");
|
83
|
+
}
|
84
|
+
rbstack[i] = argv[i];
|
85
|
+
}
|
86
|
+
|
87
|
+
if (out == Qnil) {
|
88
|
+
nout = 0;
|
89
|
+
}
|
90
|
+
else {
|
91
|
+
if (rb_xnd_check_type(out)) {
|
92
|
+
nout = 1;
|
93
|
+
if (nin + nout > NDT_MAX_ARGS) {
|
94
|
+
rb_raise(rb_eTypeError, "max number of arguments is %d, got %ld.",
|
95
|
+
NDT_MAX_ARGS, nin+nout);
|
96
|
+
}
|
97
|
+
rbstack[nin] = out;
|
98
|
+
}
|
99
|
+
else if (RB_TYPE_P(out, T_ARRAY)) {
|
100
|
+
nout = rb_ary_size(out);
|
101
|
+
if (nout > NDT_MAX_ARGS || nin+nout > NDT_MAX_ARGS) {
|
102
|
+
rb_raise(rb_eTypeError, "max number of arguments is %d, got %ld.",
|
103
|
+
NDT_MAX_ARGS, nin+nout);
|
104
|
+
}
|
105
|
+
|
106
|
+
for (int i = 0; i < nout; ++i) {
|
107
|
+
VALUE v = rb_ary_entry(out, i);
|
108
|
+
if (!rb_is_a(v, cXND)) {
|
109
|
+
rb_raise(rb_eTypeError, "expected xnd argument in all elements of out array.");
|
110
|
+
}
|
111
|
+
rbstack[nin+i] = v;
|
112
|
+
}
|
113
|
+
}
|
114
|
+
else {
|
115
|
+
rb_raise(rb_eTypeError, "'out' argument must of type XND or Array of XND objects.");
|
116
|
+
}
|
117
|
+
}
|
118
|
+
|
119
|
+
*rb_nin = (int)nin;
|
120
|
+
*rb_nout = (int)nout;
|
121
|
+
*rb_nargs = (int)nin + (int)nout;
|
122
|
+
}
|
123
|
+
|
124
|
+
/* Implement call method on the GufuncObject call. */
|
62
125
|
static VALUE
|
63
126
|
Gumath_GufuncObject_call(int argc, VALUE *argv, VALUE self)
|
64
127
|
{
|
128
|
+
VALUE out = Qnil;
|
129
|
+
VALUE dt = Qnil;
|
130
|
+
VALUE cls = Qnil;
|
131
|
+
|
65
132
|
NDT_STATIC_CONTEXT(ctx);
|
133
|
+
VALUE rbstack[NDT_MAX_ARGS], opts = Qnil;
|
66
134
|
xnd_t stack[NDT_MAX_ARGS];
|
67
|
-
const ndt_t *
|
135
|
+
const ndt_t *types[NDT_MAX_ARGS];
|
68
136
|
gm_kernel_t kernel;
|
69
137
|
ndt_apply_spec_t spec = ndt_apply_spec_empty;
|
70
|
-
|
71
|
-
|
72
|
-
int
|
73
|
-
|
138
|
+
int64_t li[NDT_MAX_ARGS];
|
139
|
+
NdtObject *dt_p;
|
140
|
+
int k;
|
141
|
+
ndt_t *dtype = NULL;
|
142
|
+
int nin = argc, nout, nargs;
|
143
|
+
bool have_cpu_device = false;
|
144
|
+
GufuncObject * self_p;
|
145
|
+
bool check_broadcast = true, enable_threads = true;
|
74
146
|
|
75
147
|
if (argc > NDT_MAX_ARGS) {
|
76
148
|
rb_raise(rb_eArgError, "too many arguments.");
|
77
149
|
}
|
150
|
+
|
151
|
+
/* parse keyword arguments. */
|
152
|
+
int noptargs = argc;
|
153
|
+
for (int i = 0; i < argc; ++i) {
|
154
|
+
if (RB_TYPE_P(argv[i], T_HASH)) {
|
155
|
+
noptargs = i;
|
156
|
+
opts = argv[i];
|
157
|
+
break;
|
158
|
+
}
|
159
|
+
}
|
160
|
+
|
161
|
+
if (NIL_P(opts)) { opts = rb_hash_new(); }
|
162
|
+
|
163
|
+
out = rb_hash_aref(opts, ID2SYM(rb_intern("out")));
|
164
|
+
dt = rb_hash_aref(opts, ID2SYM(rb_intern("dtype")));
|
165
|
+
cls = rb_hash_aref(opts, ID2SYM(rb_intern("cls")));
|
166
|
+
|
167
|
+
if (NIL_P(cls)) { cls = cXND; }
|
168
|
+
if (!NIL_P(dt)) {
|
169
|
+
if (!NIL_P(out)) {
|
170
|
+
rb_raise(rb_eArgError, "the 'out' and 'dtype' arguments are mutually exclusive.");
|
171
|
+
}
|
78
172
|
|
79
|
-
|
80
|
-
|
81
|
-
if (!rb_xnd_check_type(argv[i])) {
|
82
|
-
VALUE str = rb_funcall(argv[i], rb_intern("inspect"), 0, NULL);
|
83
|
-
rb_raise(rb_eArgError, "Args must be XND. Received %s.", RSTRING_PTR(str));
|
173
|
+
if (!rb_ndtypes_check_type(dt)) {
|
174
|
+
rb_raise(rb_eArgError, "'dtype' argument must be an NDT object.");
|
84
175
|
}
|
176
|
+
dtype = (ndt_t *)rb_ndtypes_const_ndt(dt);
|
177
|
+
ndt_incref(dtype);
|
178
|
+
}
|
85
179
|
|
86
|
-
|
87
|
-
|
180
|
+
if (!rb_klass_has_ancestor(cls, cXND)) {
|
181
|
+
rb_raise(rb_eTypeError, "the 'cls' argument must be a subtype of 'xnd'.");
|
88
182
|
}
|
89
183
|
|
90
|
-
/*
|
184
|
+
/* parse leading optional arguments */
|
185
|
+
parse_args(rbstack, &nin, &nout, &nargs, noptargs, argv, out);
|
186
|
+
|
187
|
+
for (k = 0; k < nargs; ++k) {
|
188
|
+
if (!rb_xnd_is_cuda_managed(rbstack[k])) {
|
189
|
+
have_cpu_device = true;
|
190
|
+
}
|
191
|
+
|
192
|
+
stack[k] = *rb_xnd_const_xnd(rbstack[k]);
|
193
|
+
types[k] = stack[k].type;
|
194
|
+
li[k] = stack[k].index;
|
195
|
+
}
|
91
196
|
GET_GUOBJ(self, self_p);
|
197
|
+
|
198
|
+
if (have_cpu_device) {
|
199
|
+
if (self_p->flags & GM_CUDA_MANAGED_FUNC) {
|
200
|
+
rb_raise(rb_eValueError,
|
201
|
+
"cannot run a cuda function on xnd objects with cpu memory.");
|
202
|
+
}
|
203
|
+
}
|
204
|
+
|
205
|
+
kernel = gm_select(&spec, self_p->table, self_p->name, types, li, nin, nout,
|
206
|
+
nout && check_broadcast, stack, &ctx);
|
92
207
|
|
93
|
-
kernel = gm_select(&spec, self_p->table, self_p->name, in_types, argc, stack, &ctx);
|
94
208
|
if (kernel.set == NULL) {
|
95
209
|
seterr(&ctx);
|
96
210
|
raise_error();
|
97
211
|
}
|
98
212
|
|
99
|
-
if (
|
100
|
-
|
101
|
-
|
213
|
+
if (dtype) {
|
214
|
+
if (spec.nout != 1) {
|
215
|
+
ndt_err_format(&ctx, NDT_TypeError,
|
216
|
+
"the 'dtype' argument is only supported for a single "
|
217
|
+
"return value.");
|
218
|
+
ndt_apply_spec_clear(&spec);
|
219
|
+
ndt_decref(dtype);
|
220
|
+
seterr(&ctx);
|
221
|
+
raise_error();
|
102
222
|
}
|
103
|
-
}
|
104
223
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
stack[nin+i] = *rb_xnd_const_xnd(x);
|
224
|
+
const ndt_t *u = spec.types[spec.nin];
|
225
|
+
const ndt_t *v = ndt_copy_contiguous_dtype(u, dtype, 0, &ctx);
|
226
|
+
|
227
|
+
ndt_apply_spec_clear(&spec);
|
228
|
+
ndt_decref(dtype);
|
229
|
+
|
230
|
+
if (v == NULL) {
|
231
|
+
seterr(&ctx);
|
232
|
+
raise_error();
|
115
233
|
}
|
116
|
-
|
117
|
-
|
118
|
-
|
234
|
+
|
235
|
+
types[nin] = v;
|
236
|
+
kernel = gm_select(&spec, self_p->table, self_p->name, types, li, nin, 1,
|
237
|
+
1 && check_broadcast, stack, &ctx);
|
238
|
+
if (kernel.set == NULL) {
|
239
|
+
seterr(&ctx);
|
240
|
+
raise_error();
|
119
241
|
}
|
120
242
|
}
|
121
243
|
|
122
|
-
/*
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
244
|
+
/*
|
245
|
+
* Replace args/kwargs types with types after substitution and broadcasting.
|
246
|
+
* This includes 'out' types, if explicitly passed as kwargs.
|
247
|
+
*/
|
248
|
+
for (int i = 0; i < spec.nargs; ++i) {
|
249
|
+
stack[i].type = spec.types[i];
|
250
|
+
}
|
251
|
+
|
252
|
+
if (nout == 0) {
|
253
|
+
/* 'out' types have been inferred, create new XndObjects. */
|
254
|
+
VALUE x;
|
255
|
+
for (int i = 0; i < spec.nout; ++i) {
|
256
|
+
if (ndt_is_concrete(spec.types[nin+i])) {
|
257
|
+
uint32_t flags = self_p->flags == GM_CUDA_MANAGED_FUNC ? XND_CUDA_MANAGED : 0;
|
258
|
+
x = rb_xnd_empty_from_type(cls, spec.types[nin+i], flags);
|
259
|
+
rbstack[nin+i] = x;
|
260
|
+
stack[nin+i] = *rb_xnd_const_xnd(x);
|
261
|
+
}
|
262
|
+
else {
|
263
|
+
rb_raise(rb_eValueError,
|
264
|
+
"args with abstract types are temporarily disabled.");
|
265
|
+
}
|
266
|
+
}
|
128
267
|
}
|
268
|
+
|
269
|
+
if (self_p->flags == GM_CUDA_MANAGED_FUNC) {
|
270
|
+
#ifdef HAVE_CUDA
|
271
|
+
/* populate with CUDA specific stuff */
|
129
272
|
#else
|
130
|
-
|
273
|
+
ndt_err_format(&ctx, NDT_RuntimeError,
|
274
|
+
"internal error: GM_CUDA_MANAGED_FUNC set in a build without cuda support");
|
275
|
+
ndt_apply_spec_clear(&spec);
|
131
276
|
seterr(&ctx);
|
132
277
|
raise_error();
|
278
|
+
#endif // HAVE_CUDA
|
133
279
|
}
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
if (x == NULL) {
|
143
|
-
for (k = i+i; k < spec.nout; k++) {
|
144
|
-
if (ndt_is_abstract(spec.out[k])) {
|
145
|
-
xnd_del_buffer(&stack[nin+k], XND_OWN_ALL);
|
146
|
-
}
|
147
|
-
}
|
148
|
-
}
|
149
|
-
result[i] = x;
|
150
|
-
}
|
151
|
-
}
|
280
|
+
else {
|
281
|
+
#ifdef HAVE_PTHREAD_H
|
282
|
+
const int rounding = fegetround();
|
283
|
+
fesetround(FE_TONEAREST);
|
284
|
+
|
285
|
+
const int64_t N = enable_threads ? max_threads : 1;
|
286
|
+
const int ret = gm_apply_thread(&kernel, stack, spec.outer_dims, N, &ctx);
|
287
|
+
fesetround(rounding);
|
152
288
|
|
153
|
-
|
154
|
-
|
155
|
-
|
289
|
+
if (ret < 0) {
|
290
|
+
ndt_apply_spec_clear(&spec);
|
291
|
+
seterr(&ctx);
|
292
|
+
raise_error();
|
156
293
|
}
|
294
|
+
#else
|
295
|
+
const int rounding = fegetround();
|
296
|
+
fesetround(FE_TONEAREST);
|
297
|
+
|
298
|
+
const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
|
299
|
+
fesetround(rounding);
|
300
|
+
|
301
|
+
if (ret < 0) {
|
302
|
+
ndt_apply_spec_clear(&spec);
|
303
|
+
seterr(&ctx);
|
304
|
+
raise_error();
|
305
|
+
}
|
306
|
+
#endif // HAVE_PTHREAD_H
|
157
307
|
}
|
158
308
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
309
|
+
nin = spec.nin;
|
310
|
+
nout = spec.nout;
|
311
|
+
nargs = spec.nargs;
|
312
|
+
ndt_apply_spec_clear(&spec);
|
313
|
+
|
314
|
+
switch (nout) {
|
315
|
+
case 0: {
|
316
|
+
return Qnil;
|
317
|
+
}
|
318
|
+
case 1: {
|
319
|
+
return rbstack[nin];
|
320
|
+
}
|
163
321
|
default: {
|
164
|
-
VALUE
|
165
|
-
for (i = 0; i <
|
166
|
-
rb_ary_store(
|
322
|
+
VALUE arr = rb_ary_new2(nout);
|
323
|
+
for (int i = 0; i < nout; ++i) {
|
324
|
+
rb_ary_store(arr, i, rbstack[nin+i]);
|
167
325
|
}
|
168
|
-
return
|
326
|
+
return arr;
|
169
327
|
}
|
170
328
|
}
|
171
329
|
}
|
172
330
|
|
331
|
+
|
173
332
|
/****************************************************************************/
|
174
333
|
/* Singleton methods */
|
175
334
|
/****************************************************************************/
|
@@ -225,7 +384,7 @@ add_function(const gm_func_t *f, void *args)
|
|
225
384
|
struct map_args *a = (struct map_args *)args;
|
226
385
|
VALUE func, func_hash;
|
227
386
|
|
228
|
-
func = GufuncObject_alloc(a->table, f->name);
|
387
|
+
func = GufuncObject_alloc(a->table, f->name, GM_CPU_FUNC);
|
229
388
|
if (func == NULL) {
|
230
389
|
return -1;
|
231
390
|
}
|
@@ -236,7 +395,6 @@ add_function(const gm_func_t *f, void *args)
|
|
236
395
|
return 0;
|
237
396
|
}
|
238
397
|
|
239
|
-
/* C API call for adding functions from a gumath kernel table to */
|
240
398
|
int
|
241
399
|
rb_gumath_add_functions(VALUE module, const gm_tbl_t *tbl)
|
242
400
|
{
|
@@ -289,6 +447,9 @@ void Init_ruby_gumath(void)
|
|
289
447
|
|
290
448
|
/* Instance methods */
|
291
449
|
rb_define_method(cGumath_GufuncObject, "call", Gumath_GufuncObject_call,-1);
|
450
|
+
|
451
|
+
/* errors */
|
452
|
+
rb_eValueError = rb_define_class("ValueError", rb_eRuntimeError);
|
292
453
|
|
293
454
|
Init_gumath_functions();
|
294
455
|
Init_gumath_examples();
|