gumath 0.2.0dev5 → 0.2.0dev8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1 @@
1
+ ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3
@@ -0,0 +1 @@
1
+ ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3
@@ -72,7 +72,7 @@ clear_all_slices(xnd_t *slices[], int *nslices, int stop)
72
72
  {
73
73
  for (int i = 0; i < stop; i++) {
74
74
  for (int k = 0; k < nslices[i]; k++) {
75
- ndt_del((ndt_t *)slices[i][k].type);
75
+ ndt_decref(slices[i][k].type);
76
76
  }
77
77
  ndt_free(slices[i]);
78
78
  }
@@ -94,16 +94,27 @@ apply_thread(void *arg)
94
94
 
95
95
  int
96
96
  gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
97
- uint32_t flags, const int64_t nthreads, ndt_context_t *ctx)
97
+ const int64_t nthreads, ndt_context_t *ctx)
98
98
  {
99
99
  const int nrows = (int)kernel->set->sig->Function.nargs;
100
100
  ALLOCA(xnd_t *, slices, nrows);
101
101
  ALLOCA(int, nslices, nrows);
102
102
  struct thread_info *tinfo;
103
103
  int ncols, tnum;
104
+ bool use_threads = true;
104
105
 
105
- if (nthreads <= 1 || nrows == 0 || outer_dims == 0 ||
106
- !(flags & NDT_STRIDED)) {
106
+ if (nthreads <= 1 || nrows == 0 || outer_dims == 0) {
107
+ use_threads = false;
108
+ }
109
+
110
+ for (int i = 0; i < nrows; i++) {
111
+ const ndt_t *t = stack[i].type;
112
+ if (!ndt_is_ndarray(t) || ndt_nelem(t) < GM_THREAD_CUTOFF) {
113
+ use_threads = false;
114
+ }
115
+ }
116
+
117
+ if (!use_threads) {
107
118
  return gm_apply(kernel, stack, outer_dims, ctx);
108
119
  }
109
120
 
@@ -147,6 +158,7 @@ gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
147
158
  &tinfo[tnum]);
148
159
  if (ret != 0) {
149
160
  clear_all_slices(slices, nslices, nrows);
161
+ ndt_free(tinfo);
150
162
  ndt_err_format(ctx, NDT_RuntimeError, "could not create thread");
151
163
  return -1;
152
164
  }
@@ -169,6 +181,7 @@ gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
169
181
  }
170
182
 
171
183
  clear_all_slices(slices, nslices, nrows);
184
+ ndt_free(tinfo);
172
185
 
173
186
  return ndt_err_occurred(ctx) ? -1 : 0;
174
187
  }
@@ -38,11 +38,99 @@
38
38
  #include "ndtypes.h"
39
39
  #include "xnd.h"
40
40
  #include "gumath.h"
41
+ #include "overflow.h"
41
42
 
42
43
 
44
+ static int _gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
45
+ const int outer_dims, ndt_context_t *ctx);
46
+
47
+ int
48
+ array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx)
49
+ {
50
+ const ndt_t *t = x->type;
51
+
52
+ if (t->tag != Array) {
53
+ ndt_err_format(ctx, NDT_RuntimeError,
54
+ "type mismatch in outer dimensions");
55
+ return -1;
56
+ }
57
+
58
+ if (XND_ARRAY_DATA(x->ptr) == NULL) {
59
+ bool overflow = false;
60
+ const int64_t size = MULi64(shape, t->Array.itemsize, &overflow);
61
+ if (overflow) {
62
+ ndt_err_format(ctx, NDT_ValueError,
63
+ "datasize of flexible array is too large");
64
+ return -1;
65
+ }
66
+
67
+ char *data = ndt_aligned_calloc(t->align, size);
68
+ if (data == NULL) {
69
+ ndt_err_format(ctx, NDT_MemoryError, "out of memory");
70
+ return -1;
71
+ }
72
+
73
+ XND_ARRAY_SHAPE(x->ptr) = shape;
74
+ XND_ARRAY_DATA(x->ptr) = data;
75
+
76
+ return 0;
77
+ }
78
+ else if (XND_ARRAY_SHAPE(x->ptr) != shape) {
79
+ ndt_err_format(ctx, NDT_RuntimeError,
80
+ "shape mismatch in outer dimensions");
81
+ return -1;
82
+ }
83
+ else {
84
+ return 0;
85
+ }
86
+ }
87
+
88
+ static inline bool
89
+ any_stored_index(xnd_t stack[], const int nargs)
90
+ {
91
+ for (int i = 0; i < nargs; i++) {
92
+ if (stack[i].ptr == NULL) {
93
+ continue;
94
+ }
95
+
96
+ const ndt_t *t = stack[i].type;
97
+ if (have_stored_index(t)) {
98
+ return true;
99
+ }
100
+ }
101
+
102
+ return false;
103
+ }
104
+
43
105
  int
44
106
  gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
45
107
  const int outer_dims, ndt_context_t *ctx)
