gumath 0.2.0dev5 → 0.2.0dev8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -43,6 +43,7 @@ from glob import glob
43
43
  import platform
44
44
  import subprocess
45
45
  import shutil
46
+ import argparse
46
47
  import warnings
47
48
 
48
49
 
@@ -55,6 +56,15 @@ LONG_DESCRIPTION = """\
55
56
 
56
57
  warnings.simplefilter("ignore", UserWarning)
57
58
 
59
+
60
+ # Pre-parse and remove the '-j' argument from sys.argv.
61
+ parser = argparse.ArgumentParser()
62
+ parser.add_argument('-j', default=None)
63
+ values, rest = parser.parse_known_args()
64
+ PARALLEL = values.j
65
+ sys.argv = sys.argv[:1] + rest
66
+
67
+
58
68
  if sys.platform == "darwin":
59
69
  LIBNAME = "libgumath.dylib"
60
70
  LIBSONAME = "libgumath.0.dylib"
@@ -154,7 +164,10 @@ if len(sys.argv) == 2:
154
164
  path = module_path + ':' + python_path if python_path else module_path
155
165
  env = os.environ.copy()
156
166
  env['PYTHONPATH'] = path
157
- ret = subprocess.call([sys.executable, "python/test_gumath.py"], env=env)
167
+ ret = subprocess.call([sys.executable, "python/test_gumath.py", "--long"], env=env)
168
+ if ret != 0:
169
+ sys.exit(ret)
170
+ ret = subprocess.call([sys.executable, "python/test_xndarray.py"], env=env)
158
171
  sys.exit(ret)
159
172
  elif sys.argv[1] == 'clean':
160
173
  shutil.rmtree("build", ignore_errors=True)
@@ -173,11 +186,27 @@ if len(sys.argv) == 2:
173
186
  else:
174
187
  pass
175
188
 
189
+ def get_config_vars():
190
+ f = open("config.h")
191
+ config_vars = {}
192
+ for line in f:
193
+ if line.startswith("#define"):
194
+ l = line.split()
195
+ try:
196
+ config_vars[l[1]] = int(l[2])
197
+ except ValueError:
198
+ pass
199
+ elif line.startswith("/* #undef"):
200
+ l = line.split()
201
+ config_vars[l[2]] = 0
202
+ f.close()
203
+ return config_vars
176
204
 
177
205
  def gumath_extensions():
178
206
  add_include_dirs = [".", "libgumath", "ndtypes/python/ndtypes", "xnd/python/xnd"] + INCLUDES
179
207
  add_library_dirs = ["libgumath", "ndtypes/libndtypes", "xnd/libxnd"] + LIBS
180
208
  add_depends = []
209
+ config_vars = {}
181
210
 
182
211
  if sys.platform == "win32":
183
212
  add_libraries = ["libndtypes-0.2.0dev3.dll", "libxnd-0.2.0dev3.dll", "libgumath-0.2.0dev3.dll"]
@@ -199,6 +228,14 @@ def gumath_extensions():
199
228
  os.system("vcbuild32.bat")
200
229
  os.chdir("..")
201
230
  else:
231
+ if BUILD_ALL:
232
+ cflags = '"-I%s -I%s"' % tuple(CONFIGURE_INCLUDES)
233
+ ldflags = '"-L%s -L%s"' % tuple(CONFIGURE_LIBS)
234
+ make = "make -j%d" % int(PARALLEL) if PARALLEL else "make"
235
+ os.system("./configure CFLAGS=%s LDFLAGS=%s && %s" % (cflags, ldflags, make))
236
+
237
+ config_vars = get_config_vars()
238
+
202
239
  add_extra_compile_args = ["-Wextra", "-Wno-missing-field-initializers", "-std=c11"]
203
240
  if sys.platform == "darwin":
204
241
  add_libraries = ["ndtypes", "xnd", "gumath"]
@@ -209,10 +246,16 @@ def gumath_extensions():
209
246
  add_extra_link_args = []
210
247
  add_runtime_library_dirs = ["$ORIGIN"]
211
248
 
212
- if BUILD_ALL:
213
- cflags = '"-I%s -I%s"' % tuple(CONFIGURE_INCLUDES)
214
- ldflags = '"-L%s -L%s"' % tuple(CONFIGURE_LIBS)
215
- os.system("./configure CFLAGS=%s LDFLAGS=%s && make" % (cflags, ldflags))
249
+ if config_vars["HAVE_CUDA"]:
250
+ add_libraries += ["cudart"]
251
+
252
+ for d in [
253
+ "/usr/cuda/lib",
254
+ "/usr/cuda/lib64",
255
+ "/usr/local/cuda/lib/",
256
+ "/usr/local/cuda/lib64"]:
257
+ if os.path.isdir(d):
258
+ add_library_dirs.append(d)
216
259
 
217
260
  def gumath_ext():
218
261
  sources = ["python/gumath/_gumath.c"]
@@ -244,6 +287,21 @@ def gumath_extensions():
244
287
  runtime_library_dirs = add_runtime_library_dirs
245
288
  )
246
289
 
290
+ def cuda_ext():
291
+ sources = ["python/gumath/cuda.c"]
292
+
293
+ return Extension (
294
+ "gumath.cuda",
295
+ include_dirs = add_include_dirs,
296
+ library_dirs = add_library_dirs,
297
+ depends = add_depends,
298
+ sources = sources,
299
+ libraries = add_libraries,
300
+ extra_compile_args = add_extra_compile_args,
301
+ extra_link_args = add_extra_link_args,
302
+ runtime_library_dirs = add_runtime_library_dirs
303
+ )
304
+
247
305
  def examples_ext():
248
306
  sources = ["python/gumath/examples.c"]
249
307
 
@@ -259,7 +317,10 @@ def gumath_extensions():
259
317
  runtime_library_dirs = add_runtime_library_dirs
260
318
  )
261
319
 
262
- return [gumath_ext(), functions_ext(), examples_ext()]
320
+ extensions = [gumath_ext(), functions_ext(), examples_ext()]
321
+ if config_vars.get("HAVE_CUDA"):
322
+ extensions += [cuda_ext()]
323
+ return extensions
263
324
 
264
325
  setup (
265
326
  name = "gumath",
@@ -0,0 +1,35 @@
1
+ #include <cstdio>
2
+ #include <cstdlib>
3
+ #include <cuda_runtime.h>
4
+
5
+ static void
6
+ check(cudaError_t err)
7
+ {
8
+ if (err != cudaSuccess) {
9
+ exit(1);
10
+ }
11
+ }
12
+
13
+ static int
14
+ min(int x, int y)
15
+ {
16
+ return x <= y ? x : y;
17
+ }
18
+
19
+ int main()
20
+ {
21
+ int res = INT_MAX;
22
+ cudaDeviceProp prop;
23
+ int count, i, n;
24
+
25
+ check(cudaGetDeviceCount(&count));
26
+
27
+ for (i = 0; i < count; i++) {
28
+ check(cudaGetDeviceProperties(&prop, i));
29
+ n = prop.major * 10 + prop.minor;
30
+ res = min(res, n);
31
+ }
32
+
33
+ printf("%d", res);
34
+ return 0;
35
+ }
@@ -34,6 +34,17 @@
34
34
  #ifndef GUMATH_H
35
35
  #define GUMATH_H
36
36
 
37
+
38
+ #ifdef __cplusplus
39
+ extern "C" {
40
+ #endif
41
+
42
+ #ifdef __cplusplus
43
+ #include <cstdint>
44
+ #else
45
+ #include <stdint.h>
46
+ #endif
47
+
37
48
  #include "ndtypes.h"
38
49
  #include "xnd.h"
39
50
 
@@ -65,7 +76,8 @@
65
76
  #endif
66
77
 
67
78
 
68
- #define GM_MAX_KERNELS 512
79
+ #define GM_MAX_KERNELS 8192
80
+ #define GM_THREAD_CUTOFF 1000000
69
81
 
70
82
  typedef float float32_t;
71
83
  typedef double float64_t;
@@ -74,15 +86,25 @@ typedef double float64_t;
74
86
  typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
75
87
  typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
76
88
 
77
- /* Collection of specialized kernels for a single function signature. */
89
+ /*
90
+ * Collection of specialized kernels for a single function signature.
91
+ *
92
+ * NOTE: The specialized kernel lookup scheme is transitional and may
93
+ * be replaced by something else.
94
+ *
95
+ * This should be considered as a first version of a kernel request
96
+ * protocol.
97
+ */
78
98
  typedef struct {
79
- ndt_t *sig;
99
+ const ndt_t *sig;
80
100
  const ndt_constraint_t *constraint;
81
101
 
82
102
  /* Xnd signatures */
83
- gm_xnd_kernel_t Opt; /* dispatch ensures elementwise, at least 1D, contiguous in last dimensions */
84
- gm_xnd_kernel_t C; /* dispatch ensures c-contiguous in inner dimensions */
85
- gm_xnd_kernel_t Fortran; /* dispatch ensures f-contiguous in inner dimensions */
103
+ gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
104
+ gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
105
+ gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
106
+ gm_xnd_kernel_t C; /* C in inner dimensions */
107
+ gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
86
108
  gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
87
109
 
88
110
  /* NumPy signature */
@@ -99,11 +121,17 @@ typedef struct {
99
121
  const char *name;
100
122
  const char *sig;
101
123
  const ndt_constraint_t *constraint;
124
+ uint32_t cap;
102
125
 
103
- gm_xnd_kernel_t Opt;
126
+ /* Xnd signatures */
127
+ gm_xnd_kernel_t OptC;
128
+ gm_xnd_kernel_t OptZ;
129
+ gm_xnd_kernel_t OptS;
104
130
  gm_xnd_kernel_t C;
105
131
  gm_xnd_kernel_t Fortran;
106
132
  gm_xnd_kernel_t Xnd;
133
+
134
+ /* NumPy signature */
107
135
  gm_strided_kernel_t Strided;
108
136
  } gm_kernel_init_t;
109
137
 
@@ -115,7 +143,10 @@ typedef struct {
115
143
 
116
144
  /* Multimethod with associated kernels */
117
145
  typedef struct gm_func gm_func_t;
118
- typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *in[], int nin, ndt_context_t *ctx);
146
+ typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
147
+ const ndt_t *in[], const int64_t li[],
148
+ int nin, int nout, bool check_broadcast,
149
+ ndt_context_t *ctx);
119
150
  struct gm_func {
120
151
  char *name;
121
152
  gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
@@ -139,10 +170,10 @@ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_cont
139
170
  GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
140
171
 
141
172
  GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
142
- const ndt_t *in_types[], int nin, const xnd_t args[],
143
- ndt_context_t *ctx);
173
+ const ndt_t *types[], const int64_t li[], int nin, int nout,
174
+ bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
144
175
  GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
145
- GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, uint32_t flags, const int64_t nthreads, ndt_context_t *ctx);
176
+ GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
146
177
 
147
178
 
148
179
  /******************************************************************************/
@@ -171,6 +202,7 @@ GM_API int gm_np_map(const gm_strided_kernel_t f,
171
202
  /* Xnd loops */
172
203
  /******************************************************************************/
173
204
 
205
+ GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
174
206
  GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
175
207
  const int outer_dims, ndt_context_t *ctx);
176
208
 
@@ -191,10 +223,14 @@ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *sta
191
223
  /******************************************************************************/
192
224
 
193
225
  GM_API void gm_init(void);
194
- GM_API int gm_init_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
195
- GM_API int gm_init_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
226
+ GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
227
+ GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
228
+ GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
229
+
230
+ GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
231
+ GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
232
+
196
233
  GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
197
- GM_API int gm_init_bfloat16_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
198
234
  GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
199
235
  GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
200
236
  GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
@@ -202,4 +238,9 @@ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
202
238
  GM_API void gm_finalize(void);
203
239
 
204
240
 
241
+ #ifdef __cplusplus
242
+ } /* END extern "C" */
243
+ #endif
244
+
245
+
205
246
  #endif /* GUMATH_H */
@@ -33,8 +33,11 @@
33
33
  #define RUBY_GUMATH_H
34
34
 
35
35
  /* Classes */
36
- VALUE cGumath;
36
+ extern VALUE cGumath;
37
37
 
38
+ /* C API call for adding functions from a gumath kernel table to a Ruby module.
39
+ * Only adds CPU functions.
40
+ */
38
41
  int rb_gumath_add_functions(VALUE module, const gm_tbl_t *tbl);
39
42
  #define GUMATH_FUNCTION_HASH rb_intern("@gumath_functions")
40
43
 
@@ -43,12 +43,14 @@ static gm_tbl_t *table = NULL;
43
43
  /* Maximum number of threads */
44
44
  static int64_t max_threads = 1;
45
45
  static int initialized = 0;
46
- extern VALUE cGumath;
46
+ VALUE cGumath;
47
47
 
48
48
  /****************************************************************************/
49
49
  /* Error handling */
50
50
  /****************************************************************************/
51
51
 
52
+ static VALUE rb_eValueError;
53
+
52
54
  VALUE
53
55
  seterr(ndt_context_t *ctx)
54
56
  {
@@ -59,117 +61,274 @@ seterr(ndt_context_t *ctx)
59
61
  /* Instance methods */
60
62
  /****************************************************************************/
61
63
 
64
+ /* Parse optional arguments passed to GuFuncObject#call.
65
+ *
66
+ * Populates the rbstack with all the input arguments. Then checks whether
67
+ * the 'out' kwarg has been specified and populates the rest of rbstack
68
+ * with contents of 'out'.
69
+ */
70
+ void
71
+ parse_args(VALUE *rbstack, int *rb_nin, int *rb_nout, int *rb_nargs, int noptargs,
72
+ VALUE *argv, VALUE out)
73
+ {
74
+ size_t nin = noptargs, nout;
75
+
76
+ if (noptargs == 0) {
77
+ *rb_nin = 0;
78
+ }
79
+
80
+ for (int i = 0; i < nin; i++) {
81
+ if (!rb_is_a(argv[i], cXND)) {
82
+ rb_raise(rb_eArgError, "expected xnd arguments.");
83
+ }
84
+ rbstack[i] = argv[i];
85
+ }
86
+
87
+ if (out == Qnil) {
88
+ nout = 0;
89
+ }
90
+ else {
91
+ if (rb_xnd_check_type(out)) {
92
+ nout = 1;
93
+ if (nin + nout > NDT_MAX_ARGS) {
94
+ rb_raise(rb_eTypeError, "max number of arguments is %d, got %ld.",
95
+ NDT_MAX_ARGS, nin+nout);
96
+ }
97
+ rbstack[nin] = out;
98
+ }
99
+ else if (RB_TYPE_P(out, T_ARRAY)) {
100
+ nout = rb_ary_size(out);
101
+ if (nout > NDT_MAX_ARGS || nin+nout > NDT_MAX_ARGS) {
102
+ rb_raise(rb_eTypeError, "max number of arguments is %d, got %ld.",
103
+ NDT_MAX_ARGS, nin+nout);
104
+ }
105
+
106
+ for (int i = 0; i < nout; ++i) {
107
+ VALUE v = rb_ary_entry(out, i);
108
+ if (!rb_is_a(v, cXND)) {
109
+ rb_raise(rb_eTypeError, "expected xnd argument in all elements of out array.");
110
+ }
111
+ rbstack[nin+i] = v;
112
+ }
113
+ }
114
+ else {
115
+ rb_raise(rb_eTypeError, "'out' argument must of type XND or Array of XND objects.");
116
+ }
117
+ }
118
+
119
+ *rb_nin = (int)nin;
120
+ *rb_nout = (int)nout;
121
+ *rb_nargs = (int)nin + (int)nout;
122
+ }
123
+
124
+ /* Implement call method on the GufuncObject call. */
62
125
  static VALUE
63
126
  Gumath_GufuncObject_call(int argc, VALUE *argv, VALUE self)
64
127
  {
128
+ VALUE out = Qnil;
129
+ VALUE dt = Qnil;
130
+ VALUE cls = Qnil;
131
+
65
132
  NDT_STATIC_CONTEXT(ctx);
133
+ VALUE rbstack[NDT_MAX_ARGS], opts = Qnil;
66
134
  xnd_t stack[NDT_MAX_ARGS];
67
- const ndt_t *in_types[NDT_MAX_ARGS];
135
+ const ndt_t *types[NDT_MAX_ARGS];
68
136
  gm_kernel_t kernel;
69
137
  ndt_apply_spec_t spec = ndt_apply_spec_empty;
70
- GufuncObject *self_p;
71
- VALUE result[NDT_MAX_ARGS];
72
- int i, k;
73
- size_t nin = argc;
138
+ int64_t li[NDT_MAX_ARGS];
139
+ NdtObject *dt_p;
140
+ int k;
141
+ ndt_t *dtype = NULL;
142
+ int nin = argc, nout, nargs;
143
+ bool have_cpu_device = false;
144
+ GufuncObject * self_p;
145
+ bool check_broadcast = true, enable_threads = true;
74
146
 
75
147
  if (argc > NDT_MAX_ARGS) {
76
148
  rb_raise(rb_eArgError, "too many arguments.");
77
149
  }
150
+
151
+ /* parse keyword arguments. */
152
+ int noptargs = argc;
153
+ for (int i = 0; i < argc; ++i) {
154
+ if (RB_TYPE_P(argv[i], T_HASH)) {
155
+ noptargs = i;
156
+ opts = argv[i];
157
+ break;
158
+ }
159
+ }
160
+
161
+ if (NIL_P(opts)) { opts = rb_hash_new(); }
162
+
163
+ out = rb_hash_aref(opts, ID2SYM(rb_intern("out")));
164
+ dt = rb_hash_aref(opts, ID2SYM(rb_intern("dtype")));
165
+ cls = rb_hash_aref(opts, ID2SYM(rb_intern("cls")));
166
+
167
+ if (NIL_P(cls)) { cls = cXND; }
168
+ if (!NIL_P(dt)) {
169
+ if (!NIL_P(out)) {
170
+ rb_raise(rb_eArgError, "the 'out' and 'dtype' arguments are mutually exclusive.");
171
+ }
78
172
 
79
- /* Prepare arguments for sending into gumath function. */
80
- for (i = 0; i < argc; i++) {
81
- if (!rb_xnd_check_type(argv[i])) {
82
- VALUE str = rb_funcall(argv[i], rb_intern("inspect"), 0, NULL);
83
- rb_raise(rb_eArgError, "Args must be XND. Received %s.", RSTRING_PTR(str));
173
+ if (!rb_ndtypes_check_type(dt)) {
174
+ rb_raise(rb_eArgError, "'dtype' argument must be an NDT object.");
84
175
  }
176
+ dtype = (ndt_t *)rb_ndtypes_const_ndt(dt);
177
+ ndt_incref(dtype);
178
+ }
85
179
 
86
- stack[i] = *rb_xnd_const_xnd(argv[i]);
87
- in_types[i] = stack[i].type;
180
+ if (!rb_klass_has_ancestor(cls, cXND)) {
181
+ rb_raise(rb_eTypeError, "the 'cls' argument must be a subtype of 'xnd'.");
88
182
  }
89
183
 
90
- /* Select the gumath function to be called from the function table. */
184
+ /* parse leading optional arguments */
185
+ parse_args(rbstack, &nin, &nout, &nargs, noptargs, argv, out);
186
+
187
+ for (k = 0; k < nargs; ++k) {
188
+ if (!rb_xnd_is_cuda_managed(rbstack[k])) {
189
+ have_cpu_device = true;
190
+ }
191
+
192
+ stack[k] = *rb_xnd_const_xnd(rbstack[k]);
193
+ types[k] = stack[k].type;
194
+ li[k] = stack[k].index;
195
+ }
91
196
  GET_GUOBJ(self, self_p);
197
+
198
+ if (have_cpu_device) {
199
+ if (self_p->flags & GM_CUDA_MANAGED_FUNC) {
200
+ rb_raise(rb_eValueError,
201
+ "cannot run a cuda function on xnd objects with cpu memory.");
202
+ }
203
+ }
204
+
205
+ kernel = gm_select(&spec, self_p->table, self_p->name, types, li, nin, nout,
206
+ nout && check_broadcast, stack, &ctx);
92
207
 
93
- kernel = gm_select(&spec, self_p->table, self_p->name, in_types, argc, stack, &ctx);
94
208
  if (kernel.set == NULL) {
95
209
  seterr(&ctx);
96
210
  raise_error();
97
211
  }
98
212
 
99
- if (spec.nbroadcast > 0) {
100
- for (i = 0; i < argc; i++) {
101
- stack[i].type = spec.broadcast[i];
213
+ if (dtype) {
214
+ if (spec.nout != 1) {
215
+ ndt_err_format(&ctx, NDT_TypeError,
216
+ "the 'dtype' argument is only supported for a single "
217
+ "return value.");
218
+ ndt_apply_spec_clear(&spec);
219
+ ndt_decref(dtype);
220
+ seterr(&ctx);
221
+ raise_error();
102
222
  }
103
- }
104
223
 
105
- /* Populate output values with empty XND objects. */
106
- for (i = 0; i < spec.nout; i++) {
107
- if (ndt_is_concrete(spec.out[i])) {
108
- VALUE x = rb_xnd_empty_from_type(spec.out[i]);
109
- if (x == NULL) {
110
- ndt_apply_spec_clear(&spec);
111
- rb_raise(rb_eNoMemError, "could not allocate empty XND object.");
112
- }
113
- result[i] = x;
114
- stack[nin+i] = *rb_xnd_const_xnd(x);
224
+ const ndt_t *u = spec.types[spec.nin];
225
+ const ndt_t *v = ndt_copy_contiguous_dtype(u, dtype, 0, &ctx);
226
+
227
+ ndt_apply_spec_clear(&spec);
228
+ ndt_decref(dtype);
229
+
230
+ if (v == NULL) {
231
+ seterr(&ctx);
232
+ raise_error();
115
233
  }
116
- else {
117
- result[i] = NULL;
118
- stack[nin+i] = xnd_error;
234
+
235
+ types[nin] = v;
236
+ kernel = gm_select(&spec, self_p->table, self_p->name, types, li, nin, 1,
237
+ 1 && check_broadcast, stack, &ctx);
238
+ if (kernel.set == NULL) {
239
+ seterr(&ctx);
240
+ raise_error();
119
241
  }
120
242
  }
121
243
 
122
- /* Actually call the kernel function with prepared input and output args. */
123
- #ifdef HAVE_PTHREAD_H
124
- if (gm_apply_thread(&kernel, stack, spec.outer_dims, spec.flags,
125
- max_threads, &ctx) < 0) {
126
- seterr(&ctx);
127
- raise_error();
244
+ /*
245
+ * Replace args/kwargs types with types after substitution and broadcasting.
246
+ * This includes 'out' types, if explicitly passed as kwargs.
247
+ */
248
+ for (int i = 0; i < spec.nargs; ++i) {
249
+ stack[i].type = spec.types[i];
250
+ }
251
+
252
+ if (nout == 0) {
253
+ /* 'out' types have been inferred, create new XndObjects. */
254
+ VALUE x;
255
+ for (int i = 0; i < spec.nout; ++i) {
256
+ if (ndt_is_concrete(spec.types[nin+i])) {
257
+ uint32_t flags = self_p->flags == GM_CUDA_MANAGED_FUNC ? XND_CUDA_MANAGED : 0;
258
+ x = rb_xnd_empty_from_type(cls, spec.types[nin+i], flags);
259
+ rbstack[nin+i] = x;
260
+ stack[nin+i] = *rb_xnd_const_xnd(x);
261
+ }
262
+ else {
263
+ rb_raise(rb_eValueError,
264
+ "args with abstract types are temporarily disabled.");
265
+ }
266
+ }
128
267
  }
268
+
269
+ if (self_p->flags == GM_CUDA_MANAGED_FUNC) {
270
+ #ifdef HAVE_CUDA
271
+ /* populate with CUDA specific stuff */
129
272
  #else
130
- if (gm_apply(&kernel, stack, spec.outer_dims, &ctx) < 0) {
273
+ ndt_err_format(&ctx, NDT_RuntimeError,
274
+ "internal error: GM_CUDA_MANAGED_FUNC set in a build without cuda support");
275
+ ndt_apply_spec_clear(&spec);
131
276
  seterr(&ctx);
132
277
  raise_error();
278
+ #endif // HAVE_CUDA
133
279
  }
134
- #endif
135
-
136
- /* Prepare output XND objects. */
137
- for (i = 0; i < spec.nout; i++) {
138
- if (ndt_is_abstract(spec.out[i])) {
139
- ndt_del(spec.out[i]);
140
- VALUE x = rb_xnd_from_xnd(&stack[nin+i]);
141
- stack[nin+i] = xnd_error;
142
- if (x == NULL) {
143
- for (k = i+i; k < spec.nout; k++) {
144
- if (ndt_is_abstract(spec.out[k])) {
145
- xnd_del_buffer(&stack[nin+k], XND_OWN_ALL);
146
- }
147
- }
148
- }
149
- result[i] = x;
150
- }
151
- }
280
+ else {
281
+ #ifdef HAVE_PTHREAD_H
282
+ const int rounding = fegetround();
283
+ fesetround(FE_TONEAREST);
284
+
285
+ const int64_t N = enable_threads ? max_threads : 1;
286
+ const int ret = gm_apply_thread(&kernel, stack, spec.outer_dims, N, &ctx);
287
+ fesetround(rounding);
152
288
 
153
- if (spec.nbroadcast > 0) {
154
- for (i = 0; i < nin; ++i) {
155
- ndt_del(spec.broadcast[i]);
289
+ if (ret < 0) {
290
+ ndt_apply_spec_clear(&spec);
291
+ seterr(&ctx);
292
+ raise_error();
156
293
  }
294
+ #else
295
+ const int rounding = fegetround();
296
+ fesetround(FE_TONEAREST);
297
+
298
+ const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
299
+ fesetround(rounding);
300
+
301
+ if (ret < 0) {
302
+ ndt_apply_spec_clear(&spec);
303
+ seterr(&ctx);
304
+ raise_error();
305
+ }
306
+ #endif // HAVE_PTHREAD_H
157
307
  }
158
308
 
159
- /* Return result */
160
- switch(spec.nout) {
161
- case 0: return Qnil;
162
- case 1: return result[0];
309
+ nin = spec.nin;
310
+ nout = spec.nout;
311
+ nargs = spec.nargs;
312
+ ndt_apply_spec_clear(&spec);
313
+
314
+ switch (nout) {
315
+ case 0: {
316
+ return Qnil;
317
+ }
318
+ case 1: {
319
+ return rbstack[nin];
320
+ }
163
321
  default: {
164
- VALUE tuple = array_new(spec.nout);
165
- for (i = 0; i < spec.nout; ++i) {
166
- rb_ary_store(tuple, i, result[i]);
322
+ VALUE arr = rb_ary_new2(nout);
323
+ for (int i = 0; i < nout; ++i) {
324
+ rb_ary_store(arr, i, rbstack[nin+i]);
167
325
  }
168
- return tuple;
326
+ return arr;
169
327
  }
170
328
  }
171
329
  }
172
330
 
331
+
173
332
  /****************************************************************************/
174
333
  /* Singleton methods */
175
334
  /****************************************************************************/
@@ -225,7 +384,7 @@ add_function(const gm_func_t *f, void *args)
225
384
  struct map_args *a = (struct map_args *)args;
226
385
  VALUE func, func_hash;
227
386
 
228
- func = GufuncObject_alloc(a->table, f->name);
387
+ func = GufuncObject_alloc(a->table, f->name, GM_CPU_FUNC);
229
388
  if (func == NULL) {
230
389
  return -1;
231
390
  }
@@ -236,7 +395,6 @@ add_function(const gm_func_t *f, void *args)
236
395
  return 0;
237
396
  }
238
397
 
239
- /* C API call for adding functions from a gumath kernel table to */
240
398
  int
241
399
  rb_gumath_add_functions(VALUE module, const gm_tbl_t *tbl)
242
400
  {
@@ -289,6 +447,9 @@ void Init_ruby_gumath(void)
289
447
 
290
448
  /* Instance methods */
291
449
  rb_define_method(cGumath_GufuncObject, "call", Gumath_GufuncObject_call,-1);
450
+
451
+ /* errors */
452
+ rb_eValueError = rb_define_class("ValueError", rb_eRuntimeError);
292
453
 
293
454
  Init_gumath_functions();
294
455
  Init_gumath_examples();