gumath 0.2.0dev5 → 0.2.0dev8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1 @@
1
+ ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3
@@ -0,0 +1 @@
1
+ ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3
@@ -72,7 +72,7 @@ clear_all_slices(xnd_t *slices[], int *nslices, int stop)
72
72
  {
73
73
  for (int i = 0; i < stop; i++) {
74
74
  for (int k = 0; k < nslices[i]; k++) {
75
- ndt_del((ndt_t *)slices[i][k].type);
75
+ ndt_decref(slices[i][k].type);
76
76
  }
77
77
  ndt_free(slices[i]);
78
78
  }
@@ -94,16 +94,27 @@ apply_thread(void *arg)
94
94
 
95
95
  int
96
96
  gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
97
- uint32_t flags, const int64_t nthreads, ndt_context_t *ctx)
97
+ const int64_t nthreads, ndt_context_t *ctx)
98
98
  {
99
99
  const int nrows = (int)kernel->set->sig->Function.nargs;
100
100
  ALLOCA(xnd_t *, slices, nrows);
101
101
  ALLOCA(int, nslices, nrows);
102
102
  struct thread_info *tinfo;
103
103
  int ncols, tnum;
104
+ bool use_threads = true;
104
105
 
105
- if (nthreads <= 1 || nrows == 0 || outer_dims == 0 ||
106
- !(flags & NDT_STRIDED)) {
106
+ if (nthreads <= 1 || nrows == 0 || outer_dims == 0) {
107
+ use_threads = false;
108
+ }
109
+
110
+ for (int i = 0; i < nrows; i++) {
111
+ const ndt_t *t = stack[i].type;
112
+ if (!ndt_is_ndarray(t) || ndt_nelem(t) < GM_THREAD_CUTOFF) {
113
+ use_threads = false;
114
+ }
115
+ }
116
+
117
+ if (!use_threads) {
107
118
  return gm_apply(kernel, stack, outer_dims, ctx);
108
119
  }
109
120
 
@@ -147,6 +158,7 @@ gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
147
158
  &tinfo[tnum]);
148
159
  if (ret != 0) {
149
160
  clear_all_slices(slices, nslices, nrows);
161
+ ndt_free(tinfo);
150
162
  ndt_err_format(ctx, NDT_RuntimeError, "could not create thread");
151
163
  return -1;
152
164
  }
@@ -169,6 +181,7 @@ gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
169
181
  }
170
182
 
171
183
  clear_all_slices(slices, nslices, nrows);
184
+ ndt_free(tinfo);
172
185
 
173
186
  return ndt_err_occurred(ctx) ? -1 : 0;
174
187
  }
@@ -38,11 +38,99 @@
38
38
  #include "ndtypes.h"
39
39
  #include "xnd.h"
40
40
  #include "gumath.h"
41
+ #include "overflow.h"
41
42
 
42
43
 
44
+ static int _gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
45
+ const int outer_dims, ndt_context_t *ctx);
46
+
47
+ int
48
+ array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx)
49
+ {
50
+ const ndt_t *t = x->type;
51
+
52
+ if (t->tag != Array) {
53
+ ndt_err_format(ctx, NDT_RuntimeError,
54
+ "type mismatch in outer dimensions");
55
+ return -1;
56
+ }
57
+
58
+ if (XND_ARRAY_DATA(x->ptr) == NULL) {
59
+ bool overflow = false;
60
+ const int64_t size = MULi64(shape, t->Array.itemsize, &overflow);
61
+ if (overflow) {
62
+ ndt_err_format(ctx, NDT_ValueError,
63
+ "datasize of flexible array is too large");
64
+ return -1;
65
+ }
66
+
67
+ char *data = ndt_aligned_calloc(t->align, size);
68
+ if (data == NULL) {
69
+ ndt_err_format(ctx, NDT_MemoryError, "out of memory");
70
+ return -1;
71
+ }
72
+
73
+ XND_ARRAY_SHAPE(x->ptr) = shape;
74
+ XND_ARRAY_DATA(x->ptr) = data;
75
+
76
+ return 0;
77
+ }
78
+ else if (XND_ARRAY_SHAPE(x->ptr) != shape) {
79
+ ndt_err_format(ctx, NDT_RuntimeError,
80
+ "shape mismatch in outer dimensions");
81
+ return -1;
82
+ }
83
+ else {
84
+ return 0;
85
+ }
86
+ }
87
+
88
+ static inline bool
89
+ any_stored_index(xnd_t stack[], const int nargs)
90
+ {
91
+ for (int i = 0; i < nargs; i++) {
92
+ if (stack[i].ptr == NULL) {
93
+ continue;
94
+ }
95
+
96
+ const ndt_t *t = stack[i].type;
97
+ if (have_stored_index(t)) {
98
+ return true;
99
+ }
100
+ }
101
+
102
+ return false;
103
+ }
104
+
43
105
  int
44
106
  gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
45
107
  const int outer_dims, ndt_context_t *ctx)