108
+ {
109
+ if (any_stored_index(stack, nargs)) {
110
+ ALLOCA(xnd_t, next, nargs);
111
+
112
+ for (int i = 0; i < nargs; i++) {
113
+ const ndt_t *t = stack[i].type;
114
+ if (have_stored_index(t)) {
115
+ next[i] = apply_stored_indices(&stack[i], ctx);
116
+ if (xnd_err_occurred(&next[i])) {
117
+ return -1;
118
+ }
119
+ }
120
+ else {
121
+ next[i] = stack[i];
122
+ }
123
+ }
124
+
125
+ return _gm_xnd_map(f, next, nargs, outer_dims, ctx);
126
+ }
127
+
128
+ return _gm_xnd_map(f, stack, nargs, outer_dims, ctx);
129
+ }
130
+
131
+ static int
132
+ _gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
133
+ const int outer_dims, ndt_context_t *ctx)
46
134
  {
47
135
  ALLOCA(xnd_t, next, nargs);
48
136
  const ndt_t *t;
@@ -123,6 +211,28 @@ gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
123
211
  return 0;
124
212
  }
125
213
 
214
+ case Array: {
215
+ const int64_t shape = XND_ARRAY_SHAPE(stack[0].ptr);
216
+
217
+ for (int k = 1; k < nargs; k++) {
218
+ if (array_shape_check(&stack[k], shape, ctx) < 0) {
219
+ return -1;
220
+ }
221
+ }
222
+
223
+ for (int64_t i = 0; i < shape; i++) {
224
+ for (int k = 0; k < nargs; k++) {
225
+ next[k] = xnd_array_next(&stack[k], i);
226
+ }
227
+
228
+ if (gm_xnd_map(f, next, nargs, outer_dims-1, ctx) < 0) {
229
+ return -1;
230
+ }
231
+ }
232
+
233
+ return 0;
234
+ }
235
+
126
236
  default:
127
237
  ndt_err_format(ctx, NDT_NotImplementedError, "unsupported type");
128
238
  return -1;
@@ -33,7 +33,157 @@
33
33
  from ndtypes import ndt
34
34
  from xnd import xnd
35
35
  from ._gumath import *
36
+ from . import functions as _fn
36
37
 
38
+ try:
39
+ from . import cuda as _cd
40
+ except ImportError:
41
+ _cd = None
42
+
43
+
44
+ __all__ = ['cuda', 'fold', 'functions', 'get_max_threads', 'gufunc', 'reduce',
45
+ 'set_max_threads', 'unsafe_add_kernel', 'vfold', 'xndvectorize']
46
+
47
+
48
+ # ==============================================================================
49
+ # Init identity elements
50
+ # ==============================================================================
51
+
52
+ # This is done here now, perhaps it should be on the C level.
53
+ _fn.add.identity = 0
54
+ _fn.multiply.identity = 1
55
+
56
+
57
+ # ==============================================================================
58
+ # General fold function
59
+ # ==============================================================================
60
+
61
+ def fold(f, acc, x):
62
+ return vfold(x, f=f, acc=acc)
63
+
64
+
65
+ # ==============================================================================
66
+ # NumPy's reduce in terms of fold
67
+ # ==============================================================================
68
+
69
+ def _get_axes(axes, ndim):
70
+ type_err = "'axes' must be None, a single integer or a tuple of integers"
71
+ value_err = "axis with value %d out of range"
72
+ duplicate_err = "'axes' argument contains duplicate values"
73
+ if axes is None:
74
+ axes = tuple(range(ndim))
75
+ elif isinstance(axes, int):
76
+ axes = (axes,)
77
+ elif not isinstance(axes, tuple) or \
78
+ any(not isinstance(v, int) for v in axes):
79
+ raise TypeError(type_err)
80
+
81
+ if any(n >= ndim for n in axes):
82
+ raise ValueError(value_err % n)
83
+
84
+ if len(set(axes)) != len(axes):
85
+ raise ValueError(duplicate_err)
86
+
87
+ return list(axes)
88
+
89
+ def _copyto(dest, value):
90
+ x = xnd(value, dtype=dest.dtype)
91
+ _fn.copy(x, out=dest)
92
+
93
+ def reduce_cpu(f, x, axes, dtype):
94
+ """NumPy's reduce in terms of fold."""
95
+ axes = _get_axes(axes, x.ndim)
96
+ if not axes:
97
+ return x
98
+
99
+ permute = [n for n in range(x.ndim) if n not in axes]
100
+ permute = axes + permute
101
+
102
+ T = x.transpose(permute=permute)
103
+
104
+ N = len(axes)
105
+ t = T.type.at(N, dtype=dtype)
106
+ acc = x.empty(t, device=x.device)
107
+
108
+ if f.identity is not None:
109
+ _copyto(acc, f.identity)
110
+ tl = T
111
+ elif N == 1 and T.type.shape[0] > 0:
112
+ hd, tl = T[0], T[1:]
113
+ acc[()] = hd
114
+ else:
115
+ raise ValueError(
116
+ "reduction not possible for function without an identity element")
117
+
118
+ return fold(f, acc, tl)
119
+
120
+ def reduce_cuda(g, x, axes, dtype):
121
+ """Reductions in CUDA use the thrust library for speed and have limited
122
+ functionality."""
123
+ if axes != 0:
124
+ raise NotImplementedError("'axes' keyword is not implemented for CUDA")
125
+
126
+ return g(x, dtype=dtype)
127
+
128
+ def get_cuda_reduction_func(f):
129
+ if _cd is None:
130
+ return None
131
+ elif f == _cd.add:
132
+ return _cd.reduce_add
133
+ elif f == _cd.multiply:
134
+ return _cd.reduce_multiply
135
+ else:
136
+ return None
137
+
138
+ def reduce(f, x, axes=0, dtype=None):
139
+ if dtype is None:
140
+ dtype = maxcast[x.dtype]
141
+
142
+ g = get_cuda_reduction_func(f)
143
+ if g is not None:
144
+ return reduce_cuda(g, x, axes, dtype)
145
+
146
+ return reduce_cpu(f, x, axes, dtype)
147
+
148
+
149
+ maxcast = {
150
+ ndt("int8"): ndt("int64"),
151
+ ndt("int16"): ndt("int64"),
152
+ ndt("int32"): ndt("int64"),
153
+ ndt("int64"): ndt("int64"),
154
+ ndt("uint8"): ndt("uint64"),
155
+ ndt("uint16"): ndt("uint64"),
156
+ ndt("uint32"): ndt("uint64"),
157
+ ndt("uint64"): ndt("uint64"),
158
+ ndt("bfloat16"): ndt("float64"),
159
+ ndt("float16"): ndt("float64"),
160
+ ndt("float32"): ndt("float64"),
161
+ ndt("float64"): ndt("float64"),
162
+ ndt("complex32"): ndt("complex128"),
163
+ ndt("complex64"): ndt("complex128"),
164
+ ndt("complex128"): ndt("complex128"),
165
+
166
+ ndt("?int8"): ndt("?int64"),
167
+ ndt("?int16"): ndt("?int64"),
168
+ ndt("?int32"): ndt("?int64"),
169
+ ndt("?int64"): ndt("?int64"),
170
+ ndt("?uint8"): ndt("?uint64"),
171
+ ndt("?uint16"): ndt("?uint64"),
172
+ ndt("?uint32"): ndt("?uint64"),
173
+ ndt("?uint64"): ndt("?uint64"),
174
+ ndt("?bfloat16"): ndt("?float64"),
175
+ ndt("?float16"): ndt("?float64"),
176
+ ndt("?float32"): ndt("?float64"),
177
+ ndt("?float64"): ndt("?float64"),
178
+ ndt("?complex32"): ndt("?complex128"),
179
+ ndt("?complex64"): ndt("?complex128"),
180
+ ndt("?complex128"): ndt("?complex128"),
181
+ }
182
+
183
+
184
+ # ==============================================================================
185
+ # Numba's GUVectorize on xnd arrays
186
+ # ==============================================================================
37
187
 
