gumath 0.2.0dev5 → 0.2.0dev8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -43,6 +43,7 @@ from glob import glob
43
43
  import platform
44
44
  import subprocess
45
45
  import shutil
46
+ import argparse
46
47
  import warnings
47
48
 
48
49
 
@@ -55,6 +56,15 @@ LONG_DESCRIPTION = """\
55
56
 
56
57
  warnings.simplefilter("ignore", UserWarning)
57
58
 
59
+
60
+ # Pre-parse and remove the '-j' argument from sys.argv.
61
+ parser = argparse.ArgumentParser()
62
+ parser.add_argument('-j', default=None)
63
+ values, rest = parser.parse_known_args()
64
+ PARALLEL = values.j
65
+ sys.argv = sys.argv[:1] + rest
66
+
67
+
58
68
  if sys.platform == "darwin":
59
69
  LIBNAME = "libgumath.dylib"
60
70
  LIBSONAME = "libgumath.0.dylib"
@@ -154,7 +164,10 @@ if len(sys.argv) == 2:
154
164
  path = module_path + ':' + python_path if python_path else module_path
155
165
  env = os.environ.copy()
156
166
  env['PYTHONPATH'] = path
157
- ret = subprocess.call([sys.executable, "python/test_gumath.py"], env=env)
167
+ ret = subprocess.call([sys.executable, "python/test_gumath.py", "--long"], env=env)
168
+ if ret != 0:
169
+ sys.exit(ret)
170
+ ret = subprocess.call([sys.executable, "python/test_xndarray.py"], env=env)
158
171
  sys.exit(ret)
159
172
  elif sys.argv[1] == 'clean':
160
173
  shutil.rmtree("build", ignore_errors=True)
@@ -173,11 +186,27 @@ if len(sys.argv) == 2:
173
186
  else:
174
187
  pass
175
188
 
189
+ def get_config_vars():
190
+ f = open("config.h")
191
+ config_vars = {}
192
+ for line in f:
193
+ if line.startswith("#define"):
194
+ l = line.split()
195
+ try:
196
+ config_vars[l[1]] = int(l[2])
197
+ except ValueError:
198
+ pass
199
+ elif line.startswith("/* #undef"):
200
+ l = line.split()
201
+ config_vars[l[2]] = 0
202
+ f.close()
203
+ return config_vars
176
204
 
177
205
  def gumath_extensions():
178
206
  add_include_dirs = [".", "libgumath", "ndtypes/python/ndtypes", "xnd/python/xnd"] + INCLUDES
179
207
  add_library_dirs = ["libgumath", "ndtypes/libndtypes", "xnd/libxnd"] + LIBS
180
208
  add_depends = []
209
+ config_vars = {}
181
210
 
182
211
  if sys.platform == "win32":
183
212
  add_libraries = ["libndtypes-0.2.0dev3.dll", "libxnd-0.2.0dev3.dll", "libgumath-0.2.0dev3.dll"]
@@ -199,6 +228,14 @@ def gumath_extensions():
199
228
  os.system("vcbuild32.bat")
200
229
  os.chdir("..")
201
230
  else:
231
+ if BUILD_ALL:
232
+ cflags = '"-I%s -I%s"' % tuple(CONFIGURE_INCLUDES)
233
+ ldflags = '"-L%s -L%s"' % tuple(CONFIGURE_LIBS)
234
+ make = "make -j%d" % int(PARALLEL) if PARALLEL else "make"
235
+ os.system("./configure CFLAGS=%s LDFLAGS=%s && %s" % (cflags, ldflags, make))
236
+
237
+ config_vars = get_config_vars()
238
+
202
239
  add_extra_compile_args = ["-Wextra", "-Wno-missing-field-initializers", "-std=c11"]
203
240
  if sys.platform == "darwin":
204
241
  add_libraries = ["ndtypes", "xnd", "gumath"]
@@ -209,10 +246,16 @@ def gumath_extensions():
209
246
  add_extra_link_args = []
210
247
  add_runtime_library_dirs = ["$ORIGIN"]
211
248
 
212
- if BUILD_ALL:
213
- cflags = '"-I%s -I%s"' % tuple(CONFIGURE_INCLUDES)
214
- ldflags = '"-L%s -L%s"' % tuple(CONFIGURE_LIBS)
215
- os.system("./configure CFLAGS=%s LDFLAGS=%s && make" % (cflags, ldflags))
249
+ if config_vars["HAVE_CUDA"]:
250
+ add_libraries += ["cudart"]
251
+
252
+ for d in [
253
+ "/usr/cuda/lib",
254
+ "/usr/cuda/lib64",
255
+ "/usr/local/cuda/lib/",
256
+ "/usr/local/cuda/lib64"]:
257
+ if os.path.isdir(d):
258
+ add_library_dirs.append(d)
216
259
 
217
260
  def gumath_ext():
218
261
  sources = ["python/gumath/_gumath.c"]
@@ -244,6 +287,21 @@ def gumath_extensions():
244
287
  runtime_library_dirs = add_runtime_library_dirs
245
288
  )
246
289
 
290
+ def cuda_ext():
291
+ sources = ["python/gumath/cuda.c"]
292
+
293
+ return Extension (
294
+ "gumath.cuda",
295
+ include_dirs = add_include_dirs,
296
+ library_dirs = add_library_dirs,
297
+ depends = add_depends,
298
+ sources = sources,
299
+ libraries = add_libraries,
300
+ extra_compile_args = add_extra_compile_args,
301
+ extra_link_args = add_extra_link_args,
302
+ runtime_library_dirs = add_runtime_library_dirs
303
+ )
304
+
247
305
  def examples_ext():
248
306
  sources = ["python/gumath/examples.c"]
249
307
 
@@ -259,7 +317,10 @@ def gumath_extensions():
259
317
  runtime_library_dirs = add_runtime_library_dirs
260
318
  )
261
319
 
262
- return [gumath_ext(), functions_ext(), examples_ext()]
320
+ extensions = [gumath_ext(), functions_ext(), examples_ext()]
321
+ if config_vars.get("HAVE_CUDA"):
322
+ extensions += [cuda_ext()]
323
+ return extensions
263
324
 
264
325
  setup (
265
326
  name = "gumath",
@@ -0,0 +1,35 @@
1
+ #include <cstdio>
2
+ #include <cstdlib>
3
+ #include <cuda_runtime.h>
4
+
5
+ static void
6
+ check(cudaError_t err)
7
+ {
8
+ if (err != cudaSuccess) {
9
+ exit(1);
10
+ }
11
+ }
12
+
13
+ static int
14
+ min(int x, int y)
15
+ {
16
+ return x <= y ? x : y;
17
+ }
18
+
19
+ int main()
20
+ {
21
+ int res = INT_MAX;
22
+ cudaDeviceProp prop;
23
+ int count, i, n;
24
+
25
+ check(cudaGetDeviceCount(&count));
26
+
27
+ for (i = 0; i < count; i++) {
28
+ check(cudaGetDeviceProperties(&prop, i));
29
+ n = prop.major * 10 + prop.minor;
30
+ res = min(res, n);
31
+ }
32
+
33
+ printf("%d", res);
34
+ return 0;
35
+ }
@@ -34,6 +34,17 @@
34
34
  #ifndef GUMATH_H
35
35
  #define GUMATH_H
36
36
 
37
+
38
+ #ifdef __cplusplus
39
+ extern "C" {
40
+ #endif
41
+
42
+ #ifdef __cplusplus
43
+ #include <cstdint>
44
+ #else
45
+ #include <stdint.h>
46
+ #endif
47
+
37
48
  #include "ndtypes.h"
38
49
  #include "xnd.h"
39
50
 
@@ -65,7 +76,8 @@
65
76
  #endif
66
77
 
67
78
 
68
- #define GM_MAX_KERNELS 512
79
+ #define GM_MAX_KERNELS 8192
80
+ #define GM_THREAD_CUTOFF 1000000
69
81
 
70
82
  typedef float float32_t;
71
83
  typedef double float64_t;
@@ -74,15 +86,25 @@ typedef double float64_t;
74
86
  typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
75
87
  typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
76
88
 
77
- /* Collection of specialized kernels for a single function signature. */
89
+ /*
90
+ * Collection of specialized kernels for a single function signature.
91
+ *
92
+ * NOTE: The specialized kernel lookup scheme is transitional and may
93
+ * be replaced by something else.
94
+ *
95
+ * This should be considered as a first version of a kernel request
96
+ * protocol.
97
+ */
78
98
  typedef struct {
79
- ndt_t *sig;
99
+ const ndt_t *sig;
80
100
  const ndt_constraint_t *constraint;
81
101
 
82
102
  /* Xnd signatures */
83
- gm_xnd_kernel_t Opt; /* dispatch ensures elementwise, at least 1D, contiguous in last dimensions */
84
- gm_xnd_kernel_t C; /* dispatch ensures c-contiguous in inner dimensions */
85
- gm_xnd_kernel_t Fortran; /* dispatch ensures f-contiguous in inner dimensions */
103
+ gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
104
+ gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
105
+ gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
106
+ gm_xnd_kernel_t C; /* C in inner dimensions */
107
+ gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
86
108
  gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
87
109
 
88
110
  /* NumPy signature */
@@ -99,11 +121,17 @@ typedef struct {
99
121
  const char *name;
100
122
  const char *sig;
101
123
  const ndt_constraint_t *constraint;
124
+ uint32_t cap;
102
125
 
103
- gm_xnd_kernel_t Opt;
126
+ /* Xnd signatures */
127
+ gm_xnd_kernel_t OptC;
128
+ gm_xnd_kernel_t OptZ;
129
+ gm_xnd_kernel_t OptS;
104
130
  gm_xnd_kernel_t C;
105
131
  gm_xnd_kernel_t Fortran;
106
132
  gm_xnd_kernel_t Xnd;
133
+
134
+ /* NumPy signature */
107
135
  gm_strided_kernel_t Strided;
108
136
  } gm_kernel_init_t;
109
137
 
@@ -115,7 +143,10 @@ typedef struct {
115
143
 
116
144
  /* Multimethod with associated kernels */
117
145
  typedef struct gm_func gm_func_t;
118
- typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f, const ndt_t *in[], int nin, ndt_context_t *ctx);
146
+ typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
147
+ const ndt_t *in[], const int64_t li[],
148
+ int nin, int nout, bool check_broadcast,
149
+ ndt_context_t *ctx);
119
150
  struct gm_func {
120
151
  char *name;
121
152
  gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
@@ -139,10 +170,10 @@ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_cont
139
170
  GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
140
171
 
141
172
  GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
142
- const ndt_t *in_types[], int nin, const xnd_t args[],
143
- ndt_context_t *ctx);
173
+ const ndt_t *types[], const int64_t li[], int nin, int nout,
174
+ bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
144
175
  GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
145
- GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, uint32_t flags, const int64_t nthreads, ndt_context_t *ctx);
176
+ GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
146
177
 
147
178
 
148
179
  /******************************************************************************/
@@ -171,6 +202,7 @@ GM_API int gm_np_map(const gm_strided_kernel_t f,
171
202
  /* Xnd loops */
172
203
  /******************************************************************************/
173
204
 
205
+ GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
174
206
  GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
175
207
  const int outer_dims, ndt_context_t *ctx);
176
208
 
@@ -191,10 +223,14 @@ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *sta
191
223
  /******************************************************************************/
192
224
 
193
225
  GM_API void gm_init(void);
194
- GM_API int gm_init_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
195
- GM_API int gm_init_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
226
+ GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
227
+ GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
228
+ GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
229
+
230
+ GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
231
+ GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
232
+
196
233
  GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
197
- GM_API int gm_init_bfloat16_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
198
234
  GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
199
235
  GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
200
236
  GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
@@ -202,4 +238,9 @@ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
202
238
  GM_API void gm_finalize(void);
203
239
 
204
240
 
241
+ #ifdef __cplusplus
242
+ } /* END extern "C" */
243
+ #endif
244
+
245
+
205
246
  #endif /* GUMATH_H */
@@ -33,8 +33,11 @@
33
33
  #define RUBY_GUMATH_H
34
34
 
35
35
  /* Classes */
36
- VALUE cGumath;
36
+ extern VALUE cGumath;
37
37
 
38
+ /* C API call for adding functions from a gumath kernel table to a Ruby module.
39
+ * Only adds CPU functions.
40
+ */
38
41
  int rb_gumath_add_functions(VALUE module, const gm_tbl_t *tbl);
39
42
  #define GUMATH_FUNCTION_HASH rb_intern("@gumath_functions")
40
43
 
@@ -43,12 +43,14 @@ static gm_tbl_t *table = NULL;
43
43
  /* Maximum number of threads */
44
44
  static int64_t max_threads = 1;
45
45
  static int initialized = 0;
46
- extern VALUE cGumath;
46
+ VALUE cGumath;
47
47
 
48
48
  /****************************************************************************/
49
49
  /* Error handling */
50
50
  /****************************************************************************/
51
51
 
52
+ static VALUE rb_eValueError;
53
+
52
54
  VALUE
53
55
  seterr(ndt_context_t *ctx)
54
56
  {
@@ -59,117 +61,274 @@ seterr(ndt_context_t *ctx)
59
61
  /* Instance methods */
60
62
  /****************************************************************************/
61
63
 
64
+ /* Parse optional arguments passed to GuFuncObject#call.
65
+ *
66
+ * Populates the rbstack with all the input arguments. Then checks whether
67
+ * the 'out' kwarg has been specified and populates the rest of rbstack
68
+ * with contents of 'out'.
69
+ */
70
+ void
71
+ parse_args(VALUE *rbstack, int *rb_nin, int *rb_nout, int *rb_nargs, int noptargs,
72
+ VALUE *argv, VALUE out)
73
+ {
74
+ size_t nin = noptargs, nout;
75
+
76
+ if (noptargs == 0) {
77
+ *rb_nin = 0;
78
+ }
79
+
80
+ for (int i = 0; i < nin; i++) {
81
+ if (!rb_is_a(argv[i], cXND)) {
82
+ rb_raise(rb_eArgError, "expected xnd arguments.");
83
+ }
84
+ rbstack[i] = argv[i];
85
+ }
86
+
87
+ if (out == Qnil) {
88
+ nout = 0;
89
+ }
90
+ else {
91
+ if (rb_xnd_check_type(out)) {
92
+ nout = 1;
93
+ if (nin + nout > NDT_MAX_ARGS) {
94
+ rb_raise(rb_eTypeError, "max number of arguments is %d, got %ld.",
95
+ NDT_MAX_ARGS, nin+nout);
96
+ }
97
+ rbstack[nin] = out;
98
+ }
99
+ else if (RB_TYPE_P(out, T_ARRAY)) {
100
+ nout = rb_ary_size(out);
101
+ if (nout > NDT_MAX_ARGS || nin+nout > NDT_MAX_ARGS) {
102
+ rb_raise(rb_eTypeError, "max number of arguments is %d, got %ld.",
103
+ NDT_MAX_ARGS, nin+nout);
104
+ }
105
+
106
+ for (int i = 0; i < nout; ++i) {
107
+ VALUE v = rb_ary_entry(out, i);
108
+ if (!rb_is_a(v, cXND)) {
109
+ rb_raise(rb_eTypeError, "expected xnd argument in all elements of out array.");
110
+ }
111
+ rbstack[nin+i] = v;
112
+ }
113
+ }
114
+ else {
115
+ rb_raise(rb_eTypeError, "'out' argument must of type XND or Array of XND objects.");
116
+ }
117
+ }
118
+
119
+ *rb_nin = (int)nin;
120
+ *rb_nout = (int)nout;
121
+ *rb_nargs = (int)nin + (int)nout;
122
+ }
123
+
124
+ /* Implement call method on the GufuncObject call. */
62
125
  static VALUE
63
126
  Gumath_GufuncObject_call(int argc, VALUE *argv, VALUE self)
64
127
  {
128
+ VALUE out = Qnil;
129
+ VALUE dt = Qnil;
130
+ VALUE cls = Qnil;
131
+
65
132
  NDT_STATIC_CONTEXT(ctx);
133
+ VALUE rbstack[NDT_MAX_ARGS], opts = Qnil;
66
134
  xnd_t stack[NDT_MAX_ARGS];
67
- const ndt_t *in_types[NDT_MAX_ARGS];
135
+ const ndt_t *types[NDT_MAX_ARGS];
68
136
  gm_kernel_t kernel;
69
137
  ndt_apply_spec_t spec = ndt_apply_spec_empty;
70
- GufuncObject *self_p;
71
- VALUE result[NDT_MAX_ARGS];
72
- int i, k;
73
- size_t nin = argc;
138
+ int64_t li[NDT_MAX_ARGS];
139
+ NdtObject *dt_p;
140
+ int k;
141
+ ndt_t *dtype = NULL;
142
+ int nin = argc, nout, nargs;
143
+ bool have_cpu_device = false;
144
+ GufuncObject * self_p;
145
+ bool check_broadcast = true, enable_threads = true;
74
146
 
75
147
  if (argc > NDT_MAX_ARGS) {
76
148
  rb_raise(rb_eArgError, "too many arguments.");
77
149
  }
150
+
151
+ /* parse keyword arguments. */
152
+ int noptargs = argc;
153
+ for (int i = 0; i < argc; ++i) {
154
+ if (RB_TYPE_P(argv[i], T_HASH)) {
155
+ noptargs = i;
156
+ opts = argv[i];
157
+ break;
158
+ }
159
+ }
160
+
161
+ if (NIL_P(opts)) { opts = rb_hash_new(); }
162
+
163
+ out = rb_hash_aref(opts, ID2SYM(rb_intern("out")));
164
+ dt = rb_hash_aref(opts, ID2SYM(rb_intern("dtype")));
165
+ cls = rb_hash_aref(opts, ID2SYM(rb_intern("cls")));
166
+
167
+ if (NIL_P(cls)) { cls = cXND; }
168
+ if (!NIL_P(dt)) {
169
+ if (!NIL_P(out)) {
170
+ rb_raise(rb_eArgError, "the 'out' and 'dtype' arguments are mutually exclusive.");
171
+ }
78
172
 
79
- /* Prepare arguments for sending into gumath function. */
80
- for (i = 0; i < argc; i++) {
81
- if (!rb_xnd_check_type(argv[i])) {
82
- VALUE str = rb_funcall(argv[i], rb_intern("inspect"), 0, NULL);
83
- rb_raise(rb_eArgError, "Args must be XND. Received %s.", RSTRING_PTR(str));
173
+ if (!rb_ndtypes_check_type(dt)) {
174
+ rb_raise(rb_eArgError, "'dtype' argument must be an NDT object.");
84
175
  }
176
+ dtype = (ndt_t *)rb_ndtypes_const_ndt(dt);
177
+ ndt_incref(dtype);
178
+ }
85
179
 
86
- stack[i] = *rb_xnd_const_xnd(argv[i]);
87
- in_types[i] = stack[i].type;
180
+ if (!rb_klass_has_ancestor(cls, cXND)) {
181
+ rb_raise(rb_eTypeError, "the 'cls' argument must be a subtype of 'xnd'.");
88
182
  }
89
183
 
90
- /* Select the gumath function to be called from the function table. */
184
+ /* parse leading optional arguments */
185
+ parse_args(rbstack, &nin, &nout, &nargs, noptargs, argv, out);
186
+
187
+ for (k = 0; k < nargs; ++k) {
188
+ if (!rb_xnd_is_cuda_managed(rbstack[k])) {
189
+ have_cpu_device = true;
190
+ }
191
+
192
+ stack[k] = *rb_xnd_const_xnd(rbstack[k]);
193
+ types[k] = stack[k].type;
194
+ li[k] = stack[k].index;
195
+ }
91
196
  GET_GUOBJ(self, self_p);
197
+
198
+ if (have_cpu_device) {
199
+ if (self_p->flags & GM_CUDA_MANAGED_FUNC) {
200
+ rb_raise(rb_eValueError,
201
+ "cannot run a cuda function on xnd objects with cpu memory.");
202
+ }
203
+ }
204
+
205
+ kernel = gm_select(&spec, self_p->table, self_p->name, types, li, nin, nout,
206
+ nout && check_broadcast, stack, &ctx);
92
207
 
93
- kernel = gm_select(&spec, self_p->table, self_p->name, in_types, argc, stack, &ctx);
94
208
  if (kernel.set == NULL) {
95
209
  seterr(&ctx);
96
210
  raise_error();
97
211
  }
98
212
 
99
- if (spec.nbroadcast > 0) {
100
- for (i = 0; i < argc; i++) {
101
- stack[i].type = spec.broadcast[i];
213
+ if (dtype) {
214
+ if (spec.nout != 1) {
215
+ ndt_err_format(&ctx, NDT_TypeError,
216
+ "the 'dtype' argument is only supported for a single "
217
+ "return value.");
218
+ ndt_apply_spec_clear(&spec);
219
+ ndt_decref(dtype);
220
+ seterr(&ctx);
221
+ raise_error();
102
222
  }
103
- }
104
223
 
105
- /* Populate output values with empty XND objects. */
106
- for (i = 0; i < spec.nout; i++) {
107
- if (ndt_is_concrete(spec.out[i])) {
108
- VALUE x = rb_xnd_empty_from_type(spec.out[i]);
109
- if (x == NULL) {
110
- ndt_apply_spec_clear(&spec);
111
- rb_raise(rb_eNoMemError, "could not allocate empty XND object.");
112
- }
113
- result[i] = x;
114
- stack[nin+i] = *rb_xnd_const_xnd(x);
224
+ const ndt_t *u = spec.types[spec.nin];
225
+ const ndt_t *v = ndt_copy_contiguous_dtype(u, dtype, 0, &ctx);
226
+
227
+ ndt_apply_spec_clear(&spec);
228
+ ndt_decref(dtype);
229
+
230
+ if (v == NULL) {
231
+ seterr(&ctx);
232
+ raise_error();
115
233
  }
116
- else {
117
- result[i] = NULL;
118
- stack[nin+i] = xnd_error;
234
+
235
+ types[nin] = v;
236
+ kernel = gm_select(&spec, self_p->table, self_p->name, types, li, nin, 1,
237
+ 1 && check_broadcast, stack, &ctx);
238
+ if (kernel.set == NULL) {
239
+ seterr(&ctx);
240
+ raise_error();
119
241
  }
120
242
  }
121
243
 
122
- /* Actually call the kernel function with prepared input and output args. */
123
- #ifdef HAVE_PTHREAD_H
124
- if (gm_apply_thread(&kernel, stack, spec.outer_dims, spec.flags,
125
- max_threads, &ctx) < 0) {
126
- seterr(&ctx);
127
- raise_error();
244
+ /*
245
+ * Replace args/kwargs types with types after substitution and broadcasting.
246
+ * This includes 'out' types, if explicitly passed as kwargs.
247
+ */
248
+ for (int i = 0; i < spec.nargs; ++i) {
249
+ stack[i].type = spec.types[i];
250
+ }
251
+
252
+ if (nout == 0) {
253
+ /* 'out' types have been inferred, create new XndObjects. */
254
+ VALUE x;
255
+ for (int i = 0; i < spec.nout; ++i) {
256
+ if (ndt_is_concrete(spec.types[nin+i])) {
257
+ uint32_t flags = self_p->flags == GM_CUDA_MANAGED_FUNC ? XND_CUDA_MANAGED : 0;
258
+ x = rb_xnd_empty_from_type(cls, spec.types[nin+i], flags);
259
+ rbstack[nin+i] = x;
260
+ stack[nin+i] = *rb_xnd_const_xnd(x);
261
+ }
262
+ else {
263
+ rb_raise(rb_eValueError,
264
+ "args with abstract types are temporarily disabled.");
265
+ }
266
+ }
128
267
  }
268
+
269
+ if (self_p->flags == GM_CUDA_MANAGED_FUNC) {
270
+ #ifdef HAVE_CUDA
271
+ /* populate with CUDA specific stuff */
129
272
  #else
130
- if (gm_apply(&kernel, stack, spec.outer_dims, &ctx) < 0) {
273
+ ndt_err_format(&ctx, NDT_RuntimeError,
274
+ "internal error: GM_CUDA_MANAGED_FUNC set in a build without cuda support");
275
+ ndt_apply_spec_clear(&spec);
131
276
  seterr(&ctx);
132
277
  raise_error();
278
+ #endif // HAVE_CUDA
133
279
  }
134
- #endif
135
-
136
- /* Prepare output XND objects. */
137
- for (i = 0; i < spec.nout; i++) {
138
- if (ndt_is_abstract(spec.out[i])) {
139
- ndt_del(spec.out[i]);
140
- VALUE x = rb_xnd_from_xnd(&stack[nin+i]);
141
- stack[nin+i] = xnd_error;
142
- if (x == NULL) {
143
- for (k = i+i; k < spec.nout; k++) {
144
- if (ndt_is_abstract(spec.out[k])) {
145
- xnd_del_buffer(&stack[nin+k], XND_OWN_ALL);
146
- }
147
- }
148
- }
149
- result[i] = x;
150
- }
151
- }
280
+ else {
281
+ #ifdef HAVE_PTHREAD_H
282
+ const int rounding = fegetround();
283
+ fesetround(FE_TONEAREST);
284
+
285
+ const int64_t N = enable_threads ? max_threads : 1;
286
+ const int ret = gm_apply_thread(&kernel, stack, spec.outer_dims, N, &ctx);
287
+ fesetround(rounding);
152
288
 
153
- if (spec.nbroadcast > 0) {
154
- for (i = 0; i < nin; ++i) {
155
- ndt_del(spec.broadcast[i]);
289
+ if (ret < 0) {
290
+ ndt_apply_spec_clear(&spec);
291
+ seterr(&ctx);
292
+ raise_error();
156
293
  }
294
+ #else
295
+ const int rounding = fegetround();
296
+ fesetround(FE_TONEAREST);
297
+
298
+ const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
299
+ fesetround(rounding);
300
+
301
+ if (ret < 0) {
302
+ ndt_apply_spec_clear(&spec);
303
+ seterr(&ctx);
304
+ raise_error();
305
+ }
306
+ #endif // HAVE_PTHREAD_H
157
307
  }
158
308
 
159
- /* Return result */
160
- switch(spec.nout) {
161
- case 0: return Qnil;
162
- case 1: return result[0];
309
+ nin = spec.nin;
310
+ nout = spec.nout;
311
+ nargs = spec.nargs;
312
+ ndt_apply_spec_clear(&spec);
313
+
314
+ switch (nout) {
315
+ case 0: {
316
+ return Qnil;
317
+ }
318
+ case 1: {
319
+ return rbstack[nin];
320
+ }
163
321
  default: {
164
- VALUE tuple = array_new(spec.nout);
165
- for (i = 0; i < spec.nout; ++i) {
166
- rb_ary_store(tuple, i, result[i]);
322
+ VALUE arr = rb_ary_new2(nout);
323
+ for (int i = 0; i < nout; ++i) {
324
+ rb_ary_store(arr, i, rbstack[nin+i]);
167
325
  }
168
- return tuple;
326
+ return arr;
169
327
  }
170
328
  }
171
329
  }
172
330
 
331
+
173
332
  /****************************************************************************/
174
333
  /* Singleton methods */
175
334
  /****************************************************************************/
@@ -225,7 +384,7 @@ add_function(const gm_func_t *f, void *args)
225
384
  struct map_args *a = (struct map_args *)args;
226
385
  VALUE func, func_hash;
227
386
 
228
- func = GufuncObject_alloc(a->table, f->name);
387
+ func = GufuncObject_alloc(a->table, f->name, GM_CPU_FUNC);
229
388
  if (func == NULL) {
230
389
  return -1;
231
390
  }
@@ -236,7 +395,6 @@ add_function(const gm_func_t *f, void *args)
236
395
  return 0;
237
396
  }
238
397
 
239
- /* C API call for adding functions from a gumath kernel table to */
240
398
  int
241
399
  rb_gumath_add_functions(VALUE module, const gm_tbl_t *tbl)
242
400
  {
@@ -289,6 +447,9 @@ void Init_ruby_gumath(void)
289
447
 
290
448
  /* Instance methods */
291
449
  rb_define_method(cGumath_GufuncObject, "call", Gumath_GufuncObject_call,-1);
450
+
451
+ /* errors */
452
+ rb_eValueError = rb_define_class("ValueError", rb_eRuntimeError);
292
453
 
293
454
  Init_gumath_functions();
294
455
  Init_gumath_examples();