108
+ {
109
+ if (any_stored_index(stack, nargs)) {
110
+ ALLOCA(xnd_t, next, nargs);
111
+
112
+ for (int i = 0; i < nargs; i++) {
113
+ const ndt_t *t = stack[i].type;
114
+ if (have_stored_index(t)) {
115
+ next[i] = apply_stored_indices(&stack[i], ctx);
116
+ if (xnd_err_occurred(&next[i])) {
117
+ return -1;
118
+ }
119
+ }
120
+ else {
121
+ next[i] = stack[i];
122
+ }
123
+ }
124
+
125
+ return _gm_xnd_map(f, next, nargs, outer_dims, ctx);
126
+ }
127
+
128
+ return _gm_xnd_map(f, stack, nargs, outer_dims, ctx);
129
+ }
130
+
131
+ static int
132
+ _gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
133
+ const int outer_dims, ndt_context_t *ctx)
46
134
  {
47
135
  ALLOCA(xnd_t, next, nargs);
48
136
  const ndt_t *t;
@@ -123,6 +211,28 @@ gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
123
211
  return 0;
124
212
  }
125
213
 
214
+ case Array: {
215
+ const int64_t shape = XND_ARRAY_SHAPE(stack[0].ptr);
216
+
217
+ for (int k = 1; k < nargs; k++) {
218
+ if (array_shape_check(&stack[k], shape, ctx) < 0) {
219
+ return -1;
220
+ }
221
+ }
222
+
223
+ for (int64_t i = 0; i < shape; i++) {
224
+ for (int k = 0; k < nargs; k++) {
225
+ next[k] = xnd_array_next(&stack[k], i);
226
+ }
227
+
228
+ if (gm_xnd_map(f, next, nargs, outer_dims-1, ctx) < 0) {
229
+ return -1;
230
+ }
231
+ }
232
+
233
+ return 0;
234
+ }
235
+
126
236
  default:
127
237
  ndt_err_format(ctx, NDT_NotImplementedError, "unsupported type");
128
238
  return -1;
@@ -33,7 +33,157 @@
33
33
  from ndtypes import ndt
34
34
  from xnd import xnd
35
35
  from ._gumath import *
36
+ from . import functions as _fn
36
37
 
38
+ try:
39
+ from . import cuda as _cd
40
+ except ImportError:
41
+ _cd = None
42
+
43
+
44
+ __all__ = ['cuda', 'fold', 'functions', 'get_max_threads', 'gufunc', 'reduce',
45
+ 'set_max_threads', 'unsafe_add_kernel', 'vfold', 'xndvectorize']
46
+
47
+
48
+ # ==============================================================================
49
+ # Init identity elements
50
+ # ==============================================================================
51
+
52
+ # This is done here now, perhaps it should be on the C level.
53
+ _fn.add.identity = 0
54
+ _fn.multiply.identity = 1
55
+
56
+
57
+ # ==============================================================================
58
+ # General fold function
59
+ # ==============================================================================
60
+
61
+ def fold(f, acc, x):
62
+ return vfold(x, f=f, acc=acc)
63
+
64
+
65
+ # ==============================================================================
66
+ # NumPy's reduce in terms of fold
67
+ # ==============================================================================
68
+
69
+ def _get_axes(axes, ndim):
70
+ type_err = "'axes' must be None, a single integer or a tuple of integers"
71
+ value_err = "axis with value %d out of range"
72
+ duplicate_err = "'axes' argument contains duplicate values"
73
+ if axes is None:
74
+ axes = tuple(range(ndim))
75
+ elif isinstance(axes, int):
76
+ axes = (axes,)
77
+ elif not isinstance(axes, tuple) or \
78
+ any(not isinstance(v, int) for v in axes):
79
+ raise TypeError(type_err)
80
+
81
+ if any(n >= ndim for n in axes):
82
+ raise ValueError(value_err % n)
83
+
84
+ if len(set(axes)) != len(axes):
85
+ raise ValueError(duplicate_err)
86
+
87
+ return list(axes)
88
+
89
+ def _copyto(dest, value):
90
+ x = xnd(value, dtype=dest.dtype)
91
+ _fn.copy(x, out=dest)
92
+
93
+ def reduce_cpu(f, x, axes, dtype):
94
+ """NumPy's reduce in terms of fold."""
95
+ axes = _get_axes(axes, x.ndim)
96
+ if not axes:
97
+ return x
98
+
99
+ permute = [n for n in range(x.ndim) if n not in axes]
100
+ permute = axes + permute
101
+
102
+ T = x.transpose(permute=permute)
103
+
104
+ N = len(axes)
105
+ t = T.type.at(N, dtype=dtype)
106
+ acc = x.empty(t, device=x.device)
107
+
108
+ if f.identity is not None:
109
+ _copyto(acc, f.identity)
110
+ tl = T
111
+ elif N == 1 and T.type.shape[0] > 0:
112
+ hd, tl = T[0], T[1:]
113
+ acc[()] = hd
114
+ else:
115
+ raise ValueError(
116
+ "reduction not possible for function without an identity element")
117
+
118
+ return fold(f, acc, tl)
119
+
120
+ def reduce_cuda(g, x, axes, dtype):
121
+ """Reductions in CUDA use the thrust library for speed and have limited
122
+ functionality."""
123
+ if axes != 0:
124
+ raise NotImplementedError("'axes' keyword is not implemented for CUDA")
125
+
126
+ return g(x, dtype=dtype)
127
+
128
+ def get_cuda_reduction_func(f):
129
+ if _cd is None:
130
+ return None
131
+ elif f == _cd.add:
132
+ return _cd.reduce_add
133
+ elif f == _cd.multiply:
134
+ return _cd.reduce_multiply
135
+ else:
136
+ return None
137
+
138
+ def reduce(f, x, axes=0, dtype=None):
139
+ if dtype is None:
140
+ dtype = maxcast[x.dtype]
141
+
142
+ g = get_cuda_reduction_func(f)
143
+ if g is not None:
144
+ return reduce_cuda(g, x, axes, dtype)
145
+
146
+ return reduce_cpu(f, x, axes, dtype)
147
+
148
+
149
+ maxcast = {
150
+ ndt("int8"): ndt("int64"),
151
+ ndt("int16"): ndt("int64"),
152
+ ndt("int32"): ndt("int64"),
153
+ ndt("int64"): ndt("int64"),
154
+ ndt("uint8"): ndt("uint64"),
155
+ ndt("uint16"): ndt("uint64"),
156
+ ndt("uint32"): ndt("uint64"),
157
+ ndt("uint64"): ndt("uint64"),
158
+ ndt("bfloat16"): ndt("float64"),
159
+ ndt("float16"): ndt("float64"),
160
+ ndt("float32"): ndt("float64"),
161
+ ndt("float64"): ndt("float64"),
162
+ ndt("complex32"): ndt("complex128"),
163
+ ndt("complex64"): ndt("complex128"),
164
+ ndt("complex128"): ndt("complex128"),
165
+
166
+ ndt("?int8"): ndt("?int64"),
167
+ ndt("?int16"): ndt("?int64"),
168
+ ndt("?int32"): ndt("?int64"),
169
+ ndt("?int64"): ndt("?int64"),
170
+ ndt("?uint8"): ndt("?uint64"),
171
+ ndt("?uint16"): ndt("?uint64"),
172
+ ndt("?uint32"): ndt("?uint64"),
173
+ ndt("?uint64"): ndt("?uint64"),
174
+ ndt("?bfloat16"): ndt("?float64"),
175
+ ndt("?float16"): ndt("?float64"),
176
+ ndt("?float32"): ndt("?float64"),
177
+ ndt("?float64"): ndt("?float64"),
178
+ ndt("?complex32"): ndt("?complex128"),
179
+ ndt("?complex64"): ndt("?complex128"),
180
+ ndt("?complex128"): ndt("?complex128"),
181
+ }
182
+
183
+
184
+ # ==============================================================================
185
+ # Numba's GUVectorize on xnd arrays
186
+ # ==============================================================================
37
187
 