38
188
  try:
39
189
  import numpy as np
@@ -45,16 +45,24 @@
45
45
  #define GUMATH_MODULE
46
46
  #include "pygumath.h"
47
47
 
48
+
48
49
  #ifdef _MSC_VER
49
50
  #ifndef UNUSED
50
51
  #define UNUSED
51
52
  #endif
53
+ #include <float.h>
54
+ #include <fenv.h>
55
+ #pragma fenv_access(on)
52
56
  #else
53
57
  #if defined(__GNUC__) && !defined(__INTEL_COMPILER)
54
58
  #define UNUSED __attribute__((unused))
55
59
  #else
56
60
  #define UNUSED
57
61
  #endif
62
+ #include <fenv.h>
63
+ #if 0 /* Not supported by gcc and clang. */
64
+ #pragma STDC FENV_ACCESS ON
65
+ #endif
58
66
  #endif
59
67
 
60
68
 
@@ -73,6 +81,9 @@ static gm_tbl_t *table = NULL;
73
81
  /* Xnd type */
74
82
  static PyTypeObject *xnd = NULL;
75
83
 
84
+ /* Empty positional arguments */
85
+ static PyObject *positional_empty = NULL;
86
+
76
87
  /* Maximum number of threads */
77
88
  static int64_t max_threads = 1;
78
89
 
@@ -95,7 +106,7 @@ seterr(ndt_context_t *ctx)
95
106
  static PyTypeObject Gufunc_Type;
96
107
 
97
108
  static PyObject *
98
- gufunc_new(const gm_tbl_t *tbl, const char *name)
109
+ gufunc_new(const gm_tbl_t *tbl, const char *name, const uint32_t flags)
99
110
  {
100
111
  NDT_STATIC_CONTEXT(ctx);
101
112
  GufuncObject *self;
@@ -106,12 +117,16 @@ gufunc_new(const gm_tbl_t *tbl, const char *name)
106
117
  }
107
118
 
108
119
  self->tbl = tbl;
120
+ self->flags = flags;
109
121
 
110
122
  self->name = ndt_strdup(name, &ctx);
111
123
  if (self->name == NULL) {
112
124
  return seterr(&ctx);
113
125
  }
114
126
 
127
+ self->identity = Py_None;
128
+ Py_INCREF(self->identity);
129
+
115
130
  return (PyObject *)self;
116
131
  }
117
132
 
@@ -119,6 +134,7 @@ static void
119
134
  gufunc_dealloc(GufuncObject *self)
120
135
  {
121
136
  ndt_free(self->name);
137
+ Py_DECREF(self->identity);
122
138
  PyObject_Del(self);
123
139
  }
124
140
 
@@ -128,124 +144,317 @@ gufunc_dealloc(GufuncObject *self)
128
144
  /****************************************************************************/
129
145
 
130
146
  static void
