gumath 0.2.0dev5 → 0.2.0dev8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
|
Binary file
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -72,7 +72,7 @@ clear_all_slices(xnd_t *slices[], int *nslices, int stop)
|
|
|
72
72
|
{
|
|
73
73
|
for (int i = 0; i < stop; i++) {
|
|
74
74
|
for (int k = 0; k < nslices[i]; k++) {
|
|
75
|
-
|
|
75
|
+
ndt_decref(slices[i][k].type);
|
|
76
76
|
}
|
|
77
77
|
ndt_free(slices[i]);
|
|
78
78
|
}
|
|
@@ -94,16 +94,27 @@ apply_thread(void *arg)
|
|
|
94
94
|
|
|
95
95
|
int
|
|
96
96
|
gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
97
|
-
|
|
97
|
+
const int64_t nthreads, ndt_context_t *ctx)
|
|
98
98
|
{
|
|
99
99
|
const int nrows = (int)kernel->set->sig->Function.nargs;
|
|
100
100
|
ALLOCA(xnd_t *, slices, nrows);
|
|
101
101
|
ALLOCA(int, nslices, nrows);
|
|
102
102
|
struct thread_info *tinfo;
|
|
103
103
|
int ncols, tnum;
|
|
104
|
+
bool use_threads = true;
|
|
104
105
|
|
|
105
|
-
if (nthreads <= 1 || nrows == 0 || outer_dims == 0
|
|
106
|
-
|
|
106
|
+
if (nthreads <= 1 || nrows == 0 || outer_dims == 0) {
|
|
107
|
+
use_threads = false;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
for (int i = 0; i < nrows; i++) {
|
|
111
|
+
const ndt_t *t = stack[i].type;
|
|
112
|
+
if (!ndt_is_ndarray(t) || ndt_nelem(t) < GM_THREAD_CUTOFF) {
|
|
113
|
+
use_threads = false;
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
if (!use_threads) {
|
|
107
118
|
return gm_apply(kernel, stack, outer_dims, ctx);
|
|
108
119
|
}
|
|
109
120
|
|
|
@@ -147,6 +158,7 @@ gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
|
147
158
|
&tinfo[tnum]);
|
|
148
159
|
if (ret != 0) {
|
|
149
160
|
clear_all_slices(slices, nslices, nrows);
|
|
161
|
+
ndt_free(tinfo);
|
|
150
162
|
ndt_err_format(ctx, NDT_RuntimeError, "could not create thread");
|
|
151
163
|
return -1;
|
|
152
164
|
}
|
|
@@ -169,6 +181,7 @@ gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
|
169
181
|
}
|
|
170
182
|
|
|
171
183
|
clear_all_slices(slices, nslices, nrows);
|
|
184
|
+
ndt_free(tinfo);
|
|
172
185
|
|
|
173
186
|
return ndt_err_occurred(ctx) ? -1 : 0;
|
|
174
187
|
}
|
|
Binary file
|
|
@@ -38,11 +38,99 @@
|
|
|
38
38
|
#include "ndtypes.h"
|
|
39
39
|
#include "xnd.h"
|
|
40
40
|
#include "gumath.h"
|
|
41
|
+
#include "overflow.h"
|
|
41
42
|
|
|
42
43
|
|
|
44
|
+
static int _gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
|
45
|
+
const int outer_dims, ndt_context_t *ctx);
|
|
46
|
+
|
|
47
|
+
int
|
|
48
|
+
array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx)
|
|
49
|
+
{
|
|
50
|
+
const ndt_t *t = x->type;
|
|
51
|
+
|
|
52
|
+
if (t->tag != Array) {
|
|
53
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
54
|
+
"type mismatch in outer dimensions");
|
|
55
|
+
return -1;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if (XND_ARRAY_DATA(x->ptr) == NULL) {
|
|
59
|
+
bool overflow = false;
|
|
60
|
+
const int64_t size = MULi64(shape, t->Array.itemsize, &overflow);
|
|
61
|
+
if (overflow) {
|
|
62
|
+
ndt_err_format(ctx, NDT_ValueError,
|
|
63
|
+
"datasize of flexible array is too large");
|
|
64
|
+
return -1;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
char *data = ndt_aligned_calloc(t->align, size);
|
|
68
|
+
if (data == NULL) {
|
|
69
|
+
ndt_err_format(ctx, NDT_MemoryError, "out of memory");
|
|
70
|
+
return -1;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
XND_ARRAY_SHAPE(x->ptr) = shape;
|
|
74
|
+
XND_ARRAY_DATA(x->ptr) = data;
|
|
75
|
+
|
|
76
|
+
return 0;
|
|
77
|
+
}
|
|
78
|
+
else if (XND_ARRAY_SHAPE(x->ptr) != shape) {
|
|
79
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
80
|
+
"shape mismatch in outer dimensions");
|
|
81
|
+
return -1;
|
|
82
|
+
}
|
|
83
|
+
else {
|
|
84
|
+
return 0;
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
static inline bool
|
|
89
|
+
any_stored_index(xnd_t stack[], const int nargs)
|
|
90
|
+
{
|
|
91
|
+
for (int i = 0; i < nargs; i++) {
|
|
92
|
+
if (stack[i].ptr == NULL) {
|
|
93
|
+
continue;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
const ndt_t *t = stack[i].type;
|
|
97
|
+
if (have_stored_index(t)) {
|
|
98
|
+
return true;
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
return false;
|
|
103
|
+
}
|
|
104
|
+
|
|
43
105
|
int
|
|
44
106
|
gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
|
45
107
|
const int outer_dims, ndt_context_t *ctx)
|
|
108
|
+
{
|
|
109
|
+
if (any_stored_index(stack, nargs)) {
|
|
110
|
+
ALLOCA(xnd_t, next, nargs);
|
|
111
|
+
|
|
112
|
+
for (int i = 0; i < nargs; i++) {
|
|
113
|
+
const ndt_t *t = stack[i].type;
|
|
114
|
+
if (have_stored_index(t)) {
|
|
115
|
+
next[i] = apply_stored_indices(&stack[i], ctx);
|
|
116
|
+
if (xnd_err_occurred(&next[i])) {
|
|
117
|
+
return -1;
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
else {
|
|
121
|
+
next[i] = stack[i];
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
return _gm_xnd_map(f, next, nargs, outer_dims, ctx);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
return _gm_xnd_map(f, stack, nargs, outer_dims, ctx);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
static int
|
|
132
|
+
_gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
|
133
|
+
const int outer_dims, ndt_context_t *ctx)
|
|
46
134
|
{
|
|
47
135
|
ALLOCA(xnd_t, next, nargs);
|
|
48
136
|
const ndt_t *t;
|
|
@@ -123,6 +211,28 @@ gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
|
|
123
211
|
return 0;
|
|
124
212
|
}
|
|
125
213
|
|
|
214
|
+
case Array: {
|
|
215
|
+
const int64_t shape = XND_ARRAY_SHAPE(stack[0].ptr);
|
|
216
|
+
|
|
217
|
+
for (int k = 1; k < nargs; k++) {
|
|
218
|
+
if (array_shape_check(&stack[k], shape, ctx) < 0) {
|
|
219
|
+
return -1;
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
for (int64_t i = 0; i < shape; i++) {
|
|
224
|
+
for (int k = 0; k < nargs; k++) {
|
|
225
|
+
next[k] = xnd_array_next(&stack[k], i);
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
if (gm_xnd_map(f, next, nargs, outer_dims-1, ctx) < 0) {
|
|
229
|
+
return -1;
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
return 0;
|
|
234
|
+
}
|
|
235
|
+
|
|
126
236
|
default:
|
|
127
237
|
ndt_err_format(ctx, NDT_NotImplementedError, "unsupported type");
|
|
128
238
|
return -1;
|
|
Binary file
|
|
@@ -33,7 +33,157 @@
|
|
|
33
33
|
from ndtypes import ndt
|
|
34
34
|
from xnd import xnd
|
|
35
35
|
from ._gumath import *
|
|
36
|
+
from . import functions as _fn
|
|
36
37
|
|
|
38
|
+
try:
|
|
39
|
+
from . import cuda as _cd
|
|
40
|
+
except ImportError:
|
|
41
|
+
_cd = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
__all__ = ['cuda', 'fold', 'functions', 'get_max_threads', 'gufunc', 'reduce',
|
|
45
|
+
'set_max_threads', 'unsafe_add_kernel', 'vfold', 'xndvectorize']
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# ==============================================================================
|
|
49
|
+
# Init identity elements
|
|
50
|
+
# ==============================================================================
|
|
51
|
+
|
|
52
|
+
# This is done here now, perhaps it should be on the C level.
|
|
53
|
+
_fn.add.identity = 0
|
|
54
|
+
_fn.multiply.identity = 1
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# ==============================================================================
|
|
58
|
+
# General fold function
|
|
59
|
+
# ==============================================================================
|
|
60
|
+
|
|
61
|
+
def fold(f, acc, x):
|
|
62
|
+
return vfold(x, f=f, acc=acc)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ==============================================================================
|
|
66
|
+
# NumPy's reduce in terms of fold
|
|
67
|
+
# ==============================================================================
|
|
68
|
+
|
|
69
|
+
def _get_axes(axes, ndim):
|
|
70
|
+
type_err = "'axes' must be None, a single integer or a tuple of integers"
|
|
71
|
+
value_err = "axis with value %d out of range"
|
|
72
|
+
duplicate_err = "'axes' argument contains duplicate values"
|
|
73
|
+
if axes is None:
|
|
74
|
+
axes = tuple(range(ndim))
|
|
75
|
+
elif isinstance(axes, int):
|
|
76
|
+
axes = (axes,)
|
|
77
|
+
elif not isinstance(axes, tuple) or \
|
|
78
|
+
any(not isinstance(v, int) for v in axes):
|
|
79
|
+
raise TypeError(type_err)
|
|
80
|
+
|
|
81
|
+
if any(n >= ndim for n in axes):
|
|
82
|
+
raise ValueError(value_err % n)
|
|
83
|
+
|
|
84
|
+
if len(set(axes)) != len(axes):
|
|
85
|
+
raise ValueError(duplicate_err)
|
|
86
|
+
|
|
87
|
+
return list(axes)
|
|
88
|
+
|
|
89
|
+
def _copyto(dest, value):
|
|
90
|
+
x = xnd(value, dtype=dest.dtype)
|
|
91
|
+
_fn.copy(x, out=dest)
|
|
92
|
+
|
|
93
|
+
def reduce_cpu(f, x, axes, dtype):
|
|
94
|
+
"""NumPy's reduce in terms of fold."""
|
|
95
|
+
axes = _get_axes(axes, x.ndim)
|
|
96
|
+
if not axes:
|
|
97
|
+
return x
|
|
98
|
+
|
|
99
|
+
permute = [n for n in range(x.ndim) if n not in axes]
|
|
100
|
+
permute = axes + permute
|
|
101
|
+
|
|
102
|
+
T = x.transpose(permute=permute)
|
|
103
|
+
|
|
104
|
+
N = len(axes)
|
|
105
|
+
t = T.type.at(N, dtype=dtype)
|
|
106
|
+
acc = x.empty(t, device=x.device)
|
|
107
|
+
|
|
108
|
+
if f.identity is not None:
|
|
109
|
+
_copyto(acc, f.identity)
|
|
110
|
+
tl = T
|
|
111
|
+
elif N == 1 and T.type.shape[0] > 0:
|
|
112
|
+
hd, tl = T[0], T[1:]
|
|
113
|
+
acc[()] = hd
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
"reduction not possible for function without an identity element")
|
|
117
|
+
|
|
118
|
+
return fold(f, acc, tl)
|
|
119
|
+
|
|
120
|
+
def reduce_cuda(g, x, axes, dtype):
|
|
121
|
+
"""Reductions in CUDA use the thrust library for speed and have limited
|
|
122
|
+
functionality."""
|
|
123
|
+
if axes != 0:
|
|
124
|
+
raise NotImplementedError("'axes' keyword is not implemented for CUDA")
|
|
125
|
+
|
|
126
|
+
return g(x, dtype=dtype)
|
|
127
|
+
|
|
128
|
+
def get_cuda_reduction_func(f):
|
|
129
|
+
if _cd is None:
|
|
130
|
+
return None
|
|
131
|
+
elif f == _cd.add:
|
|
132
|
+
return _cd.reduce_add
|
|
133
|
+
elif f == _cd.multiply:
|
|
134
|
+
return _cd.reduce_multiply
|
|
135
|
+
else:
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
def reduce(f, x, axes=0, dtype=None):
|
|
139
|
+
if dtype is None:
|
|
140
|
+
dtype = maxcast[x.dtype]
|
|
141
|
+
|
|
142
|
+
g = get_cuda_reduction_func(f)
|
|
143
|
+
if g is not None:
|
|
144
|
+
return reduce_cuda(g, x, axes, dtype)
|
|
145
|
+
|
|
146
|
+
return reduce_cpu(f, x, axes, dtype)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
maxcast = {
|
|
150
|
+
ndt("int8"): ndt("int64"),
|
|
151
|
+
ndt("int16"): ndt("int64"),
|
|
152
|
+
ndt("int32"): ndt("int64"),
|
|
153
|
+
ndt("int64"): ndt("int64"),
|
|
154
|
+
ndt("uint8"): ndt("uint64"),
|
|
155
|
+
ndt("uint16"): ndt("uint64"),
|
|
156
|
+
ndt("uint32"): ndt("uint64"),
|
|
157
|
+
ndt("uint64"): ndt("uint64"),
|
|
158
|
+
ndt("bfloat16"): ndt("float64"),
|
|
159
|
+
ndt("float16"): ndt("float64"),
|
|
160
|
+
ndt("float32"): ndt("float64"),
|
|
161
|
+
ndt("float64"): ndt("float64"),
|
|
162
|
+
ndt("complex32"): ndt("complex128"),
|
|
163
|
+
ndt("complex64"): ndt("complex128"),
|
|
164
|
+
ndt("complex128"): ndt("complex128"),
|
|
165
|
+
|
|
166
|
+
ndt("?int8"): ndt("?int64"),
|
|
167
|
+
ndt("?int16"): ndt("?int64"),
|
|
168
|
+
ndt("?int32"): ndt("?int64"),
|
|
169
|
+
ndt("?int64"): ndt("?int64"),
|
|
170
|
+
ndt("?uint8"): ndt("?uint64"),
|
|
171
|
+
ndt("?uint16"): ndt("?uint64"),
|
|
172
|
+
ndt("?uint32"): ndt("?uint64"),
|
|
173
|
+
ndt("?uint64"): ndt("?uint64"),
|
|
174
|
+
ndt("?bfloat16"): ndt("?float64"),
|
|
175
|
+
ndt("?float16"): ndt("?float64"),
|
|
176
|
+
ndt("?float32"): ndt("?float64"),
|
|
177
|
+
ndt("?float64"): ndt("?float64"),
|
|
178
|
+
ndt("?complex32"): ndt("?complex128"),
|
|
179
|
+
ndt("?complex64"): ndt("?complex128"),
|
|
180
|
+
ndt("?complex128"): ndt("?complex128"),
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# ==============================================================================
|
|
185
|
+
# Numba's GUVectorize on xnd arrays
|
|
186
|
+
# ==============================================================================
|
|
37
187
|
|
|
38
188
|
try:
|
|
39
189
|
import numpy as np
|
|
@@ -45,16 +45,24 @@
|
|
|
45
45
|
#define GUMATH_MODULE
|
|
46
46
|
#include "pygumath.h"
|
|
47
47
|
|
|
48
|
+
|
|
48
49
|
#ifdef _MSC_VER
|
|
49
50
|
#ifndef UNUSED
|
|
50
51
|
#define UNUSED
|
|
51
52
|
#endif
|
|
53
|
+
#include <float.h>
|
|
54
|
+
#include <fenv.h>
|
|
55
|
+
#pragma fenv_access(on)
|
|
52
56
|
#else
|
|
53
57
|
#if defined(__GNUC__) && !defined(__INTEL_COMPILER)
|
|
54
58
|
#define UNUSED __attribute__((unused))
|
|
55
59
|
#else
|
|
56
60
|
#define UNUSED
|
|
57
61
|
#endif
|
|
62
|
+
#include <fenv.h>
|
|
63
|
+
#if 0 /* Not supported by gcc and clang. */
|
|
64
|
+
#pragma STDC FENV_ACCESS ON
|
|
65
|
+
#endif
|
|
58
66
|
#endif
|
|
59
67
|
|
|
60
68
|
|
|
@@ -73,6 +81,9 @@ static gm_tbl_t *table = NULL;
|
|
|
73
81
|
/* Xnd type */
|
|
74
82
|
static PyTypeObject *xnd = NULL;
|
|
75
83
|
|
|
84
|
+
/* Empty positional arguments */
|
|
85
|
+
static PyObject *positional_empty = NULL;
|
|
86
|
+
|
|
76
87
|
/* Maximum number of threads */
|
|
77
88
|
static int64_t max_threads = 1;
|
|
78
89
|
|
|
@@ -95,7 +106,7 @@ seterr(ndt_context_t *ctx)
|
|
|
95
106
|
static PyTypeObject Gufunc_Type;
|
|
96
107
|
|
|
97
108
|
static PyObject *
|
|
98
|
-
gufunc_new(const gm_tbl_t *tbl, const char *name)
|
|
109
|
+
gufunc_new(const gm_tbl_t *tbl, const char *name, const uint32_t flags)
|
|
99
110
|
{
|
|
100
111
|
NDT_STATIC_CONTEXT(ctx);
|
|
101
112
|
GufuncObject *self;
|
|
@@ -106,12 +117,16 @@ gufunc_new(const gm_tbl_t *tbl, const char *name)
|
|
|
106
117
|
}
|
|
107
118
|
|
|
108
119
|
self->tbl = tbl;
|
|
120
|
+
self->flags = flags;
|
|
109
121
|
|
|
110
122
|
self->name = ndt_strdup(name, &ctx);
|
|
111
123
|
if (self->name == NULL) {
|
|
112
124
|
return seterr(&ctx);
|
|
113
125
|
}
|
|
114
126
|
|
|
127
|
+
self->identity = Py_None;
|
|
128
|
+
Py_INCREF(self->identity);
|
|
129
|
+
|
|
115
130
|
return (PyObject *)self;
|
|
116
131
|
}
|
|
117
132
|
|
|
@@ -119,6 +134,7 @@ static void
|
|
|
119
134
|
gufunc_dealloc(GufuncObject *self)
|
|
120
135
|
{
|
|
121
136
|
ndt_free(self->name);
|
|
137
|
+
Py_DECREF(self->identity);
|
|
122
138
|
PyObject_Del(self);
|
|
123
139
|
}
|
|
124
140
|
|
|
@@ -128,124 +144,317 @@ gufunc_dealloc(GufuncObject *self)
|
|
|
128
144
|
/****************************************************************************/
|
|
129
145
|
|
|
130
146
|
static void
|
|
131
|
-
|
|
147
|
+
clear_pystack(PyObject *pystack[], Py_ssize_t len)
|
|
148
|
+
{
|
|
149
|
+
for (Py_ssize_t i = 0; i < len; i++) {
|
|
150
|
+
Py_CLEAR(pystack[i]);
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
static int
|
|
155
|
+
parse_args(PyObject *pystack[NDT_MAX_ARGS], int *py_nin, int *py_nout, int *py_nargs,
|
|
156
|
+
PyObject *args, PyObject *out)
|
|
132
157
|
{
|
|
133
|
-
Py_ssize_t
|
|
158
|
+
Py_ssize_t nin;
|
|
159
|
+
Py_ssize_t nout;
|
|
134
160
|
|
|
135
|
-
|
|
136
|
-
|
|
161
|
+
if (!args || !PyTuple_Check(args)) {
|
|
162
|
+
const char *name = args ? Py_TYPE(args)->tp_name : "NULL";
|
|
163
|
+
PyErr_Format(PyExc_SystemError,
|
|
164
|
+
"internal error: expected tuple, got '%.200s'", name);
|
|
165
|
+
return -1;
|
|
137
166
|
}
|
|
167
|
+
|
|
168
|
+
nin = PyTuple_GET_SIZE(args);
|
|
169
|
+
if (nin > NDT_MAX_ARGS) {
|
|
170
|
+
PyErr_Format(PyExc_TypeError,
|
|
171
|
+
"maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin);
|
|
172
|
+
return -1;
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
for (Py_ssize_t i = 0; i < nin; i++) {
|
|
176
|
+
PyObject *v = PyTuple_GET_ITEM(args, i);
|
|
177
|
+
if (!Xnd_Check(v)) {
|
|
178
|
+
PyErr_Format(PyExc_TypeError,
|
|
179
|
+
"expected xnd argument, got '%.200s'", Py_TYPE(v)->tp_name);
|
|
180
|
+
return -1;
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
pystack[i] = v;
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
if (out == NULL) {
|
|
187
|
+
nout = 0;
|
|
188
|
+
}
|
|
189
|
+
else {
|
|
190
|
+
if (Xnd_Check(out)) {
|
|
191
|
+
nout = 1;
|
|
192
|
+
if (nin+nout > NDT_MAX_ARGS) {
|
|
193
|
+
PyErr_Format(PyExc_TypeError,
|
|
194
|
+
"maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin+nout);
|
|
195
|
+
return -1;
|
|
196
|
+
}
|
|
197
|
+
pystack[nin] = out;
|
|
198
|
+
}
|
|
199
|
+
else if (PyTuple_Check(out)) {
|
|
200
|
+
nout = PyTuple_GET_SIZE(out);
|
|
201
|
+
if (nout > NDT_MAX_ARGS || nin+nout > NDT_MAX_ARGS) {
|
|
202
|
+
PyErr_Format(PyExc_TypeError,
|
|
203
|
+
"maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin+nout);
|
|
204
|
+
return -1;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
for (Py_ssize_t i = 0; i < nout; i++) {
|
|
208
|
+
PyObject *v = PyTuple_GET_ITEM(out, i);
|
|
209
|
+
if (!Xnd_Check(v)) {
|
|
210
|
+
PyErr_Format(PyExc_TypeError,
|
|
211
|
+
"expected xnd argument, got '%.200s'", Py_TYPE(v)->tp_name);
|
|
212
|
+
return -1;
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
pystack[nin+i] = v;
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
else {
|
|
219
|
+
PyErr_Format(PyExc_TypeError,
|
|
220
|
+
"'out' argument must be xnd or a tuple of xnd, got '%.200s'",
|
|
221
|
+
Py_TYPE(out)->tp_name);
|
|
222
|
+
return -1;
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
for (int i = 0; i < nin+nout; i++) {
|
|
227
|
+
Py_INCREF(pystack[i]);
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
*py_nin = (int)nin;
|
|
231
|
+
*py_nout = (int)nout;
|
|
232
|
+
*py_nargs = (int)nin+(int)nout;
|
|
233
|
+
|
|
234
|
+
return 0;
|
|
138
235
|
}
|
|
139
236
|
|
|
140
237
|
static PyObject *
|
|
141
|
-
|
|
238
|
+
_gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwargs,
|
|
239
|
+
bool enable_threads, bool check_broadcast)
|
|
142
240
|
{
|
|
241
|
+
static char *kwlist[] = {"out", "dtype", "cls", NULL};
|
|
242
|
+
PyObject *out = Py_None;
|
|
243
|
+
PyObject *dt = Py_None;
|
|
244
|
+
PyObject *cls = Py_None;
|
|
245
|
+
|
|
143
246
|
NDT_STATIC_CONTEXT(ctx);
|
|
144
|
-
|
|
145
|
-
PyObject **a = &PyTuple_GET_ITEM(args, 0);
|
|
146
|
-
PyObject *result[NDT_MAX_ARGS];
|
|
147
|
-
ndt_apply_spec_t spec = ndt_apply_spec_empty;
|
|
148
|
-
const ndt_t *in_types[NDT_MAX_ARGS];
|
|
247
|
+
PyObject *pystack[NDT_MAX_ARGS];
|
|
149
248
|
xnd_t stack[NDT_MAX_ARGS];
|
|
249
|
+
const ndt_t *types[NDT_MAX_ARGS];
|
|
250
|
+
int64_t li[NDT_MAX_ARGS];
|
|
251
|
+
ndt_apply_spec_t spec = ndt_apply_spec_empty;
|
|
150
252
|
gm_kernel_t kernel;
|
|
151
|
-
|
|
253
|
+
bool have_cpu_device = false;
|
|
254
|
+
ndt_t *dtype = NULL;
|
|
255
|
+
int nin, nout, nargs;
|
|
256
|
+
int k;
|
|
152
257
|
|
|
153
|
-
if (
|
|
154
|
-
|
|
155
|
-
"gufunc calls do not support keywords");
|
|
258
|
+
if (!PyArg_ParseTupleAndKeywords(positional_empty, kwargs, "|$OOO", kwlist,
|
|
259
|
+
&out, &dt, &cls)) {
|
|
156
260
|
return NULL;
|
|
157
261
|
}
|
|
158
262
|
|
|
159
|
-
|
|
263
|
+
out = out == Py_None ? NULL : out;
|
|
264
|
+
dt = dt == Py_None ? NULL : dt;
|
|
265
|
+
cls = cls == Py_None ? (PyObject *)xnd : cls;
|
|
266
|
+
|
|
267
|
+
if (dt != NULL) {
|
|
268
|
+
if (out != NULL) {
|
|
269
|
+
PyErr_SetString(PyExc_TypeError,
|
|
270
|
+
"the 'out' and 'dtype' arguments are mutually exclusive");
|
|
271
|
+
return NULL;
|
|
272
|
+
}
|
|
273
|
+
if (!Ndt_Check(dt)) {
|
|
274
|
+
PyErr_Format(PyExc_TypeError,
|
|
275
|
+
"'dtype' argument must be ndt, got '%.200s'",
|
|
276
|
+
Py_TYPE(dt)->tp_name);
|
|
277
|
+
return NULL;
|
|
278
|
+
}
|
|
279
|
+
dtype = (ndt_t *)NDT(dt);
|
|
280
|
+
ndt_incref(dtype);
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
if (!PyType_Check(cls) || !PyType_IsSubtype((PyTypeObject *)cls, xnd)) {
|
|
160
284
|
PyErr_SetString(PyExc_TypeError,
|
|
161
|
-
"
|
|
285
|
+
"the 'cls' argument must be a subtype of 'xnd'");
|
|
286
|
+
return NULL;
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
if (parse_args(pystack, &nin, &nout, &nargs, args, out) < 0) {
|
|
162
290
|
return NULL;
|
|
163
291
|
}
|
|
292
|
+
assert(nout == 0 || dtype == NULL);
|
|
164
293
|
|
|
165
|
-
for (
|
|
166
|
-
|
|
167
|
-
|
|
294
|
+
for (k = 0; k < nargs; k++) {
|
|
295
|
+
const XndObject *x = (XndObject *)pystack[k];
|
|
296
|
+
if (!(x->mblock->xnd->flags&XND_CUDA_MANAGED)) {
|
|
297
|
+
have_cpu_device = true;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
stack[k] = *CONST_XND((PyObject *)x);
|
|
301
|
+
types[k] = stack[k].type;
|
|
302
|
+
li[k] = stack[k].index;
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
if (have_cpu_device) {
|
|
306
|
+
if (self->flags & GM_CUDA_MANAGED_FUNC) {
|
|
307
|
+
PyErr_SetString(PyExc_ValueError,
|
|
308
|
+
"cannot run a cuda function on xnd objects with cpu memory");
|
|
309
|
+
clear_pystack(pystack, nargs);
|
|
168
310
|
return NULL;
|
|
169
311
|
}
|
|
170
|
-
stack[i] = *CONST_XND(a[i]);
|
|
171
|
-
in_types[i] = stack[i].type;
|
|
172
312
|
}
|
|
173
313
|
|
|
174
|
-
kernel = gm_select(&spec, self->tbl, self->name,
|
|
314
|
+
kernel = gm_select(&spec, self->tbl, self->name, types, li, nin, nout,
|
|
315
|
+
nout && check_broadcast, stack, &ctx);
|
|
175
316
|
if (kernel.set == NULL) {
|
|
176
317
|
return seterr(&ctx);
|
|
177
318
|
}
|
|
178
319
|
|
|
179
|
-
if (
|
|
180
|
-
|
|
181
|
-
|
|
320
|
+
if (dtype) {
|
|
321
|
+
if (spec.nout != 1) {
|
|
322
|
+
ndt_err_format(&ctx, NDT_TypeError,
|
|
323
|
+
"the 'dtype' argument is only supported for a single "
|
|
324
|
+
"return value");
|
|
325
|
+
ndt_apply_spec_clear(&spec);
|
|
326
|
+
ndt_decref(dtype);
|
|
327
|
+
return seterr(&ctx);
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
const ndt_t *u = spec.types[spec.nin];
|
|
331
|
+
const ndt_t *v = ndt_copy_contiguous_dtype(u, dtype, 0, &ctx);
|
|
332
|
+
|
|
333
|
+
ndt_apply_spec_clear(&spec);
|
|
334
|
+
ndt_decref(dtype);
|
|
335
|
+
|
|
336
|
+
if (v == NULL) {
|
|
337
|
+
return seterr(&ctx);
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
types[nin] = v;
|
|
341
|
+
kernel = gm_select(&spec, self->tbl, self->name, types, li, nin, 1,
|
|
342
|
+
1 && check_broadcast, stack, &ctx);
|
|
343
|
+
if (kernel.set == NULL) {
|
|
344
|
+
return seterr(&ctx);
|
|
182
345
|
}
|
|
183
346
|
}
|
|
184
347
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
348
|
+
/*
|
|
349
|
+
* Replace args/kwargs types with types after substitution and broadcasting.
|
|
350
|
+
* This includes 'out' types, if explicitly passed as kwargs.
|
|
351
|
+
*/
|
|
352
|
+
for (int i = 0; i < spec.nargs; i++) {
|
|
353
|
+
stack[i].type = spec.types[i];
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
if (nout == 0) {
|
|
357
|
+
/* 'out' types have been inferred, create new XndObjects. */
|
|
358
|
+
for (int i = 0; i < spec.nout; i++) {
|
|
359
|
+
if (ndt_is_concrete(spec.types[nin+i])) {
|
|
360
|
+
uint32_t flags = self->flags == GM_CUDA_MANAGED_FUNC ? XND_CUDA_MANAGED : 0;
|
|
361
|
+
PyObject *x = Xnd_EmptyFromType((PyTypeObject *)cls, spec.types[nin+i], flags);
|
|
362
|
+
if (x == NULL) {
|
|
363
|
+
clear_pystack(pystack, nin+i);
|
|
364
|
+
ndt_apply_spec_clear(&spec);
|
|
191
365
|
return NULL;
|
|
192
366
|
}
|
|
193
|
-
|
|
367
|
+
pystack[nin+i] = x;
|
|
194
368
|
stack[nin+i] = *CONST_XND(x);
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
369
|
+
}
|
|
370
|
+
else {
|
|
371
|
+
clear_pystack(pystack, nin+i);
|
|
372
|
+
ndt_apply_spec_clear(&spec);
|
|
373
|
+
PyErr_SetString(PyExc_ValueError,
|
|
374
|
+
"arguments with abstract types are temporarily disabled");
|
|
375
|
+
return NULL;
|
|
376
|
+
}
|
|
377
|
+
}
|
|
200
378
|
}
|
|
201
379
|
|
|
202
|
-
|
|
203
|
-
if
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
return seterr(&ctx);
|
|
212
|
-
}
|
|
213
|
-
#endif
|
|
380
|
+
if (self->flags == GM_CUDA_MANAGED_FUNC) {
|
|
381
|
+
#if HAVE_CUDA
|
|
382
|
+
if (!check_broadcast) {
|
|
383
|
+
ndt_err_format(&ctx, NDT_NotImplementedError,
|
|
384
|
+
"fold() is currently not supported on cuda");
|
|
385
|
+
clear_pystack(pystack, spec.nargs);
|
|
386
|
+
ndt_apply_spec_clear(&spec);
|
|
387
|
+
return seterr(&ctx);
|
|
388
|
+
}
|
|
214
389
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
clear_objects(result, i);
|
|
222
|
-
for (k = i+1; k < spec.nout; k++) {
|
|
223
|
-
if (ndt_is_abstract(spec.out[k])) {
|
|
224
|
-
xnd_del_buffer(&stack[nin+k], XND_OWN_ALL);
|
|
225
|
-
}
|
|
226
|
-
}
|
|
227
|
-
}
|
|
228
|
-
result[i] = x;
|
|
390
|
+
const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
|
|
391
|
+
|
|
392
|
+
if (xnd_cuda_device_synchronize(&ctx) < 0 || ret < 0) {
|
|
393
|
+
clear_pystack(pystack, spec.nargs);
|
|
394
|
+
ndt_apply_spec_clear(&spec);
|
|
395
|
+
return seterr(&ctx);
|
|
229
396
|
}
|
|
397
|
+
#else
|
|
398
|
+
ndt_err_format(&ctx, NDT_RuntimeError,
|
|
399
|
+
"internal error: GM_CUDA_MANAGED_FUNC set in a build without cuda support");
|
|
400
|
+
clear_pystack(pystack, spec.nargs);
|
|
401
|
+
ndt_apply_spec_clear(&spec);
|
|
402
|
+
return seterr(&ctx);
|
|
403
|
+
#endif
|
|
230
404
|
}
|
|
405
|
+
else {
|
|
406
|
+
#ifdef HAVE_PTHREAD_H
|
|
407
|
+
const int rounding = fegetround();
|
|
408
|
+
fesetround(FE_TONEAREST);
|
|
409
|
+
|
|
410
|
+
const int64_t N = enable_threads ? max_threads : 1;
|
|
411
|
+
const int ret = gm_apply_thread(&kernel, stack, spec.outer_dims, N,
|
|
412
|
+
&ctx);
|
|
413
|
+
fesetround(rounding);
|
|
414
|
+
|
|
415
|
+
if (ret < 0) {
|
|
416
|
+
clear_pystack(pystack, spec.nargs);
|
|
417
|
+
ndt_apply_spec_clear(&spec);
|
|
418
|
+
return seterr(&ctx);
|
|
419
|
+
}
|
|
420
|
+
#else
|
|
421
|
+
const int rounding = fegetround();
|
|
422
|
+
fesetround(FE_TONEAREST);
|
|
423
|
+
|
|
424
|
+
const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
|
|
231
425
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
426
|
+
fesetround(rounding);
|
|
427
|
+
|
|
428
|
+
if (ret < 0) {
|
|
429
|
+
clear_pystack(pystack, spec.nargs);
|
|
430
|
+
ndt_apply_spec_clear(&spec);
|
|
431
|
+
return seterr(&ctx);
|
|
235
432
|
}
|
|
433
|
+
#endif
|
|
236
434
|
}
|
|
237
435
|
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
436
|
+
nin = spec.nin;
|
|
437
|
+
nout = spec.nout;
|
|
438
|
+
nargs = spec.nargs;
|
|
439
|
+
ndt_apply_spec_clear(&spec);
|
|
440
|
+
|
|
441
|
+
switch (nout) {
|
|
442
|
+
case 0: {
|
|
443
|
+
clear_pystack(pystack, nargs);
|
|
444
|
+
Py_RETURN_NONE;
|
|
445
|
+
}
|
|
446
|
+
case 1: {
|
|
447
|
+
clear_pystack(pystack, nin);
|
|
448
|
+
return pystack[nin];
|
|
449
|
+
}
|
|
241
450
|
default: {
|
|
242
|
-
PyObject *tuple = PyTuple_New(
|
|
451
|
+
PyObject *tuple = PyTuple_New(nout);
|
|
243
452
|
if (tuple == NULL) {
|
|
244
|
-
|
|
453
|
+
clear_pystack(pystack, nargs);
|
|
245
454
|
return NULL;
|
|
246
455
|
}
|
|
247
|
-
for (i = 0; i <
|
|
248
|
-
PyTuple_SET_ITEM(tuple, i,
|
|
456
|
+
for (int i = 0; i < nout; i++) {
|
|
457
|
+
PyTuple_SET_ITEM(tuple, i, pystack[nin+i]);
|
|
249
458
|
}
|
|
250
459
|
return tuple;
|
|
251
460
|
}
|
|
@@ -253,7 +462,39 @@ gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwds)
|
|
|
253
462
|
}
|
|
254
463
|
|
|
255
464
|
static PyObject *
|
|
256
|
-
|
|
465
|
+
gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwargs)
|
|
466
|
+
{
|
|
467
|
+
return _gufunc_call(self, args, kwargs, true, true);
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
static PyObject *
|
|
471
|
+
gufunc_getdevice(GufuncObject *self, PyObject *args GM_UNUSED)
|
|
472
|
+
{
|
|
473
|
+
if (self->flags & GM_CUDA_MANAGED_FUNC) {
|
|
474
|
+
return PyUnicode_FromString("cuda:managed");
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
Py_RETURN_NONE;
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
static PyObject *
|
|
481
|
+
gufunc_getidentity(GufuncObject *self, PyObject *args GM_UNUSED)
|
|
482
|
+
{
|
|
483
|
+
Py_INCREF(self->identity);
|
|
484
|
+
return self->identity;
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
static int
|
|
488
|
+
gufunc_setidentity(GufuncObject *self, PyObject *value, void *closure GM_UNUSED)
|
|
489
|
+
{
|
|
490
|
+
Py_DECREF(self->identity);
|
|
491
|
+
Py_INCREF(value);
|
|
492
|
+
self->identity = value;
|
|
493
|
+
return 0;
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
static PyObject *
|
|
497
|
+
gufunc_getkernels(GufuncObject *self, PyObject *args GM_UNUSED)
|
|
257
498
|
{
|
|
258
499
|
NDT_STATIC_CONTEXT(ctx);
|
|
259
500
|
PyObject *list, *tmp;
|
|
@@ -294,7 +535,9 @@ gufunc_kernels(GufuncObject *self, PyObject *args GM_UNUSED)
|
|
|
294
535
|
|
|
295
536
|
static PyGetSetDef gufunc_getsets [] =
|
|
296
537
|
{
|
|
297
|
-
{ "
|
|
538
|
+
{ "device", (getter)gufunc_getdevice, NULL, NULL, NULL},
|
|
539
|
+
{ "identity", (getter)gufunc_getidentity, (setter)gufunc_setidentity, NULL, NULL},
|
|
540
|
+
{ "kernels", (getter)gufunc_getkernels, NULL, NULL, NULL},
|
|
298
541
|
{NULL}
|
|
299
542
|
};
|
|
300
543
|
|
|
@@ -323,13 +566,25 @@ struct map_args {
|
|
|
323
566
|
const gm_tbl_t *tbl;
|
|
324
567
|
};
|
|
325
568
|
|
|
569
|
+
static int
|
|
570
|
+
Gufunc_CheckExact(const PyObject *v)
|
|
571
|
+
{
|
|
572
|
+
return Py_TYPE(v) == &Gufunc_Type;
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
static int
|
|
576
|
+
Gufunc_Check(const PyObject *v)
|
|
577
|
+
{
|
|
578
|
+
return PyObject_TypeCheck(v, &Gufunc_Type);
|
|
579
|
+
}
|
|
580
|
+
|
|
326
581
|
static int
|
|
327
582
|
add_function(const gm_func_t *f, void *args)
|
|
328
583
|
{
|
|
329
584
|
struct map_args *a = (struct map_args *)args;
|
|
330
585
|
PyObject *func;
|
|
331
586
|
|
|
332
|
-
func = gufunc_new(a->tbl, f->name);
|
|
587
|
+
func = gufunc_new(a->tbl, f->name, GM_CPU_FUNC);
|
|
333
588
|
if (func == NULL) {
|
|
334
589
|
return -1;
|
|
335
590
|
}
|
|
@@ -349,10 +604,40 @@ Gumath_AddFunctions(PyObject *m, const gm_tbl_t *tbl)
|
|
|
349
604
|
return 0;
|
|
350
605
|
}
|
|
351
606
|
|
|
607
|
+
static int
|
|
608
|
+
add_cuda_function(const gm_func_t *f, void *args)
|
|
609
|
+
{
|
|
610
|
+
struct map_args *a = (struct map_args *)args;
|
|
611
|
+
PyObject *func;
|
|
612
|
+
|
|
613
|
+
func = gufunc_new(a->tbl, f->name, GM_CUDA_MANAGED_FUNC);
|
|
614
|
+
if (func == NULL) {
|
|
615
|
+
return -1;
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
return PyModule_AddObject(a->module, f->name, func);
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
static int
|
|
622
|
+
Gumath_AddCudaFunctions(PyObject *m, const gm_tbl_t *tbl)
|
|
623
|
+
{
|
|
624
|
+
struct map_args args = {m, tbl};
|
|
625
|
+
|
|
626
|
+
if (gm_tbl_map(tbl, add_cuda_function, &args) < 0) {
|
|
627
|
+
return -1;
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
return 0;
|
|
631
|
+
}
|
|
632
|
+
|
|
352
633
|
static PyObject *
|
|
353
634
|
init_api(void)
|
|
354
635
|
{
|
|
636
|
+
gumath_api[Gufunc_CheckExact_INDEX] = (void *)Gufunc_CheckExact;
|
|
637
|
+
gumath_api[Gufunc_Check_INDEX] = (void *)Gufunc_Check;
|
|
355
638
|
gumath_api[Gumath_AddFunctions_INDEX] = (void *)Gumath_AddFunctions;
|
|
639
|
+
gumath_api[Gumath_AddFunctions_INDEX] = (void *)Gumath_AddFunctions;
|
|
640
|
+
gumath_api[Gumath_AddCudaFunctions_INDEX] = (void *)Gumath_AddCudaFunctions;
|
|
356
641
|
|
|
357
642
|
return PyCapsule_New(gumath_api, "gumath._gumath._API", NULL);
|
|
358
643
|
}
|
|
@@ -362,6 +647,75 @@ init_api(void)
|
|
|
362
647
|
/* Module */
|
|
363
648
|
/****************************************************************************/
|
|
364
649
|
|
|
650
|
+
static PyObject *
|
|
651
|
+
gufunc_vfold(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwargs)
|
|
652
|
+
{
|
|
653
|
+
static char *kwlist[] = {"f", "acc", NULL};
|
|
654
|
+
PyObject *func = Py_None;
|
|
655
|
+
PyObject *acc = Py_None;
|
|
656
|
+
PyObject *tuple;
|
|
657
|
+
PyObject *dict;
|
|
658
|
+
PyObject *res;
|
|
659
|
+
Py_ssize_t size, i;
|
|
660
|
+
int ret;
|
|
661
|
+
|
|
662
|
+
ret = PyArg_ParseTupleAndKeywords(positional_empty, kwargs, "|$OO", kwlist,
|
|
663
|
+
&func, &acc);
|
|
664
|
+
if (ret < 0) {
|
|
665
|
+
return NULL;
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
if (!Gufunc_Check(func)) {
|
|
669
|
+
PyErr_Format(PyExc_TypeError,
|
|
670
|
+
"vfold: expected gufunc object, got '%.200s'", Py_TYPE(func));
|
|
671
|
+
return NULL;
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
if (!Xnd_Check(acc)) {
|
|
675
|
+
PyErr_Format(PyExc_TypeError,
|
|
676
|
+
"vfold: expected xnd instance, got '%.200s'", Py_TYPE(acc));
|
|
677
|
+
return NULL;
|
|
678
|
+
}
|
|
679
|
+
|
|
680
|
+
/* Push the accumulator onto the argument stack. */
|
|
681
|
+
size = PyTuple_Size(args);
|
|
682
|
+
tuple = PyTuple_New(size+1);
|
|
683
|
+
if (tuple == NULL) {
|
|
684
|
+
return NULL;
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
Py_INCREF(acc);
|
|
688
|
+
PyTuple_SET_ITEM(tuple, 0, acc);
|
|
689
|
+
for (i = 0; i < size; i++) {
|
|
690
|
+
PyObject *v = PyTuple_GET_ITEM(args, i);
|
|
691
|
+
Py_INCREF(v);
|
|
692
|
+
PyTuple_SET_ITEM(tuple, i+1, v);
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
/* Simultaneously use the accumulator as the 'out' argument. */
|
|
696
|
+
dict = PyDict_New();
|
|
697
|
+
if (dict == NULL) {
|
|
698
|
+
Py_DECREF(tuple);
|
|
699
|
+
return NULL;
|
|
700
|
+
}
|
|
701
|
+
if (PyDict_SetItemString(dict, "out", acc) < 0) {
|
|
702
|
+
Py_DECREF(dict);
|
|
703
|
+
Py_DECREF(tuple);
|
|
704
|
+
return NULL;
|
|
705
|
+
}
|
|
706
|
+
if (PyDict_SetItemString(dict, "cls", (PyObject *)(Py_TYPE(acc))) < 0) {
|
|
707
|
+
Py_DECREF(dict);
|
|
708
|
+
Py_DECREF(tuple);
|
|
709
|
+
return NULL;
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
res = _gufunc_call((GufuncObject *)func, tuple, dict, false, false);
|
|
713
|
+
Py_DECREF(tuple);
|
|
714
|
+
Py_DECREF(dict);
|
|
715
|
+
|
|
716
|
+
return res;
|
|
717
|
+
}
|
|
718
|
+
|
|
365
719
|
static PyObject *
|
|
366
720
|
unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
|
|
367
721
|
{
|
|
@@ -388,8 +742,8 @@ unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
|
|
|
388
742
|
k.name = name;
|
|
389
743
|
k.sig = sig;
|
|
390
744
|
|
|
391
|
-
if (strcmp(tag, "Opt") == 0) {
|
|
392
|
-
k.
|
|
745
|
+
if (strcmp(tag, "Opt") == 0) { /* XXX */
|
|
746
|
+
k.OptC = p;
|
|
393
747
|
}
|
|
394
748
|
else if (strcmp(tag, "C") == 0) {
|
|
395
749
|
k.C = p;
|
|
@@ -418,7 +772,7 @@ unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
|
|
|
418
772
|
return seterr(&ctx);
|
|
419
773
|
}
|
|
420
774
|
|
|
421
|
-
return gufunc_new(table, f->name);
|
|
775
|
+
return gufunc_new(table, f->name, GM_CPU_FUNC);
|
|
422
776
|
}
|
|
423
777
|
|
|
424
778
|
static void
|
|
@@ -490,6 +844,7 @@ set_max_threads(PyObject *m UNUSED, PyObject *obj)
|
|
|
490
844
|
static PyMethodDef gumath_methods [] =
|
|
491
845
|
{
|
|
492
846
|
/* Methods */
|
|
847
|
+
{ "vfold", (PyCFunction)gufunc_vfold, METH_VARARGS|METH_KEYWORDS, NULL },
|
|
493
848
|
{ "unsafe_add_kernel", (PyCFunction)unsafe_add_kernel, METH_VARARGS|METH_KEYWORDS, NULL },
|
|
494
849
|
{ "get_max_threads", (PyCFunction)get_max_threads, METH_NOARGS, NULL },
|
|
495
850
|
{ "set_max_threads", (PyCFunction)set_max_threads, METH_O, NULL },
|
|
@@ -554,11 +909,21 @@ PyInit__gumath(void)
|
|
|
554
909
|
goto error;
|
|
555
910
|
}
|
|
556
911
|
|
|
912
|
+
positional_empty = PyTuple_New(0);
|
|
913
|
+
if (positional_empty == NULL) {
|
|
914
|
+
goto error;
|
|
915
|
+
}
|
|
916
|
+
|
|
557
917
|
m = PyModule_Create(&gumath_module);
|
|
558
918
|
if (m == NULL) {
|
|
559
919
|
goto error;
|
|
560
920
|
}
|
|
561
921
|
|
|
922
|
+
Py_INCREF(&Gufunc_Type);
|
|
923
|
+
if (PyModule_AddObject(m, "gufunc", (PyObject *)&Gufunc_Type) < 0) {
|
|
924
|
+
goto error;
|
|
925
|
+
}
|
|
926
|
+
|
|
562
927
|
Py_INCREF(capsule);
|
|
563
928
|
if (PyModule_AddObject(m, "_API", capsule) < 0) {
|
|
564
929
|
goto error;
|
|
@@ -571,6 +936,7 @@ PyInit__gumath(void)
|
|
|
571
936
|
return m;
|
|
572
937
|
|
|
573
938
|
error:
|
|
939
|
+
Py_CLEAR(positional_empty);
|
|
574
940
|
Py_CLEAR(xnd);
|
|
575
941
|
Py_CLEAR(m);
|
|
576
942
|
return NULL;
|