38
188
  try:
39
189
  import numpy as np
@@ -45,16 +45,24 @@
45
45
  #define GUMATH_MODULE
46
46
  #include "pygumath.h"
47
47
 
48
+
48
49
  #ifdef _MSC_VER
49
50
  #ifndef UNUSED
50
51
  #define UNUSED
51
52
  #endif
53
+ #include <float.h>
54
+ #include <fenv.h>
55
+ #pragma fenv_access(on)
52
56
  #else
53
57
  #if defined(__GNUC__) && !defined(__INTEL_COMPILER)
54
58
  #define UNUSED __attribute__((unused))
55
59
  #else
56
60
  #define UNUSED
57
61
  #endif
62
+ #include <fenv.h>
63
+ #if 0 /* Not supported by gcc and clang. */
64
+ #pragma STDC FENV_ACCESS ON
65
+ #endif
58
66
  #endif
59
67
 
60
68
 
@@ -73,6 +81,9 @@ static gm_tbl_t *table = NULL;
73
81
  /* Xnd type */
74
82
  static PyTypeObject *xnd = NULL;
75
83
 
84
+ /* Empty positional arguments */
85
+ static PyObject *positional_empty = NULL;
86
+
76
87
  /* Maximum number of threads */
77
88
  static int64_t max_threads = 1;
78
89
 
@@ -95,7 +106,7 @@ seterr(ndt_context_t *ctx)
95
106
  static PyTypeObject Gufunc_Type;
96
107
 
97
108
  static PyObject *
98
- gufunc_new(const gm_tbl_t *tbl, const char *name)
109
+ gufunc_new(const gm_tbl_t *tbl, const char *name, const uint32_t flags)
99
110
  {
100
111
  NDT_STATIC_CONTEXT(ctx);
101
112
  GufuncObject *self;
@@ -106,12 +117,16 @@ gufunc_new(const gm_tbl_t *tbl, const char *name)
106
117
  }
107
118
 
108
119
  self->tbl = tbl;
120
+ self->flags = flags;
109
121
 
110
122
  self->name = ndt_strdup(name, &ctx);
111
123
  if (self->name == NULL) {
112
124
  return seterr(&ctx);
113
125
  }
114
126
 
127
+ self->identity = Py_None;
128
+ Py_INCREF(self->identity);
129
+
115
130
  return (PyObject *)self;
116
131
  }
117
132
 
@@ -119,6 +134,7 @@ static void
119
134
  gufunc_dealloc(GufuncObject *self)
120
135
  {
121
136
  ndt_free(self->name);
137
+ Py_DECREF(self->identity);
122
138
  PyObject_Del(self);
123
139
  }
124
140
 
@@ -128,124 +144,317 @@ gufunc_dealloc(GufuncObject *self)
128
144
  /****************************************************************************/
129
145
 
130
146
  static void