131
- clear_objects(PyObject **a, Py_ssize_t len)
147
+ clear_pystack(PyObject *pystack[], Py_ssize_t len)
148
+ {
149
+ for (Py_ssize_t i = 0; i < len; i++) {
150
+ Py_CLEAR(pystack[i]);
151
+ }
152
+ }
153
+
154
+ static int
155
+ parse_args(PyObject *pystack[NDT_MAX_ARGS], int *py_nin, int *py_nout, int *py_nargs,
156
+ PyObject *args, PyObject *out)
132
157
  {
133
- Py_ssize_t i;
158
+ Py_ssize_t nin;
159
+ Py_ssize_t nout;
134
160
 
135
- for (i = 0; i < len; i++) {
136
- Py_CLEAR(a[i]);
161
+ if (!args || !PyTuple_Check(args)) {
162
+ const char *name = args ? Py_TYPE(args)->tp_name : "NULL";
163
+ PyErr_Format(PyExc_SystemError,
164
+ "internal error: expected tuple, got '%.200s'", name);
165
+ return -1;
137
166
  }
167
+
168
+ nin = PyTuple_GET_SIZE(args);
169
+ if (nin > NDT_MAX_ARGS) {
170
+ PyErr_Format(PyExc_TypeError,
171
+ "maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin);
172
+ return -1;
173
+ }
174
+
175
+ for (Py_ssize_t i = 0; i < nin; i++) {
176
+ PyObject *v = PyTuple_GET_ITEM(args, i);
177
+ if (!Xnd_Check(v)) {
178
+ PyErr_Format(PyExc_TypeError,
179
+ "expected xnd argument, got '%.200s'", Py_TYPE(v)->tp_name);
180
+ return -1;
181
+ }
182
+
183
+ pystack[i] = v;
184
+ }
185
+
186
+ if (out == NULL) {
187
+ nout = 0;
188
+ }
189
+ else {
190
+ if (Xnd_Check(out)) {
191
+ nout = 1;
192
+ if (nin+nout > NDT_MAX_ARGS) {
193
+ PyErr_Format(PyExc_TypeError,
194
+ "maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin+nout);
195
+ return -1;
196
+ }
197
+ pystack[nin] = out;
198
+ }
199
+ else if (PyTuple_Check(out)) {
200
+ nout = PyTuple_GET_SIZE(out);
201
+ if (nout > NDT_MAX_ARGS || nin+nout > NDT_MAX_ARGS) {
202
+ PyErr_Format(PyExc_TypeError,
203
+ "maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin+nout);
204
+ return -1;
205
+ }
206
+
207
+ for (Py_ssize_t i = 0; i < nout; i++) {
208
+ PyObject *v = PyTuple_GET_ITEM(out, i);
209
+ if (!Xnd_Check(v)) {
210
+ PyErr_Format(PyExc_TypeError,
211
+ "expected xnd argument, got '%.200s'", Py_TYPE(v)->tp_name);
212
+ return -1;
213
+ }
214
+
215
+ pystack[nin+i] = v;
216
+ }
217
+ }
218
+ else {
219
+ PyErr_Format(PyExc_TypeError,
220
+ "'out' argument must be xnd or a tuple of xnd, got '%.200s'",
221
+ Py_TYPE(out)->tp_name);
222
+ return -1;
223
+ }
224
+ }
225
+
226
+ for (int i = 0; i < nin+nout; i++) {
227
+ Py_INCREF(pystack[i]);
228
+ }
229
+
230
+ *py_nin = (int)nin;
231
+ *py_nout = (int)nout;
232
+ *py_nargs = (int)nin+(int)nout;
233
+
234
+ return 0;
138
235
  }
139
236
 
140
237
  static PyObject *
141
- gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwds)
238
+ _gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwargs,
239
+ bool enable_threads, bool check_broadcast)
142
240
  {
241
+ static char *kwlist[] = {"out", "dtype", "cls", NULL};
242
+ PyObject *out = Py_None;
243
+ PyObject *dt = Py_None;
244
+ PyObject *cls = Py_None;
245
+
143
246
  NDT_STATIC_CONTEXT(ctx);
144
- const Py_ssize_t nin = PyTuple_GET_SIZE(args);
145
- PyObject **a = &PyTuple_GET_ITEM(args, 0);
146
- PyObject *result[NDT_MAX_ARGS];
147
- ndt_apply_spec_t spec = ndt_apply_spec_empty;
148
- const ndt_t *in_types[NDT_MAX_ARGS];
247
+ PyObject *pystack[NDT_MAX_ARGS];
149
248
  xnd_t stack[NDT_MAX_ARGS];
249
+ const ndt_t *types[NDT_MAX_ARGS];
250
+ int64_t li[NDT_MAX_ARGS];
251
+ ndt_apply_spec_t spec = ndt_apply_spec_empty;
150
252
  gm_kernel_t kernel;
151
- int i, k;
253
+ bool have_cpu_device = false;
254
+ ndt_t *dtype = NULL;
255
+ int nin, nout, nargs;
256
+ int k;
152
257
 
153
- if (kwds && PyDict_Size(kwds) > 0) {
154
- PyErr_SetString(PyExc_TypeError,
155
- "gufunc calls do not support keywords");
258
+ if (!PyArg_ParseTupleAndKeywords(positional_empty, kwargs, "|$OOO", kwlist,
259
+ &out, &dt, &cls)) {
156
260
  return NULL;
157
261
  }
158
262
 
159
- if (nin > NDT_MAX_ARGS) {
263
+ out = out == Py_None ? NULL : out;
264
+ dt = dt == Py_None ? NULL : dt;
265
+ cls = cls == Py_None ? (PyObject *)xnd : cls;
266
+
267
+ if (dt != NULL) {
268
+ if (out != NULL) {
269
+ PyErr_SetString(PyExc_TypeError,
270
+ "the 'out' and 'dtype' arguments are mutually exclusive");
271
+ return NULL;
272
+ }
273
+ if (!Ndt_Check(dt)) {
274
+ PyErr_Format(PyExc_TypeError,
275
+ "'dtype' argument must be ndt, got '%.200s'",
276
+ Py_TYPE(dt)->tp_name);
277
+ return NULL;
278
+ }
279
+ dtype = (ndt_t *)NDT(dt);
280
+ ndt_incref(dtype);
281
+ }
282
+
283
+ if (!PyType_Check(cls) || !PyType_IsSubtype((PyTypeObject *)cls, xnd)) {
160
284
  PyErr_SetString(PyExc_TypeError,
161
- "invalid number of arguments");
285
+ "the 'cls' argument must be a subtype of 'xnd'");
286
+ return NULL;
287
+ }
288
+
289
+ if (parse_args(pystack, &nin, &nout, &nargs, args, out) < 0) {
162
290
  return NULL;
163
291
  }
292
+ assert(nout == 0 || dtype == NULL);
164
293
 
165
- for (i = 0; i < nin; i++) {
166
- if (!Xnd_Check(a[i])) {
167
- PyErr_SetString(PyExc_TypeError, "arguments must be xnd");
294
+ for (k = 0; k < nargs; k++) {
295
+ const XndObject *x = (XndObject *)pystack[k];
296
+ if (!(x->mblock->xnd->flags&XND_CUDA_MANAGED)) {
297
+ have_cpu_device = true;
298
+ }
299
+
300
+ stack[k] = *CONST_XND((PyObject *)x);
301
+ types[k] = stack[k].type;
302
+ li[k] = stack[k].index;
303
+ }
304
+
305
+ if (have_cpu_device) {
306
+ if (self->flags & GM_CUDA_MANAGED_FUNC) {
307
+ PyErr_SetString(PyExc_ValueError,
308
+ "cannot run a cuda function on xnd objects with cpu memory");
309
+ clear_pystack(pystack, nargs);
168
310
  return NULL;
169
311
  }
170
- stack[i] = *CONST_XND(a[i]);
171
- in_types[i] = stack[i].type;
172
312
  }
173
313
 
174
- kernel = gm_select(&spec, self->tbl, self->name, in_types, (int)nin, stack, &ctx);
314
+ kernel = gm_select(&spec, self->tbl, self->name, types, li, nin, nout,
315
+ nout && check_broadcast, stack, &ctx);
175
316
  if (kernel.set == NULL) {
176
317
  return seterr(&ctx);
177
318
  }
178
319
 
179
- if (spec.nbroadcast > 0) {
180
- for (i = 0; i < nin; i++) {
181
- stack[i].type = spec.broadcast[i];
320
+ if (dtype) {
321
+ if (spec.nout != 1) {
322
+ ndt_err_format(&ctx, NDT_TypeError,
323
+ "the 'dtype' argument is only supported for a single "
324
+ "return value");
325
+ ndt_apply_spec_clear(&spec);
326
+ ndt_decref(dtype);
327
+ return seterr(&ctx);
328
+ }
329
+
330
+ const ndt_t *u = spec.types[spec.nin];
331
+ const ndt_t *v = ndt_copy_contiguous_dtype(u, dtype, 0, &ctx);
332
+
333
+ ndt_apply_spec_clear(&spec);
334
+ ndt_decref(dtype);
335
+
336
+ if (v == NULL) {
337
+ return seterr(&ctx);
338
+ }
339
+
340
+ types[nin] = v;
341
+ kernel = gm_select(&spec, self->tbl, self->name, types, li, nin, 1,
342
+ 1 && check_broadcast, stack, &ctx);
343
+ if (kernel.set == NULL) {
344
+ return seterr(&ctx);
182
345
  }
183
346
  }
184
347
 
185
- for (i = 0; i < spec.nout; i++) {
186
- if (ndt_is_concrete(spec.out[i])) {
187
- PyObject *x = Xnd_EmptyFromType(xnd, spec.out[i]);
188
- if (x == NULL) {
189
- clear_objects(result, i);
190
- ndt_apply_spec_clear(&spec);
348
+ /*
349
+ * Replace args/kwargs types with types after substitution and broadcasting.
350
+ * This includes 'out' types, if explicitly passed as kwargs.
351
+ */
352
+ for (int i = 0; i < spec.nargs; i++) {
353
+ stack[i].type = spec.types[i];
354
+ }
355
+
356
+ if (nout == 0) {
357
+ /* 'out' types have been inferred, create new XndObjects. */
358
+ for (int i = 0; i < spec.nout; i++) {
359
+ if (ndt_is_concrete(spec.types[nin+i])) {
360
+ uint32_t flags = self->flags == GM_CUDA_MANAGED_FUNC ? XND_CUDA_MANAGED : 0;
361
+ PyObject *x = Xnd_EmptyFromType((PyTypeObject *)cls, spec.types[nin+i], flags);
362
+ if (x == NULL) {
363
+ clear_pystack(pystack, nin+i);
364
+ ndt_apply_spec_clear(&spec);
191
365
  return NULL;
192
366
  }
193
- result[i] = x;
367
+ pystack[nin+i] = x;
194
368
  stack[nin+i] = *CONST_XND(x);
195
- }
196
- else {
197
- result[i] = NULL;
198
- stack[nin+i] = xnd_error;
199
- }
369
+ }
370
+ else {
371
+ clear_pystack(pystack, nin+i);
372
+ ndt_apply_spec_clear(&spec);
373
+ PyErr_SetString(PyExc_ValueError,
374
+ "arguments with abstract types are temporarily disabled");
375
+ return NULL;
376
+ }
377
+ }
200
378
  }
201
379
 
202
- #ifdef HAVE_PTHREAD_H
203
- if (gm_apply_thread(&kernel, stack, spec.outer_dims, spec.flags,
204
- max_threads, &ctx) < 0) {
205
- clear_objects(result, spec.nout);
206
- return seterr(&ctx);
207
- }
208
- #else
209
- if (gm_apply(&kernel, stack, spec.outer_dims, &ctx) < 0) {
210
- clear_objects(result, spec.nout);
211
- return seterr(&ctx);
212
- }
213
- #endif
380
+ if (self->flags == GM_CUDA_MANAGED_FUNC) {
381
+ #if HAVE_CUDA
382
+ if (!check_broadcast) {
383
+ ndt_err_format(&ctx, NDT_NotImplementedError,
384
+ "fold() is currently not supported on cuda");
385
+ clear_pystack(pystack, spec.nargs);
386
+ ndt_apply_spec_clear(&spec);
387
+ return seterr(&ctx);
388
+ }
214
389
 
215
- for (i = 0; i < spec.nout; i++) {
216
- if (ndt_is_abstract(spec.out[i])) {
217
- ndt_del(spec.out[i]);
218
- PyObject *x = Xnd_FromXnd(xnd, &stack[nin+i]);
219
- stack[nin+i] = xnd_error;
220
- if (x == NULL) {
221
- clear_objects(result, i);
222
- for (k = i+1; k < spec.nout; k++) {
223
- if (ndt_is_abstract(spec.out[k])) {
224
- xnd_del_buffer(&stack[nin+k], XND_OWN_ALL);
225
- }
226
- }
227
- }
228
- result[i] = x;
390
+ const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
391
+
392
+ if (xnd_cuda_device_synchronize(&ctx) < 0 || ret < 0) {
393
+ clear_pystack(pystack, spec.nargs);
394
+ ndt_apply_spec_clear(&spec);
395
+ return seterr(&ctx);
229
396
  }
397
+ #else
398
+ ndt_err_format(&ctx, NDT_RuntimeError,
399
+ "internal error: GM_CUDA_MANAGED_FUNC set in a build without cuda support");
400
+ clear_pystack(pystack, spec.nargs);
401
+ ndt_apply_spec_clear(&spec);
402
+ return seterr(&ctx);
403
+ #endif
230
404
  }
405
+ else {
406
+ #ifdef HAVE_PTHREAD_H
407
+ const int rounding = fegetround();
408
+ fesetround(FE_TONEAREST);
409
+
410
+ const int64_t N = enable_threads ? max_threads : 1;
411
+ const int ret = gm_apply_thread(&kernel, stack, spec.outer_dims, N,
412
+ &ctx);
413
+ fesetround(rounding);
414
+
415
+ if (ret < 0) {
416
+ clear_pystack(pystack, spec.nargs);
417
+ ndt_apply_spec_clear(&spec);
418
+ return seterr(&ctx);
419
+ }
420
+ #else
421
+ const int rounding = fegetround();
422
+ fesetround(FE_TONEAREST);
423
+
424
+ const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
231
425
 
232
- if (spec.nbroadcast > 0) {
233
- for (i = 0; i < nin; i++) {
234
- ndt_del(spec.broadcast[i]);
426
+ fesetround(rounding);
427
+
428
+ if (ret < 0) {
429
+ clear_pystack(pystack, spec.nargs);
430
+ ndt_apply_spec_clear(&spec);
431
+ return seterr(&ctx);
235
432
  }
433
+ #endif
236
434
  }
237
435
 
238
- switch (spec.nout) {
239
- case 0: Py_RETURN_NONE;
240
- case 1: return result[0];
436
+ nin = spec.nin;
437
+ nout = spec.nout;
438
+ nargs = spec.nargs;
439
+ ndt_apply_spec_clear(&spec);
440
+
441
+ switch (nout) {
442
+ case 0: {
443
+ clear_pystack(pystack, nargs);
444
+ Py_RETURN_NONE;
445
+ }
446
+ case 1: {
447
+ clear_pystack(pystack, nin);
448
+ return pystack[nin];
449
+ }
241
450
  default: {
242
- PyObject *tuple = PyTuple_New(spec.nout);
451
+ PyObject *tuple = PyTuple_New(nout);
243
452
  if (tuple == NULL) {
244
- clear_objects(result, spec.nout);
453
+ clear_pystack(pystack, nargs);
245
454
  return NULL;
246
455
  }
247
- for (i = 0; i < spec.nout; i++) {
248
- PyTuple_SET_ITEM(tuple, i, result[i]);
456
+ for (int i = 0; i < nout; i++) {
457
+ PyTuple_SET_ITEM(tuple, i, pystack[nin+i]);
249
458
  }
250
459
  return tuple;
251
460
  }
@@ -253,7 +462,39 @@ gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwds)
253
462
  }
254
463
 
255
464
  static PyObject *
256
- gufunc_kernels(GufuncObject *self, PyObject *args GM_UNUSED)
465
+ gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwargs)
466
+ {
467
+ return _gufunc_call(self, args, kwargs, true, true);
468
+ }
469
+
470
+ static PyObject *
471
+ gufunc_getdevice(GufuncObject *self, PyObject *args GM_UNUSED)
472
+ {
473
+ if (self->flags & GM_CUDA_MANAGED_FUNC) {
474
+ return PyUnicode_FromString("cuda:managed");
475
+ }
476
+
477
+ Py_RETURN_NONE;
478
+ }
479
+
480
+ static PyObject *
481
+ gufunc_getidentity(GufuncObject *self, PyObject *args GM_UNUSED)
482
+ {
483
+ Py_INCREF(self->identity);
484
+ return self->identity;
485
+ }
486
+
487
+ static int
488
+ gufunc_setidentity(GufuncObject *self, PyObject *value, void *closure GM_UNUSED)
489
+ {
490
+ Py_DECREF(self->identity);
491
+ Py_INCREF(value);
492
+ self->identity = value;
493
+ return 0;
494
+ }
495
+
496
+ static PyObject *
497
+ gufunc_getkernels(GufuncObject *self, PyObject *args GM_UNUSED)
257
498
  {
258
499
  NDT_STATIC_CONTEXT(ctx);
259
500
  PyObject *list, *tmp;
@@ -294,7 +535,9 @@ gufunc_kernels(GufuncObject *self, PyObject *args GM_UNUSED)
294
535
 
295
536
  static PyGetSetDef gufunc_getsets [] =
296
537
  {
297
- { "kernels", (getter)gufunc_kernels, NULL, NULL, NULL},
538
+ { "device", (getter)gufunc_getdevice, NULL, NULL, NULL},
539
+ { "identity", (getter)gufunc_getidentity, (setter)gufunc_setidentity, NULL, NULL},
540
+ { "kernels", (getter)gufunc_getkernels, NULL, NULL, NULL},
298
541
  {NULL}
299
542
  };
300
543
 
@@ -323,13 +566,25 @@ struct map_args {
323
566
  const gm_tbl_t *tbl;
324
567
  };
325
568
 
569
+ static int
570
+ Gufunc_CheckExact(const PyObject *v)
571
+ {
572
+ return Py_TYPE(v) == &Gufunc_Type;
573
+ }
574
+
575
+ static int
576
+ Gufunc_Check(const PyObject *v)
577
+ {
578
+ return PyObject_TypeCheck(v, &Gufunc_Type);
579
+ }
580
+
326
581
  static int
327
582
  add_function(const gm_func_t *f, void *args)
328
583
  {
329
584
  struct map_args *a = (struct map_args *)args;
330
585
  PyObject *func;
331
586
 
332
- func = gufunc_new(a->tbl, f->name);
587
+ func = gufunc_new(a->tbl, f->name, GM_CPU_FUNC);
333
588
  if (func == NULL) {
334
589
  return -1;
335
590
  }
@@ -349,10 +604,40 @@ Gumath_AddFunctions(PyObject *m, const gm_tbl_t *tbl)
349
604
  return 0;
350
605
  }
351
606
 
607
+ static int
608
+ add_cuda_function(const gm_func_t *f, void *args)
609
+ {
610
+ struct map_args *a = (struct map_args *)args;
611
+ PyObject *func;
612
+
613
+ func = gufunc_new(a->tbl, f->name, GM_CUDA_MANAGED_FUNC);
614
+ if (func == NULL) {
615
+ return -1;
616
+ }
617
+
618
+ return PyModule_AddObject(a->module, f->name, func);
619
+ }
620
+
621
+ static int
622
+ Gumath_AddCudaFunctions(PyObject *m, const gm_tbl_t *tbl)
623
+ {
624
+ struct map_args args = {m, tbl};
625
+
626
+ if (gm_tbl_map(tbl, add_cuda_function, &args) < 0) {
627
+ return -1;
628
+ }
629
+
630
+ return 0;
631
+ }
632
+
352
633
  static PyObject *
353
634
  init_api(void)
354
635
  {
636
+ gumath_api[Gufunc_CheckExact_INDEX] = (void *)Gufunc_CheckExact;
637
+ gumath_api[Gufunc_Check_INDEX] = (void *)Gufunc_Check;
355
638
  gumath_api[Gumath_AddFunctions_INDEX] = (void *)Gumath_AddFunctions;
639
+ gumath_api[Gumath_AddFunctions_INDEX] = (void *)Gumath_AddFunctions;
640
+ gumath_api[Gumath_AddCudaFunctions_INDEX] = (void *)Gumath_AddCudaFunctions;
356
641
 
357
642
  return PyCapsule_New(gumath_api, "gumath._gumath._API", NULL);
358
643
  }
@@ -362,6 +647,75 @@ init_api(void)
362
647
  /* Module */
363
648
  /****************************************************************************/
364
649
 
650
+ static PyObject *
651
+ gufunc_vfold(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwargs)
652
+ {
653
+ static char *kwlist[] = {"f", "acc", NULL};
654
+ PyObject *func = Py_None;
655
+ PyObject *acc = Py_None;
656
+ PyObject *tuple;
657
+ PyObject *dict;
658
+ PyObject *res;
659
+ Py_ssize_t size, i;
660
+ int ret;
661
+
662
+ ret = PyArg_ParseTupleAndKeywords(positional_empty, kwargs, "|$OO", kwlist,
663
+ &func, &acc);
664
+ if (ret < 0) {
665
+ return NULL;
666
+ }
667
+
668
+ if (!Gufunc_Check(func)) {
669
+ PyErr_Format(PyExc_TypeError,
670
+ "vfold: expected gufunc object, got '%.200s'", Py_TYPE(func));
671
+ return NULL;
672
+ }
673
+
674
+ if (!Xnd_Check(acc)) {
675
+ PyErr_Format(PyExc_TypeError,
676
+ "vfold: expected xnd instance, got '%.200s'", Py_TYPE(acc));
677
+ return NULL;
678
+ }
679
+
680
+ /* Push the accumulator onto the argument stack. */
681
+ size = PyTuple_Size(args);
682
+ tuple = PyTuple_New(size+1);
683
+ if (tuple == NULL) {
684
+ return NULL;
685
+ }
686
+
687
+ Py_INCREF(acc);
688
+ PyTuple_SET_ITEM(tuple, 0, acc);
689
+ for (i = 0; i < size; i++) {
690
+ PyObject *v = PyTuple_GET_ITEM(args, i);
691
+ Py_INCREF(v);
692
+ PyTuple_SET_ITEM(tuple, i+1, v);
693
+ }
694
+
695
+ /* Simultaneously use the accumulator as the 'out' argument. */
696
+ dict = PyDict_New();
697
+ if (dict == NULL) {
698
+ Py_DECREF(tuple);
699
+ return NULL;
700
+ }
701
+ if (PyDict_SetItemString(dict, "out", acc) < 0) {
702
+ Py_DECREF(dict);
703
+ Py_DECREF(tuple);
704
+ return NULL;
705
+ }
706
+ if (PyDict_SetItemString(dict, "cls", (PyObject *)(Py_TYPE(acc))) < 0) {
707
+ Py_DECREF(dict);
708
+ Py_DECREF(tuple);
709
+ return NULL;
710
+ }
711
+
712
+ res = _gufunc_call((GufuncObject *)func, tuple, dict, false, false);
713
+ Py_DECREF(tuple);
714
+ Py_DECREF(dict);
715
+
716
+ return res;
717
+ }
718
+
365
719
  static PyObject *
366
720
  unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
367
721
  {
@@ -388,8 +742,8 @@ unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
388
742
  k.name = name;
389
743
  k.sig = sig;
390
744
 
391
- if (strcmp(tag, "Opt") == 0) {
392
- k.Opt = p;
745
+ if (strcmp(tag, "Opt") == 0) { /* XXX */
746
+ k.OptC = p;
393
747
  }
394
748
  else if (strcmp(tag, "C") == 0) {
395
749
  k.C = p;
@@ -418,7 +772,7 @@ unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
418
772
  return seterr(&ctx);
419
773
  }
420
774
 
421
- return gufunc_new(table, f->name);
775
+ return gufunc_new(table, f->name, GM_CPU_FUNC);
422
776
  }
423
777
 
424
778
  static void
@@ -490,6 +844,7 @@ set_max_threads(PyObject *m UNUSED, PyObject *obj)
490
844
  static PyMethodDef gumath_methods [] =
491
845
  {
492
846
  /* Methods */
847
+ { "vfold", (PyCFunction)gufunc_vfold, METH_VARARGS|METH_KEYWORDS, NULL },
493
848
  { "unsafe_add_kernel", (PyCFunction)unsafe_add_kernel, METH_VARARGS|METH_KEYWORDS, NULL },
494
849
  { "get_max_threads", (PyCFunction)get_max_threads, METH_NOARGS, NULL },
495
850
  { "set_max_threads", (PyCFunction)set_max_threads, METH_O, NULL },
@@ -554,11 +909,21 @@ PyInit__gumath(void)
554
909
  goto error;
555
910
  }
556
911
 
912
+ positional_empty = PyTuple_New(0);
913
+ if (positional_empty == NULL) {
914
+ goto error;
915
+ }
916
+
557
917
  m = PyModule_Create(&gumath_module);
558
918
  if (m == NULL) {
559
919
  goto error;
560
920
  }
561
921
 
922
+ Py_INCREF(&Gufunc_Type);
923
+ if (PyModule_AddObject(m, "gufunc", (PyObject *)&Gufunc_Type) < 0) {
924
+ goto error;
925
+ }
926
+
562
927
  Py_INCREF(capsule);
563
928
  if (PyModule_AddObject(m, "_API", capsule) < 0) {
564
929
  goto error;
@@ -571,6 +936,7 @@ PyInit__gumath(void)
571
936
  return m;
572
937
 
573
938
  error:
939
+ Py_CLEAR(positional_empty);
574
940
  Py_CLEAR(xnd);
575
941
  Py_CLEAR(m);
576
942
  return NULL;