131
- clear_objects(PyObject **a, Py_ssize_t len)
147
+ clear_pystack(PyObject *pystack[], Py_ssize_t len)
148
+ {
149
+ for (Py_ssize_t i = 0; i < len; i++) {
150
+ Py_CLEAR(pystack[i]);
151
+ }
152
+ }
153
+
154
+ static int
155
+ parse_args(PyObject *pystack[NDT_MAX_ARGS], int *py_nin, int *py_nout, int *py_nargs,
156
+ PyObject *args, PyObject *out)
132
157
  {
133
- Py_ssize_t i;
158
+ Py_ssize_t nin;
159
+ Py_ssize_t nout;
134
160
 
135
- for (i = 0; i < len; i++) {
136
- Py_CLEAR(a[i]);
161
+ if (!args || !PyTuple_Check(args)) {
162
+ const char *name = args ? Py_TYPE(args)->tp_name : "NULL";
163
+ PyErr_Format(PyExc_SystemError,
164
+ "internal error: expected tuple, got '%.200s'", name);
165
+ return -1;
137
166
  }
167
+
168
+ nin = PyTuple_GET_SIZE(args);
169
+ if (nin > NDT_MAX_ARGS) {
170
+ PyErr_Format(PyExc_TypeError,
171
+ "maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin);
172
+ return -1;
173
+ }
174
+
175
+ for (Py_ssize_t i = 0; i < nin; i++) {
176
+ PyObject *v = PyTuple_GET_ITEM(args, i);
177
+ if (!Xnd_Check(v)) {
178
+ PyErr_Format(PyExc_TypeError,
179
+ "expected xnd argument, got '%.200s'", Py_TYPE(v)->tp_name);
180
+ return -1;
181
+ }
182
+
183
+ pystack[i] = v;
184
+ }
185
+
186
+ if (out == NULL) {
187
+ nout = 0;
188
+ }
189
+ else {
190
+ if (Xnd_Check(out)) {
191
+ nout = 1;
192
+ if (nin+nout > NDT_MAX_ARGS) {
193
+ PyErr_Format(PyExc_TypeError,
194
+ "maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin+nout);
195
+ return -1;
196
+ }
197
+ pystack[nin] = out;
198
+ }
199
+ else if (PyTuple_Check(out)) {
200
+ nout = PyTuple_GET_SIZE(out);
201
+ if (nout > NDT_MAX_ARGS || nin+nout > NDT_MAX_ARGS) {
202
+ PyErr_Format(PyExc_TypeError,
203
+ "maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin+nout);
204
+ return -1;
205
+ }
206
+
207
+ for (Py_ssize_t i = 0; i < nout; i++) {
208
+ PyObject *v = PyTuple_GET_ITEM(out, i);
209
+ if (!Xnd_Check(v)) {
210
+ PyErr_Format(PyExc_TypeError,
211
+ "expected xnd argument, got '%.200s'", Py_TYPE(v)->tp_name);
212
+ return -1;
213
+ }
214
+
215
+ pystack[nin+i] = v;
216
+ }
217
+ }
218
+ else {
219
+ PyErr_Format(PyExc_TypeError,
220
+ "'out' argument must be xnd or a tuple of xnd, got '%.200s'",
221
+ Py_TYPE(out)->tp_name);
222
+ return -1;
223
+ }
224
+ }
225
+
226
+ for (int i = 0; i < nin+nout; i++) {
227
+ Py_INCREF(pystack[i]);
228
+ }
229
+
230
+ *py_nin = (int)nin;
231
+ *py_nout = (int)nout;
232
+ *py_nargs = (int)nin+(int)nout;
233
+
234
+ return 0;
138
235
  }
139
236
 
140
237
  static PyObject *
141
- gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwds)
238
+ _gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwargs,
239
+ bool enable_threads, bool check_broadcast)
142
240
  {
241
+ static char *kwlist[] = {"out", "dtype", "cls", NULL};
242
+ PyObject *out = Py_None;
243
+ PyObject *dt = Py_None;
244
+ PyObject *cls = Py_None;
245
+
143
246
  NDT_STATIC_CONTEXT(ctx);
144
- const Py_ssize_t nin = PyTuple_GET_SIZE(args);
145
- PyObject **a = &PyTuple_GET_ITEM(args, 0);
146
- PyObject *result[NDT_MAX_ARGS];
147
- ndt_apply_spec_t spec = ndt_apply_spec_empty;
148
- const ndt_t *in_types[NDT_MAX_ARGS];
247
+ PyObject *pystack[NDT_MAX_ARGS];
149
248
  xnd_t stack[NDT_MAX_ARGS];
249
+ const ndt_t *types[NDT_MAX_ARGS];
250
+ int64_t li[NDT_MAX_ARGS];
251
+ ndt_apply_spec_t spec = ndt_apply_spec_empty;
150
252
  gm_kernel_t kernel;
151
- int i, k;
253
+ bool have_cpu_device = false;
254
+ ndt_t *dtype = NULL;
255
+ int nin, nout, nargs;
256
+ int k;
152
257
 
153
- if (kwds && PyDict_Size(kwds) > 0) {
154
- PyErr_SetString(PyExc_TypeError,
155
- "gufunc calls do not support keywords");
258
+ if (!PyArg_ParseTupleAndKeywords(positional_empty, kwargs, "|$OOO", kwlist,
259
+ &out, &dt, &cls)) {
156
260
  return NULL;
157
261
  }
158
262
 
159
- if (nin > NDT_MAX_ARGS) {
263
+ out = out == Py_None ? NULL : out;
264
+ dt = dt == Py_None ? NULL : dt;
265
+ cls = cls == Py_None ? (PyObject *)xnd : cls;
266
+
267
+ if (dt != NULL) {
268
+ if (out != NULL) {
269
+ PyErr_SetString(PyExc_TypeError,
270
+ "the 'out' and 'dtype' arguments are mutually exclusive");
271
+ return NULL;
272
+ }
273
+ if (!Ndt_Check(dt)) {
274
+ PyErr_Format(PyExc_TypeError,
275
+ "'dtype' argument must be ndt, got '%.200s'",
276
+ Py_TYPE(dt)->tp_name);
277
+ return NULL;
278
+ }
279
+ dtype = (ndt_t *)NDT(dt);
280
+ ndt_incref(dtype);
281
+ }
282
+
283
+ if (!PyType_Check(cls) || !PyType_IsSubtype((PyTypeObject *)cls, xnd)) {
160
284
  PyErr_SetString(PyExc_TypeError,
161
- "invalid number of arguments");
285
+ "the 'cls' argument must be a subtype of 'xnd'");
286
+ return NULL;
287
+ }
288
+
289
+ if (parse_args(pystack, &nin, &nout, &nargs, args, out) < 0) {
162
290
  return NULL;
163
291
  }
292
+ assert(nout == 0 || dtype == NULL);
164
293
 
165
- for (i = 0; i < nin; i++) {
166
- if (!Xnd_Check(a[i])) {
167
- PyErr_SetString(PyExc_TypeError, "arguments must be xnd");
294
+ for (k = 0; k < nargs; k++) {
295
+ const XndObject *x = (XndObject *)pystack[k];
296
+ if (!(x->mblock->xnd->flags&XND_CUDA_MANAGED)) {
297
+ have_cpu_device = true;
298
+ }
299
+
300
+ stack[k] = *CONST_XND((PyObject *)x);
301
+ types[k] = stack[k].type;
302
+ li[k] = stack[k].index;
303
+ }
304
+
305
+ if (have_cpu_device) {
306
+ if (self->flags & GM_CUDA_MANAGED_FUNC) {
307
+ PyErr_SetString(PyExc_ValueError,
308
+ "cannot run a cuda function on xnd objects with cpu memory");
309
+ clear_pystack(pystack, nargs);
168
310
  return NULL;
169
311
  }
170
- stack[i] = *CONST_XND(a[i]);
171
- in_types[i] = stack[i].type;
172
312
  }
173
313
 
174
- kernel = gm_select(&spec, self->tbl, self->name, in_types, (int)nin, stack, &ctx);
314
+ kernel = gm_select(&spec, self->tbl, self->name, types, li, nin, nout,
315
+ nout && check_broadcast, stack, &ctx);
175
316
  if (kernel.set == NULL) {
176
317
  return seterr(&ctx);
177
318
  }
178
319
 
179
- if (spec.nbroadcast > 0) {
180
- for (i = 0; i < nin; i++) {
181
- stack[i].type = spec.broadcast[i];
320
+ if (dtype) {
321
+ if (spec.nout != 1) {
322
+ ndt_err_format(&ctx, NDT_TypeError,
323
+ "the 'dtype' argument is only supported for a single "
324
+ "return value");
325
+ ndt_apply_spec_clear(&spec);
326
+ ndt_decref(dtype);
327
+ return seterr(&ctx);
328
+ }
329
+
330
+ const ndt_t *u = spec.types[spec.nin];
331
+ const ndt_t *v = ndt_copy_contiguous_dtype(u, dtype, 0, &ctx);
332
+
333
+ ndt_apply_spec_clear(&spec);
334
+ ndt_decref(dtype);
335
+
336
+ if (v == NULL) {
337
+ return seterr(&ctx);
338
+ }
339
+
340
+ types[nin] = v;
341
+ kernel = gm_select(&spec, self->tbl, self->name, types, li, nin, 1,
342
+ 1 && check_broadcast, stack, &ctx);
343
+ if (kernel.set == NULL) {
344
+ return seterr(&ctx);
182
345
  }
183
346
  }
184
347
 
185
- for (i = 0; i < spec.nout; i++) {
186
- if (ndt_is_concrete(spec.out[i])) {
187
- PyObject *x = Xnd_EmptyFromType(xnd, spec.out[i]);
188
- if (x == NULL) {
189
- clear_objects(result, i);
190
- ndt_apply_spec_clear(&spec);
348
+ /*
349
+ * Replace args/kwargs types with types after substitution and broadcasting.
350
+ * This includes 'out' types, if explicitly passed as kwargs.
351
+ */
352
+ for (int i = 0; i < spec.nargs; i++) {
353
+ stack[i].type = spec.types[i];
354
+ }
355
+
356
+ if (nout == 0) {
357
+ /* 'out' types have been inferred, create new XndObjects. */
358
+ for (int i = 0; i < spec.nout; i++) {
359
+ if (ndt_is_concrete(spec.types[nin+i])) {
360
+ uint32_t flags = self->flags == GM_CUDA_MANAGED_FUNC ? XND_CUDA_MANAGED : 0;
361
+ PyObject *x = Xnd_EmptyFromType((PyTypeObject *)cls, spec.types[nin+i], flags);
362
+ if (x == NULL) {
363
+ clear_pystack(pystack, nin+i);
364
+ ndt_apply_spec_clear(&spec);
191
365
  return NULL;
192
366
  }
193
- result[i] = x;
367
+ pystack[nin+i] = x;
194
368
  stack[nin+i] = *CONST_XND(x);
195
- }
196
- else {
197
- result[i] = NULL;
198
- stack[nin+i] = xnd_error;
199
- }
369
+ }
370
+ else {
371
+ clear_pystack(pystack, nin+i);
372
+ ndt_apply_spec_clear(&spec);
373
+ PyErr_SetString(PyExc_ValueError,
374
+ "arguments with abstract types are temporarily disabled");
375
+ return NULL;
376
+ }
377
+ }
200
378
  }
201
379
 
202
- #ifdef HAVE_PTHREAD_H
203
- if (gm_apply_thread(&kernel, stack, spec.outer_dims, spec.flags,
204
- max_threads, &ctx) < 0) {
205
- clear_objects(result, spec.nout);
206
- return seterr(&ctx);
207
- }
208
- #else
209
- if (gm_apply(&kernel, stack, spec.outer_dims, &ctx) < 0) {
210
- clear_objects(result, spec.nout);
211
- return seterr(&ctx);
212
- }
213
- #endif
380
+ if (self->flags == GM_CUDA_MANAGED_FUNC) {
381
+ #if HAVE_CUDA
382
+ if (!check_broadcast) {
383
+ ndt_err_format(&ctx, NDT_NotImplementedError,
384
+ "fold() is currently not supported on cuda");
385
+ clear_pystack(pystack, spec.nargs);
386
+ ndt_apply_spec_clear(&spec);
387
+ return seterr(&ctx);
388
+ }
214
389
 
215
- for (i = 0; i < spec.nout; i++) {
216
- if (ndt_is_abstract(spec.out[i])) {
217
- ndt_del(spec.out[i]);
218
- PyObject *x = Xnd_FromXnd(xnd, &stack[nin+i]);
219
- stack[nin+i] = xnd_error;
220
- if (x == NULL) {
221
- clear_objects(result, i);
222
- for (k = i+1; k < spec.nout; k++) {
223
- if (ndt_is_abstract(spec.out[k])) {
224
- xnd_del_buffer(&stack[nin+k], XND_OWN_ALL);
225
- }
226
- }
227
- }
228
- result[i] = x;
390
+ const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
391
+
392
+ if (xnd_cuda_device_synchronize(&ctx) < 0 || ret < 0) {
393
+ clear_pystack(pystack, spec.nargs);
394
+ ndt_apply_spec_clear(&spec);
395
+ return seterr(&ctx);
229
396
  }
397
+ #else
398
+ ndt_err_format(&ctx, NDT_RuntimeError,
399
+ "internal error: GM_CUDA_MANAGED_FUNC set in a build without cuda support");
400
+ clear_pystack(pystack, spec.nargs);
401
+ ndt_apply_spec_clear(&spec);
402
+ return seterr(&ctx);
403
+ #endif
230
404
  }
405
+ else {
406
+ #ifdef HAVE_PTHREAD_H
407
+ const int rounding = fegetround();
408
+ fesetround(FE_TONEAREST);
409
+
410
+ const int64_t N = enable_threads ? max_threads : 1;
411
+ const int ret = gm_apply_thread(&kernel, stack, spec.outer_dims, N,
412
+ &ctx);
413
+ fesetround(rounding);
414
+
415
+ if (ret < 0) {
416
+ clear_pystack(pystack, spec.nargs);
417
+ ndt_apply_spec_clear(&spec);
418
+ return seterr(&ctx);
419
+ }
420
+ #else
421
+ const int rounding = fegetround();
422
+ fesetround(FE_TONEAREST);
423
+
424
+ const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
231
425
 
232
- if (spec.nbroadcast > 0) {
233
- for (i = 0; i < nin; i++) {
234
- ndt_del(spec.broadcast[i]);
426
+ fesetround(rounding);
427
+
428
+ if (ret < 0) {
429
+ clear_pystack(pystack, spec.nargs);
430
+ ndt_apply_spec_clear(&spec);
431
+ return seterr(&ctx);
235
432
  }
433
+ #endif
236
434
  }
237
435
 
238
- switch (spec.nout) {
239
- case 0: Py_RETURN_NONE;
240
- case 1: return result[0];
436
+ nin = spec.nin;
437
+ nout = spec.nout;
438
+ nargs = spec.nargs;
439
+ ndt_apply_spec_clear(&spec);
440
+
441
+ switch (nout) {
442
+ case 0: {
443
+ clear_pystack(pystack, nargs);
444
+ Py_RETURN_NONE;
445
+ }
446
+ case 1: {
447
+ clear_pystack(pystack, nin);
448
+ return pystack[nin];
449
+ }
241
450
  default: {
242
- PyObject *tuple = PyTuple_New(spec.nout);
451
+ PyObject *tuple = PyTuple_New(nout);
243
452
  if (tuple == NULL) {
244
- clear_objects(result, spec.nout);
453
+ clear_pystack(pystack, nargs);
245
454
  return NULL;
246
455
  }
247
- for (i = 0; i < spec.nout; i++) {
248
- PyTuple_SET_ITEM(tuple, i, result[i]);
456
+ for (int i = 0; i < nout; i++) {
457
+ PyTuple_SET_ITEM(tuple, i, pystack[nin+i]);
249
458
  }
250
459
  return tuple;
251
460
  }
@@ -253,7 +462,39 @@ gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwds)
253
462
  }
254
463
 
255
464
  static PyObject *
256
- gufunc_kernels(GufuncObject *self, PyObject *args GM_UNUSED)
465
+ gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwargs)
466
+ {
467
+ return _gufunc_call(self, args, kwargs, true, true);
468
+ }
469
+
470
+ static PyObject *
471
+ gufunc_getdevice(GufuncObject *self, PyObject *args GM_UNUSED)
472
+ {
473
+ if (self->flags & GM_CUDA_MANAGED_FUNC) {
474
+ return PyUnicode_FromString("cuda:managed");
475
+ }
476
+
477
+ Py_RETURN_NONE;
478
+ }
479
+
480
+ static PyObject *
481
+ gufunc_getidentity(GufuncObject *self, PyObject *args GM_UNUSED)
482
+ {
483
+ Py_INCREF(self->identity);
484
+ return self->identity;
485
+ }
486
+
487
+ static int
488
+ gufunc_setidentity(GufuncObject *self, PyObject *value, void *closure GM_UNUSED)
489
+ {
490
+ Py_DECREF(self->identity);
491
+ Py_INCREF(value);
492
+ self->identity = value;
493
+ return 0;
494
+ }
495
+
496
+ static PyObject *
497
+ gufunc_getkernels(GufuncObject *self, PyObject *args GM_UNUSED)
257
498
  {
258
499
  NDT_STATIC_CONTEXT(ctx);
259
500
  PyObject *list, *tmp;
@@ -294,7 +535,9 @@ gufunc_kernels(GufuncObject *self, PyObject *args GM_UNUSED)
294
535
 
295
536
  static PyGetSetDef gufunc_getsets [] =
296
537
  {
297
- { "kernels", (getter)gufunc_kernels, NULL, NULL, NULL},
538
+ { "device", (getter)gufunc_getdevice, NULL, NULL, NULL},
539
+ { "identity", (getter)gufunc_getidentity, (setter)gufunc_setidentity, NULL, NULL},
540
+ { "kernels", (getter)gufunc_getkernels, NULL, NULL, NULL},
298
541
  {NULL}
299
542
  };
300
543
 
@@ -323,13 +566,25 @@ struct map_args {
323
566
  const gm_tbl_t *tbl;
324
567
  };
325
568
 
569
+ static int
570
+ Gufunc_CheckExact(const PyObject *v)
571
+ {
572
+ return Py_TYPE(v) == &Gufunc_Type;
573
+ }
574
+
575
+ static int
576
+ Gufunc_Check(const PyObject *v)
577
+ {
578
+ return PyObject_TypeCheck(v, &Gufunc_Type);
579
+ }
580
+
326
581
  static int
327
582
  add_function(const gm_func_t *f, void *args)
328
583
  {
329
584
  struct map_args *a = (struct map_args *)args;
330
585
  PyObject *func;
331
586
 
332
- func = gufunc_new(a->tbl, f->name);
587
+ func = gufunc_new(a->tbl, f->name, GM_CPU_FUNC);
333
588
  if (func == NULL) {
334
589
  return -1;
335
590
  }
@@ -349,10 +604,40 @@ Gumath_AddFunctions(PyObject *m, const gm_tbl_t *tbl)
349
604
  return 0;
350
605
  }
351
606
 
607
+ static int
608
+ add_cuda_function(const gm_func_t *f, void *args)
609
+ {
610
+ struct map_args *a = (struct map_args *)args;
611
+ PyObject *func;
612
+
613
+ func = gufunc_new(a->tbl, f->name, GM_CUDA_MANAGED_FUNC);
614
+ if (func == NULL) {
615
+ return -1;
616
+ }
617
+
618
+ return PyModule_AddObject(a->module, f->name, func);
619
+ }
620
+
621
+ static int
622
+ Gumath_AddCudaFunctions(PyObject *m, const gm_tbl_t *tbl)
623
+ {
624
+ struct map_args args = {m, tbl};
625
+
626
+ if (gm_tbl_map(tbl, add_cuda_function, &args) < 0) {
627
+ return -1;
628
+ }
629
+
630
+ return 0;
631
+ }
632
+
352
633
  static PyObject *
353
634
  init_api(void)
354
635
  {
636
+ gumath_api[Gufunc_CheckExact_INDEX] = (void *)Gufunc_CheckExact;
637
+ gumath_api[Gufunc_Check_INDEX] = (void *)Gufunc_Check;
355
638
  gumath_api[Gumath_AddFunctions_INDEX] = (void *)Gumath_AddFunctions;
639
+ gumath_api[Gumath_AddFunctions_INDEX] = (void *)Gumath_AddFunctions;
640
+ gumath_api[Gumath_AddCudaFunctions_INDEX] = (void *)Gumath_AddCudaFunctions;
356
641
 
357
642
  return PyCapsule_New(gumath_api, "gumath._gumath._API", NULL);
358
643
  }
@@ -362,6 +647,75 @@ init_api(void)
362
647
  /* Module */
363
648
  /****************************************************************************/
364
649
 
650
+ static PyObject *
651
+ gufunc_vfold(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwargs)
652
+ {
653
+ static char *kwlist[] = {"f", "acc", NULL};
654
+ PyObject *func = Py_None;
655
+ PyObject *acc = Py_None;
656
+ PyObject *tuple;
657
+ PyObject *dict;
658
+ PyObject *res;
659
+ Py_ssize_t size, i;
660
+ int ret;
661
+
662
+ ret = PyArg_ParseTupleAndKeywords(positional_empty, kwargs, "|$OO", kwlist,
663
+ &func, &acc);
664
+ if (ret < 0) {
665
+ return NULL;
666
+ }
667
+
668
+ if (!Gufunc_Check(func)) {
669
+ PyErr_Format(PyExc_TypeError,
670
+ "vfold: expected gufunc object, got '%.200s'", Py_TYPE(func));
671
+ return NULL;
672
+ }
673
+
674
+ if (!Xnd_Check(acc)) {
675
+ PyErr_Format(PyExc_TypeError,
676
+ "vfold: expected xnd instance, got '%.200s'", Py_TYPE(acc));
677
+ return NULL;
678
+ }
679
+
680
+ /* Push the accumulator onto the argument stack. */
681
+ size = PyTuple_Size(args);
682
+ tuple = PyTuple_New(size+1);
683
+ if (tuple == NULL) {
684
+ return NULL;
685
+ }
686
+
687
+ Py_INCREF(acc);
688
+ PyTuple_SET_ITEM(tuple, 0, acc);
689
+ for (i = 0; i < size; i++) {
690
+ PyObject *v = PyTuple_GET_ITEM(args, i);
691
+ Py_INCREF(v);
692
+ PyTuple_SET_ITEM(tuple, i+1, v);
693
+ }
694
+
695
+ /* Simultaneously use the accumulator as the 'out' argument. */
696
+ dict = PyDict_New();
697
+ if (dict == NULL) {
698
+ Py_DECREF(tuple);
699
+ return NULL;
700
+ }
701
+ if (PyDict_SetItemString(dict, "out", acc) < 0) {
702
+ Py_DECREF(dict);
703
+ Py_DECREF(tuple);
704
+ return NULL;
705
+ }
706
+ if (PyDict_SetItemString(dict, "cls", (PyObject *)(Py_TYPE(acc))) < 0) {
707
+ Py_DECREF(dict);
708
+ Py_DECREF(tuple);
709
+ return NULL;
710
+ }
711
+
712
+ res = _gufunc_call((GufuncObject *)func, tuple, dict, false, false);
713
+ Py_DECREF(tuple);
714
+ Py_DECREF(dict);
715
+
716
+ return res;
717
+ }
718
+
365
719
  static PyObject *
366
720
  unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
367
721
  {
@@ -388,8 +742,8 @@ unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
388
742
  k.name = name;
389
743
  k.sig = sig;
390
744
 
391
- if (strcmp(tag, "Opt") == 0) {
392
- k.Opt = p;
745
+ if (strcmp(tag, "Opt") == 0) { /* XXX */
746
+ k.OptC = p;
393
747
  }
394
748
  else if (strcmp(tag, "C") == 0) {
395
749
  k.C = p;
@@ -418,7 +772,7 @@ unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
418
772
  return seterr(&ctx);
419
773
  }
420
774
 
421
- return gufunc_new(table, f->name);
775
+ return gufunc_new(table, f->name, GM_CPU_FUNC);
422
776
  }
423
777
 
424
778
  static void
@@ -490,6 +844,7 @@ set_max_threads(PyObject *m UNUSED, PyObject *obj)
490
844
  static PyMethodDef gumath_methods [] =
491
845
  {
492
846
  /* Methods */
847
+ { "vfold", (PyCFunction)gufunc_vfold, METH_VARARGS|METH_KEYWORDS, NULL },
493
848
  { "unsafe_add_kernel", (PyCFunction)unsafe_add_kernel, METH_VARARGS|METH_KEYWORDS, NULL },
494
849
  { "get_max_threads", (PyCFunction)get_max_threads, METH_NOARGS, NULL },
495
850
  { "set_max_threads", (PyCFunction)set_max_threads, METH_O, NULL },
@@ -554,11 +909,21 @@ PyInit__gumath(void)
554
909
  goto error;
555
910
  }
556
911
 
912
+ positional_empty = PyTuple_New(0);
913
+ if (positional_empty == NULL) {
914
+ goto error;
915
+ }
916
+
557
917
  m = PyModule_Create(&gumath_module);
558
918
  if (m == NULL) {
559
919
  goto error;
560
920
  }
561
921
 
922
+ Py_INCREF(&Gufunc_Type);
923
+ if (PyModule_AddObject(m, "gufunc", (PyObject *)&Gufunc_Type) < 0) {
924
+ goto error;
925
+ }
926
+
562
927
  Py_INCREF(capsule);
563
928
  if (PyModule_AddObject(m, "_API", capsule) < 0) {
564
929
  goto error;
@@ -571,6 +936,7 @@ PyInit__gumath(void)
571
936
  return m;
572
937
 
573
938
  error:
939
+ Py_CLEAR(positional_empty);
574
940
  Py_CLEAR(xnd);
575
941
  Py_CLEAR(m);
576
942
  return NULL;