gumath 0.2.0dev5 → 0.2.0dev8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
Binary file
|
@@ -0,0 +1 @@
|
|
1
|
+
ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3
|
@@ -0,0 +1 @@
|
|
1
|
+
ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -72,7 +72,7 @@ clear_all_slices(xnd_t *slices[], int *nslices, int stop)
|
|
72
72
|
{
|
73
73
|
for (int i = 0; i < stop; i++) {
|
74
74
|
for (int k = 0; k < nslices[i]; k++) {
|
75
|
-
|
75
|
+
ndt_decref(slices[i][k].type);
|
76
76
|
}
|
77
77
|
ndt_free(slices[i]);
|
78
78
|
}
|
@@ -94,16 +94,27 @@ apply_thread(void *arg)
|
|
94
94
|
|
95
95
|
int
|
96
96
|
gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
97
|
-
|
97
|
+
const int64_t nthreads, ndt_context_t *ctx)
|
98
98
|
{
|
99
99
|
const int nrows = (int)kernel->set->sig->Function.nargs;
|
100
100
|
ALLOCA(xnd_t *, slices, nrows);
|
101
101
|
ALLOCA(int, nslices, nrows);
|
102
102
|
struct thread_info *tinfo;
|
103
103
|
int ncols, tnum;
|
104
|
+
bool use_threads = true;
|
104
105
|
|
105
|
-
if (nthreads <= 1 || nrows == 0 || outer_dims == 0
|
106
|
-
|
106
|
+
if (nthreads <= 1 || nrows == 0 || outer_dims == 0) {
|
107
|
+
use_threads = false;
|
108
|
+
}
|
109
|
+
|
110
|
+
for (int i = 0; i < nrows; i++) {
|
111
|
+
const ndt_t *t = stack[i].type;
|
112
|
+
if (!ndt_is_ndarray(t) || ndt_nelem(t) < GM_THREAD_CUTOFF) {
|
113
|
+
use_threads = false;
|
114
|
+
}
|
115
|
+
}
|
116
|
+
|
117
|
+
if (!use_threads) {
|
107
118
|
return gm_apply(kernel, stack, outer_dims, ctx);
|
108
119
|
}
|
109
120
|
|
@@ -147,6 +158,7 @@ gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
147
158
|
&tinfo[tnum]);
|
148
159
|
if (ret != 0) {
|
149
160
|
clear_all_slices(slices, nslices, nrows);
|
161
|
+
ndt_free(tinfo);
|
150
162
|
ndt_err_format(ctx, NDT_RuntimeError, "could not create thread");
|
151
163
|
return -1;
|
152
164
|
}
|
@@ -169,6 +181,7 @@ gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
169
181
|
}
|
170
182
|
|
171
183
|
clear_all_slices(slices, nslices, nrows);
|
184
|
+
ndt_free(tinfo);
|
172
185
|
|
173
186
|
return ndt_err_occurred(ctx) ? -1 : 0;
|
174
187
|
}
|
Binary file
|
@@ -38,11 +38,99 @@
|
|
38
38
|
#include "ndtypes.h"
|
39
39
|
#include "xnd.h"
|
40
40
|
#include "gumath.h"
|
41
|
+
#include "overflow.h"
|
41
42
|
|
42
43
|
|
44
|
+
static int _gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
45
|
+
const int outer_dims, ndt_context_t *ctx);
|
46
|
+
|
47
|
+
int
|
48
|
+
array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx)
|
49
|
+
{
|
50
|
+
const ndt_t *t = x->type;
|
51
|
+
|
52
|
+
if (t->tag != Array) {
|
53
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
54
|
+
"type mismatch in outer dimensions");
|
55
|
+
return -1;
|
56
|
+
}
|
57
|
+
|
58
|
+
if (XND_ARRAY_DATA(x->ptr) == NULL) {
|
59
|
+
bool overflow = false;
|
60
|
+
const int64_t size = MULi64(shape, t->Array.itemsize, &overflow);
|
61
|
+
if (overflow) {
|
62
|
+
ndt_err_format(ctx, NDT_ValueError,
|
63
|
+
"datasize of flexible array is too large");
|
64
|
+
return -1;
|
65
|
+
}
|
66
|
+
|
67
|
+
char *data = ndt_aligned_calloc(t->align, size);
|
68
|
+
if (data == NULL) {
|
69
|
+
ndt_err_format(ctx, NDT_MemoryError, "out of memory");
|
70
|
+
return -1;
|
71
|
+
}
|
72
|
+
|
73
|
+
XND_ARRAY_SHAPE(x->ptr) = shape;
|
74
|
+
XND_ARRAY_DATA(x->ptr) = data;
|
75
|
+
|
76
|
+
return 0;
|
77
|
+
}
|
78
|
+
else if (XND_ARRAY_SHAPE(x->ptr) != shape) {
|
79
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
80
|
+
"shape mismatch in outer dimensions");
|
81
|
+
return -1;
|
82
|
+
}
|
83
|
+
else {
|
84
|
+
return 0;
|
85
|
+
}
|
86
|
+
}
|
87
|
+
|
88
|
+
static inline bool
|
89
|
+
any_stored_index(xnd_t stack[], const int nargs)
|
90
|
+
{
|
91
|
+
for (int i = 0; i < nargs; i++) {
|
92
|
+
if (stack[i].ptr == NULL) {
|
93
|
+
continue;
|
94
|
+
}
|
95
|
+
|
96
|
+
const ndt_t *t = stack[i].type;
|
97
|
+
if (have_stored_index(t)) {
|
98
|
+
return true;
|
99
|
+
}
|
100
|
+
}
|
101
|
+
|
102
|
+
return false;
|
103
|
+
}
|
104
|
+
|
43
105
|
int
|
44
106
|
gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
45
107
|
const int outer_dims, ndt_context_t *ctx)
|
108
|
+
{
|
109
|
+
if (any_stored_index(stack, nargs)) {
|
110
|
+
ALLOCA(xnd_t, next, nargs);
|
111
|
+
|
112
|
+
for (int i = 0; i < nargs; i++) {
|
113
|
+
const ndt_t *t = stack[i].type;
|
114
|
+
if (have_stored_index(t)) {
|
115
|
+
next[i] = apply_stored_indices(&stack[i], ctx);
|
116
|
+
if (xnd_err_occurred(&next[i])) {
|
117
|
+
return -1;
|
118
|
+
}
|
119
|
+
}
|
120
|
+
else {
|
121
|
+
next[i] = stack[i];
|
122
|
+
}
|
123
|
+
}
|
124
|
+
|
125
|
+
return _gm_xnd_map(f, next, nargs, outer_dims, ctx);
|
126
|
+
}
|
127
|
+
|
128
|
+
return _gm_xnd_map(f, stack, nargs, outer_dims, ctx);
|
129
|
+
}
|
130
|
+
|
131
|
+
static int
|
132
|
+
_gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
133
|
+
const int outer_dims, ndt_context_t *ctx)
|
46
134
|
{
|
47
135
|
ALLOCA(xnd_t, next, nargs);
|
48
136
|
const ndt_t *t;
|
@@ -123,6 +211,28 @@ gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
|
123
211
|
return 0;
|
124
212
|
}
|
125
213
|
|
214
|
+
case Array: {
|
215
|
+
const int64_t shape = XND_ARRAY_SHAPE(stack[0].ptr);
|
216
|
+
|
217
|
+
for (int k = 1; k < nargs; k++) {
|
218
|
+
if (array_shape_check(&stack[k], shape, ctx) < 0) {
|
219
|
+
return -1;
|
220
|
+
}
|
221
|
+
}
|
222
|
+
|
223
|
+
for (int64_t i = 0; i < shape; i++) {
|
224
|
+
for (int k = 0; k < nargs; k++) {
|
225
|
+
next[k] = xnd_array_next(&stack[k], i);
|
226
|
+
}
|
227
|
+
|
228
|
+
if (gm_xnd_map(f, next, nargs, outer_dims-1, ctx) < 0) {
|
229
|
+
return -1;
|
230
|
+
}
|
231
|
+
}
|
232
|
+
|
233
|
+
return 0;
|
234
|
+
}
|
235
|
+
|
126
236
|
default:
|
127
237
|
ndt_err_format(ctx, NDT_NotImplementedError, "unsupported type");
|
128
238
|
return -1;
|
Binary file
|
@@ -33,7 +33,157 @@
|
|
33
33
|
from ndtypes import ndt
|
34
34
|
from xnd import xnd
|
35
35
|
from ._gumath import *
|
36
|
+
from . import functions as _fn
|
36
37
|
|
38
|
+
try:
|
39
|
+
from . import cuda as _cd
|
40
|
+
except ImportError:
|
41
|
+
_cd = None
|
42
|
+
|
43
|
+
|
44
|
+
__all__ = ['cuda', 'fold', 'functions', 'get_max_threads', 'gufunc', 'reduce',
|
45
|
+
'set_max_threads', 'unsafe_add_kernel', 'vfold', 'xndvectorize']
|
46
|
+
|
47
|
+
|
48
|
+
# ==============================================================================
|
49
|
+
# Init identity elements
|
50
|
+
# ==============================================================================
|
51
|
+
|
52
|
+
# This is done here now, perhaps it should be on the C level.
|
53
|
+
_fn.add.identity = 0
|
54
|
+
_fn.multiply.identity = 1
|
55
|
+
|
56
|
+
|
57
|
+
# ==============================================================================
|
58
|
+
# General fold function
|
59
|
+
# ==============================================================================
|
60
|
+
|
61
|
+
def fold(f, acc, x):
|
62
|
+
return vfold(x, f=f, acc=acc)
|
63
|
+
|
64
|
+
|
65
|
+
# ==============================================================================
|
66
|
+
# NumPy's reduce in terms of fold
|
67
|
+
# ==============================================================================
|
68
|
+
|
69
|
+
def _get_axes(axes, ndim):
|
70
|
+
type_err = "'axes' must be None, a single integer or a tuple of integers"
|
71
|
+
value_err = "axis with value %d out of range"
|
72
|
+
duplicate_err = "'axes' argument contains duplicate values"
|
73
|
+
if axes is None:
|
74
|
+
axes = tuple(range(ndim))
|
75
|
+
elif isinstance(axes, int):
|
76
|
+
axes = (axes,)
|
77
|
+
elif not isinstance(axes, tuple) or \
|
78
|
+
any(not isinstance(v, int) for v in axes):
|
79
|
+
raise TypeError(type_err)
|
80
|
+
|
81
|
+
if any(n >= ndim for n in axes):
|
82
|
+
raise ValueError(value_err % n)
|
83
|
+
|
84
|
+
if len(set(axes)) != len(axes):
|
85
|
+
raise ValueError(duplicate_err)
|
86
|
+
|
87
|
+
return list(axes)
|
88
|
+
|
89
|
+
def _copyto(dest, value):
|
90
|
+
x = xnd(value, dtype=dest.dtype)
|
91
|
+
_fn.copy(x, out=dest)
|
92
|
+
|
93
|
+
def reduce_cpu(f, x, axes, dtype):
|
94
|
+
"""NumPy's reduce in terms of fold."""
|
95
|
+
axes = _get_axes(axes, x.ndim)
|
96
|
+
if not axes:
|
97
|
+
return x
|
98
|
+
|
99
|
+
permute = [n for n in range(x.ndim) if n not in axes]
|
100
|
+
permute = axes + permute
|
101
|
+
|
102
|
+
T = x.transpose(permute=permute)
|
103
|
+
|
104
|
+
N = len(axes)
|
105
|
+
t = T.type.at(N, dtype=dtype)
|
106
|
+
acc = x.empty(t, device=x.device)
|
107
|
+
|
108
|
+
if f.identity is not None:
|
109
|
+
_copyto(acc, f.identity)
|
110
|
+
tl = T
|
111
|
+
elif N == 1 and T.type.shape[0] > 0:
|
112
|
+
hd, tl = T[0], T[1:]
|
113
|
+
acc[()] = hd
|
114
|
+
else:
|
115
|
+
raise ValueError(
|
116
|
+
"reduction not possible for function without an identity element")
|
117
|
+
|
118
|
+
return fold(f, acc, tl)
|
119
|
+
|
120
|
+
def reduce_cuda(g, x, axes, dtype):
|
121
|
+
"""Reductions in CUDA use the thrust library for speed and have limited
|
122
|
+
functionality."""
|
123
|
+
if axes != 0:
|
124
|
+
raise NotImplementedError("'axes' keyword is not implemented for CUDA")
|
125
|
+
|
126
|
+
return g(x, dtype=dtype)
|
127
|
+
|
128
|
+
def get_cuda_reduction_func(f):
|
129
|
+
if _cd is None:
|
130
|
+
return None
|
131
|
+
elif f == _cd.add:
|
132
|
+
return _cd.reduce_add
|
133
|
+
elif f == _cd.multiply:
|
134
|
+
return _cd.reduce_multiply
|
135
|
+
else:
|
136
|
+
return None
|
137
|
+
|
138
|
+
def reduce(f, x, axes=0, dtype=None):
|
139
|
+
if dtype is None:
|
140
|
+
dtype = maxcast[x.dtype]
|
141
|
+
|
142
|
+
g = get_cuda_reduction_func(f)
|
143
|
+
if g is not None:
|
144
|
+
return reduce_cuda(g, x, axes, dtype)
|
145
|
+
|
146
|
+
return reduce_cpu(f, x, axes, dtype)
|
147
|
+
|
148
|
+
|
149
|
+
maxcast = {
|
150
|
+
ndt("int8"): ndt("int64"),
|
151
|
+
ndt("int16"): ndt("int64"),
|
152
|
+
ndt("int32"): ndt("int64"),
|
153
|
+
ndt("int64"): ndt("int64"),
|
154
|
+
ndt("uint8"): ndt("uint64"),
|
155
|
+
ndt("uint16"): ndt("uint64"),
|
156
|
+
ndt("uint32"): ndt("uint64"),
|
157
|
+
ndt("uint64"): ndt("uint64"),
|
158
|
+
ndt("bfloat16"): ndt("float64"),
|
159
|
+
ndt("float16"): ndt("float64"),
|
160
|
+
ndt("float32"): ndt("float64"),
|
161
|
+
ndt("float64"): ndt("float64"),
|
162
|
+
ndt("complex32"): ndt("complex128"),
|
163
|
+
ndt("complex64"): ndt("complex128"),
|
164
|
+
ndt("complex128"): ndt("complex128"),
|
165
|
+
|
166
|
+
ndt("?int8"): ndt("?int64"),
|
167
|
+
ndt("?int16"): ndt("?int64"),
|
168
|
+
ndt("?int32"): ndt("?int64"),
|
169
|
+
ndt("?int64"): ndt("?int64"),
|
170
|
+
ndt("?uint8"): ndt("?uint64"),
|
171
|
+
ndt("?uint16"): ndt("?uint64"),
|
172
|
+
ndt("?uint32"): ndt("?uint64"),
|
173
|
+
ndt("?uint64"): ndt("?uint64"),
|
174
|
+
ndt("?bfloat16"): ndt("?float64"),
|
175
|
+
ndt("?float16"): ndt("?float64"),
|
176
|
+
ndt("?float32"): ndt("?float64"),
|
177
|
+
ndt("?float64"): ndt("?float64"),
|
178
|
+
ndt("?complex32"): ndt("?complex128"),
|
179
|
+
ndt("?complex64"): ndt("?complex128"),
|
180
|
+
ndt("?complex128"): ndt("?complex128"),
|
181
|
+
}
|
182
|
+
|
183
|
+
|
184
|
+
# ==============================================================================
|
185
|
+
# Numba's GUVectorize on xnd arrays
|
186
|
+
# ==============================================================================
|
37
187
|
|
38
188
|
try:
|
39
189
|
import numpy as np
|
@@ -45,16 +45,24 @@
|
|
45
45
|
#define GUMATH_MODULE
|
46
46
|
#include "pygumath.h"
|
47
47
|
|
48
|
+
|
48
49
|
#ifdef _MSC_VER
|
49
50
|
#ifndef UNUSED
|
50
51
|
#define UNUSED
|
51
52
|
#endif
|
53
|
+
#include <float.h>
|
54
|
+
#include <fenv.h>
|
55
|
+
#pragma fenv_access(on)
|
52
56
|
#else
|
53
57
|
#if defined(__GNUC__) && !defined(__INTEL_COMPILER)
|
54
58
|
#define UNUSED __attribute__((unused))
|
55
59
|
#else
|
56
60
|
#define UNUSED
|
57
61
|
#endif
|
62
|
+
#include <fenv.h>
|
63
|
+
#if 0 /* Not supported by gcc and clang. */
|
64
|
+
#pragma STDC FENV_ACCESS ON
|
65
|
+
#endif
|
58
66
|
#endif
|
59
67
|
|
60
68
|
|
@@ -73,6 +81,9 @@ static gm_tbl_t *table = NULL;
|
|
73
81
|
/* Xnd type */
|
74
82
|
static PyTypeObject *xnd = NULL;
|
75
83
|
|
84
|
+
/* Empty positional arguments */
|
85
|
+
static PyObject *positional_empty = NULL;
|
86
|
+
|
76
87
|
/* Maximum number of threads */
|
77
88
|
static int64_t max_threads = 1;
|
78
89
|
|
@@ -95,7 +106,7 @@ seterr(ndt_context_t *ctx)
|
|
95
106
|
static PyTypeObject Gufunc_Type;
|
96
107
|
|
97
108
|
static PyObject *
|
98
|
-
gufunc_new(const gm_tbl_t *tbl, const char *name)
|
109
|
+
gufunc_new(const gm_tbl_t *tbl, const char *name, const uint32_t flags)
|
99
110
|
{
|
100
111
|
NDT_STATIC_CONTEXT(ctx);
|
101
112
|
GufuncObject *self;
|
@@ -106,12 +117,16 @@ gufunc_new(const gm_tbl_t *tbl, const char *name)
|
|
106
117
|
}
|
107
118
|
|
108
119
|
self->tbl = tbl;
|
120
|
+
self->flags = flags;
|
109
121
|
|
110
122
|
self->name = ndt_strdup(name, &ctx);
|
111
123
|
if (self->name == NULL) {
|
112
124
|
return seterr(&ctx);
|
113
125
|
}
|
114
126
|
|
127
|
+
self->identity = Py_None;
|
128
|
+
Py_INCREF(self->identity);
|
129
|
+
|
115
130
|
return (PyObject *)self;
|
116
131
|
}
|
117
132
|
|
@@ -119,6 +134,7 @@ static void
|
|
119
134
|
gufunc_dealloc(GufuncObject *self)
|
120
135
|
{
|
121
136
|
ndt_free(self->name);
|
137
|
+
Py_DECREF(self->identity);
|
122
138
|
PyObject_Del(self);
|
123
139
|
}
|
124
140
|
|
@@ -128,124 +144,317 @@ gufunc_dealloc(GufuncObject *self)
|
|
128
144
|
/****************************************************************************/
|
129
145
|
|
130
146
|
static void
|
131
|
-
|
147
|
+
clear_pystack(PyObject *pystack[], Py_ssize_t len)
|
148
|
+
{
|
149
|
+
for (Py_ssize_t i = 0; i < len; i++) {
|
150
|
+
Py_CLEAR(pystack[i]);
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
static int
|
155
|
+
parse_args(PyObject *pystack[NDT_MAX_ARGS], int *py_nin, int *py_nout, int *py_nargs,
|
156
|
+
PyObject *args, PyObject *out)
|
132
157
|
{
|
133
|
-
Py_ssize_t
|
158
|
+
Py_ssize_t nin;
|
159
|
+
Py_ssize_t nout;
|
134
160
|
|
135
|
-
|
136
|
-
|
161
|
+
if (!args || !PyTuple_Check(args)) {
|
162
|
+
const char *name = args ? Py_TYPE(args)->tp_name : "NULL";
|
163
|
+
PyErr_Format(PyExc_SystemError,
|
164
|
+
"internal error: expected tuple, got '%.200s'", name);
|
165
|
+
return -1;
|
137
166
|
}
|
167
|
+
|
168
|
+
nin = PyTuple_GET_SIZE(args);
|
169
|
+
if (nin > NDT_MAX_ARGS) {
|
170
|
+
PyErr_Format(PyExc_TypeError,
|
171
|
+
"maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin);
|
172
|
+
return -1;
|
173
|
+
}
|
174
|
+
|
175
|
+
for (Py_ssize_t i = 0; i < nin; i++) {
|
176
|
+
PyObject *v = PyTuple_GET_ITEM(args, i);
|
177
|
+
if (!Xnd_Check(v)) {
|
178
|
+
PyErr_Format(PyExc_TypeError,
|
179
|
+
"expected xnd argument, got '%.200s'", Py_TYPE(v)->tp_name);
|
180
|
+
return -1;
|
181
|
+
}
|
182
|
+
|
183
|
+
pystack[i] = v;
|
184
|
+
}
|
185
|
+
|
186
|
+
if (out == NULL) {
|
187
|
+
nout = 0;
|
188
|
+
}
|
189
|
+
else {
|
190
|
+
if (Xnd_Check(out)) {
|
191
|
+
nout = 1;
|
192
|
+
if (nin+nout > NDT_MAX_ARGS) {
|
193
|
+
PyErr_Format(PyExc_TypeError,
|
194
|
+
"maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin+nout);
|
195
|
+
return -1;
|
196
|
+
}
|
197
|
+
pystack[nin] = out;
|
198
|
+
}
|
199
|
+
else if (PyTuple_Check(out)) {
|
200
|
+
nout = PyTuple_GET_SIZE(out);
|
201
|
+
if (nout > NDT_MAX_ARGS || nin+nout > NDT_MAX_ARGS) {
|
202
|
+
PyErr_Format(PyExc_TypeError,
|
203
|
+
"maximum number of arguments is %d, got %n", NDT_MAX_ARGS, nin+nout);
|
204
|
+
return -1;
|
205
|
+
}
|
206
|
+
|
207
|
+
for (Py_ssize_t i = 0; i < nout; i++) {
|
208
|
+
PyObject *v = PyTuple_GET_ITEM(out, i);
|
209
|
+
if (!Xnd_Check(v)) {
|
210
|
+
PyErr_Format(PyExc_TypeError,
|
211
|
+
"expected xnd argument, got '%.200s'", Py_TYPE(v)->tp_name);
|
212
|
+
return -1;
|
213
|
+
}
|
214
|
+
|
215
|
+
pystack[nin+i] = v;
|
216
|
+
}
|
217
|
+
}
|
218
|
+
else {
|
219
|
+
PyErr_Format(PyExc_TypeError,
|
220
|
+
"'out' argument must be xnd or a tuple of xnd, got '%.200s'",
|
221
|
+
Py_TYPE(out)->tp_name);
|
222
|
+
return -1;
|
223
|
+
}
|
224
|
+
}
|
225
|
+
|
226
|
+
for (int i = 0; i < nin+nout; i++) {
|
227
|
+
Py_INCREF(pystack[i]);
|
228
|
+
}
|
229
|
+
|
230
|
+
*py_nin = (int)nin;
|
231
|
+
*py_nout = (int)nout;
|
232
|
+
*py_nargs = (int)nin+(int)nout;
|
233
|
+
|
234
|
+
return 0;
|
138
235
|
}
|
139
236
|
|
140
237
|
static PyObject *
|
141
|
-
|
238
|
+
_gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwargs,
|
239
|
+
bool enable_threads, bool check_broadcast)
|
142
240
|
{
|
241
|
+
static char *kwlist[] = {"out", "dtype", "cls", NULL};
|
242
|
+
PyObject *out = Py_None;
|
243
|
+
PyObject *dt = Py_None;
|
244
|
+
PyObject *cls = Py_None;
|
245
|
+
|
143
246
|
NDT_STATIC_CONTEXT(ctx);
|
144
|
-
|
145
|
-
PyObject **a = &PyTuple_GET_ITEM(args, 0);
|
146
|
-
PyObject *result[NDT_MAX_ARGS];
|
147
|
-
ndt_apply_spec_t spec = ndt_apply_spec_empty;
|
148
|
-
const ndt_t *in_types[NDT_MAX_ARGS];
|
247
|
+
PyObject *pystack[NDT_MAX_ARGS];
|
149
248
|
xnd_t stack[NDT_MAX_ARGS];
|
249
|
+
const ndt_t *types[NDT_MAX_ARGS];
|
250
|
+
int64_t li[NDT_MAX_ARGS];
|
251
|
+
ndt_apply_spec_t spec = ndt_apply_spec_empty;
|
150
252
|
gm_kernel_t kernel;
|
151
|
-
|
253
|
+
bool have_cpu_device = false;
|
254
|
+
ndt_t *dtype = NULL;
|
255
|
+
int nin, nout, nargs;
|
256
|
+
int k;
|
152
257
|
|
153
|
-
if (
|
154
|
-
|
155
|
-
"gufunc calls do not support keywords");
|
258
|
+
if (!PyArg_ParseTupleAndKeywords(positional_empty, kwargs, "|$OOO", kwlist,
|
259
|
+
&out, &dt, &cls)) {
|
156
260
|
return NULL;
|
157
261
|
}
|
158
262
|
|
159
|
-
|
263
|
+
out = out == Py_None ? NULL : out;
|
264
|
+
dt = dt == Py_None ? NULL : dt;
|
265
|
+
cls = cls == Py_None ? (PyObject *)xnd : cls;
|
266
|
+
|
267
|
+
if (dt != NULL) {
|
268
|
+
if (out != NULL) {
|
269
|
+
PyErr_SetString(PyExc_TypeError,
|
270
|
+
"the 'out' and 'dtype' arguments are mutually exclusive");
|
271
|
+
return NULL;
|
272
|
+
}
|
273
|
+
if (!Ndt_Check(dt)) {
|
274
|
+
PyErr_Format(PyExc_TypeError,
|
275
|
+
"'dtype' argument must be ndt, got '%.200s'",
|
276
|
+
Py_TYPE(dt)->tp_name);
|
277
|
+
return NULL;
|
278
|
+
}
|
279
|
+
dtype = (ndt_t *)NDT(dt);
|
280
|
+
ndt_incref(dtype);
|
281
|
+
}
|
282
|
+
|
283
|
+
if (!PyType_Check(cls) || !PyType_IsSubtype((PyTypeObject *)cls, xnd)) {
|
160
284
|
PyErr_SetString(PyExc_TypeError,
|
161
|
-
"
|
285
|
+
"the 'cls' argument must be a subtype of 'xnd'");
|
286
|
+
return NULL;
|
287
|
+
}
|
288
|
+
|
289
|
+
if (parse_args(pystack, &nin, &nout, &nargs, args, out) < 0) {
|
162
290
|
return NULL;
|
163
291
|
}
|
292
|
+
assert(nout == 0 || dtype == NULL);
|
164
293
|
|
165
|
-
for (
|
166
|
-
|
167
|
-
|
294
|
+
for (k = 0; k < nargs; k++) {
|
295
|
+
const XndObject *x = (XndObject *)pystack[k];
|
296
|
+
if (!(x->mblock->xnd->flags&XND_CUDA_MANAGED)) {
|
297
|
+
have_cpu_device = true;
|
298
|
+
}
|
299
|
+
|
300
|
+
stack[k] = *CONST_XND((PyObject *)x);
|
301
|
+
types[k] = stack[k].type;
|
302
|
+
li[k] = stack[k].index;
|
303
|
+
}
|
304
|
+
|
305
|
+
if (have_cpu_device) {
|
306
|
+
if (self->flags & GM_CUDA_MANAGED_FUNC) {
|
307
|
+
PyErr_SetString(PyExc_ValueError,
|
308
|
+
"cannot run a cuda function on xnd objects with cpu memory");
|
309
|
+
clear_pystack(pystack, nargs);
|
168
310
|
return NULL;
|
169
311
|
}
|
170
|
-
stack[i] = *CONST_XND(a[i]);
|
171
|
-
in_types[i] = stack[i].type;
|
172
312
|
}
|
173
313
|
|
174
|
-
kernel = gm_select(&spec, self->tbl, self->name,
|
314
|
+
kernel = gm_select(&spec, self->tbl, self->name, types, li, nin, nout,
|
315
|
+
nout && check_broadcast, stack, &ctx);
|
175
316
|
if (kernel.set == NULL) {
|
176
317
|
return seterr(&ctx);
|
177
318
|
}
|
178
319
|
|
179
|
-
if (
|
180
|
-
|
181
|
-
|
320
|
+
if (dtype) {
|
321
|
+
if (spec.nout != 1) {
|
322
|
+
ndt_err_format(&ctx, NDT_TypeError,
|
323
|
+
"the 'dtype' argument is only supported for a single "
|
324
|
+
"return value");
|
325
|
+
ndt_apply_spec_clear(&spec);
|
326
|
+
ndt_decref(dtype);
|
327
|
+
return seterr(&ctx);
|
328
|
+
}
|
329
|
+
|
330
|
+
const ndt_t *u = spec.types[spec.nin];
|
331
|
+
const ndt_t *v = ndt_copy_contiguous_dtype(u, dtype, 0, &ctx);
|
332
|
+
|
333
|
+
ndt_apply_spec_clear(&spec);
|
334
|
+
ndt_decref(dtype);
|
335
|
+
|
336
|
+
if (v == NULL) {
|
337
|
+
return seterr(&ctx);
|
338
|
+
}
|
339
|
+
|
340
|
+
types[nin] = v;
|
341
|
+
kernel = gm_select(&spec, self->tbl, self->name, types, li, nin, 1,
|
342
|
+
1 && check_broadcast, stack, &ctx);
|
343
|
+
if (kernel.set == NULL) {
|
344
|
+
return seterr(&ctx);
|
182
345
|
}
|
183
346
|
}
|
184
347
|
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
348
|
+
/*
|
349
|
+
* Replace args/kwargs types with types after substitution and broadcasting.
|
350
|
+
* This includes 'out' types, if explicitly passed as kwargs.
|
351
|
+
*/
|
352
|
+
for (int i = 0; i < spec.nargs; i++) {
|
353
|
+
stack[i].type = spec.types[i];
|
354
|
+
}
|
355
|
+
|
356
|
+
if (nout == 0) {
|
357
|
+
/* 'out' types have been inferred, create new XndObjects. */
|
358
|
+
for (int i = 0; i < spec.nout; i++) {
|
359
|
+
if (ndt_is_concrete(spec.types[nin+i])) {
|
360
|
+
uint32_t flags = self->flags == GM_CUDA_MANAGED_FUNC ? XND_CUDA_MANAGED : 0;
|
361
|
+
PyObject *x = Xnd_EmptyFromType((PyTypeObject *)cls, spec.types[nin+i], flags);
|
362
|
+
if (x == NULL) {
|
363
|
+
clear_pystack(pystack, nin+i);
|
364
|
+
ndt_apply_spec_clear(&spec);
|
191
365
|
return NULL;
|
192
366
|
}
|
193
|
-
|
367
|
+
pystack[nin+i] = x;
|
194
368
|
stack[nin+i] = *CONST_XND(x);
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
369
|
+
}
|
370
|
+
else {
|
371
|
+
clear_pystack(pystack, nin+i);
|
372
|
+
ndt_apply_spec_clear(&spec);
|
373
|
+
PyErr_SetString(PyExc_ValueError,
|
374
|
+
"arguments with abstract types are temporarily disabled");
|
375
|
+
return NULL;
|
376
|
+
}
|
377
|
+
}
|
200
378
|
}
|
201
379
|
|
202
|
-
|
203
|
-
if
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
return seterr(&ctx);
|
212
|
-
}
|
213
|
-
#endif
|
380
|
+
if (self->flags == GM_CUDA_MANAGED_FUNC) {
|
381
|
+
#if HAVE_CUDA
|
382
|
+
if (!check_broadcast) {
|
383
|
+
ndt_err_format(&ctx, NDT_NotImplementedError,
|
384
|
+
"fold() is currently not supported on cuda");
|
385
|
+
clear_pystack(pystack, spec.nargs);
|
386
|
+
ndt_apply_spec_clear(&spec);
|
387
|
+
return seterr(&ctx);
|
388
|
+
}
|
214
389
|
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
clear_objects(result, i);
|
222
|
-
for (k = i+1; k < spec.nout; k++) {
|
223
|
-
if (ndt_is_abstract(spec.out[k])) {
|
224
|
-
xnd_del_buffer(&stack[nin+k], XND_OWN_ALL);
|
225
|
-
}
|
226
|
-
}
|
227
|
-
}
|
228
|
-
result[i] = x;
|
390
|
+
const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
|
391
|
+
|
392
|
+
if (xnd_cuda_device_synchronize(&ctx) < 0 || ret < 0) {
|
393
|
+
clear_pystack(pystack, spec.nargs);
|
394
|
+
ndt_apply_spec_clear(&spec);
|
395
|
+
return seterr(&ctx);
|
229
396
|
}
|
397
|
+
#else
|
398
|
+
ndt_err_format(&ctx, NDT_RuntimeError,
|
399
|
+
"internal error: GM_CUDA_MANAGED_FUNC set in a build without cuda support");
|
400
|
+
clear_pystack(pystack, spec.nargs);
|
401
|
+
ndt_apply_spec_clear(&spec);
|
402
|
+
return seterr(&ctx);
|
403
|
+
#endif
|
230
404
|
}
|
405
|
+
else {
|
406
|
+
#ifdef HAVE_PTHREAD_H
|
407
|
+
const int rounding = fegetround();
|
408
|
+
fesetround(FE_TONEAREST);
|
409
|
+
|
410
|
+
const int64_t N = enable_threads ? max_threads : 1;
|
411
|
+
const int ret = gm_apply_thread(&kernel, stack, spec.outer_dims, N,
|
412
|
+
&ctx);
|
413
|
+
fesetround(rounding);
|
414
|
+
|
415
|
+
if (ret < 0) {
|
416
|
+
clear_pystack(pystack, spec.nargs);
|
417
|
+
ndt_apply_spec_clear(&spec);
|
418
|
+
return seterr(&ctx);
|
419
|
+
}
|
420
|
+
#else
|
421
|
+
const int rounding = fegetround();
|
422
|
+
fesetround(FE_TONEAREST);
|
423
|
+
|
424
|
+
const int ret = gm_apply(&kernel, stack, spec.outer_dims, &ctx);
|
231
425
|
|
232
|
-
|
233
|
-
|
234
|
-
|
426
|
+
fesetround(rounding);
|
427
|
+
|
428
|
+
if (ret < 0) {
|
429
|
+
clear_pystack(pystack, spec.nargs);
|
430
|
+
ndt_apply_spec_clear(&spec);
|
431
|
+
return seterr(&ctx);
|
235
432
|
}
|
433
|
+
#endif
|
236
434
|
}
|
237
435
|
|
238
|
-
|
239
|
-
|
240
|
-
|
436
|
+
nin = spec.nin;
|
437
|
+
nout = spec.nout;
|
438
|
+
nargs = spec.nargs;
|
439
|
+
ndt_apply_spec_clear(&spec);
|
440
|
+
|
441
|
+
switch (nout) {
|
442
|
+
case 0: {
|
443
|
+
clear_pystack(pystack, nargs);
|
444
|
+
Py_RETURN_NONE;
|
445
|
+
}
|
446
|
+
case 1: {
|
447
|
+
clear_pystack(pystack, nin);
|
448
|
+
return pystack[nin];
|
449
|
+
}
|
241
450
|
default: {
|
242
|
-
PyObject *tuple = PyTuple_New(
|
451
|
+
PyObject *tuple = PyTuple_New(nout);
|
243
452
|
if (tuple == NULL) {
|
244
|
-
|
453
|
+
clear_pystack(pystack, nargs);
|
245
454
|
return NULL;
|
246
455
|
}
|
247
|
-
for (i = 0; i <
|
248
|
-
PyTuple_SET_ITEM(tuple, i,
|
456
|
+
for (int i = 0; i < nout; i++) {
|
457
|
+
PyTuple_SET_ITEM(tuple, i, pystack[nin+i]);
|
249
458
|
}
|
250
459
|
return tuple;
|
251
460
|
}
|
@@ -253,7 +462,39 @@ gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwds)
|
|
253
462
|
}
|
254
463
|
|
255
464
|
static PyObject *
|
256
|
-
|
465
|
+
gufunc_call(GufuncObject *self, PyObject *args, PyObject *kwargs)
|
466
|
+
{
|
467
|
+
return _gufunc_call(self, args, kwargs, true, true);
|
468
|
+
}
|
469
|
+
|
470
|
+
static PyObject *
|
471
|
+
gufunc_getdevice(GufuncObject *self, PyObject *args GM_UNUSED)
|
472
|
+
{
|
473
|
+
if (self->flags & GM_CUDA_MANAGED_FUNC) {
|
474
|
+
return PyUnicode_FromString("cuda:managed");
|
475
|
+
}
|
476
|
+
|
477
|
+
Py_RETURN_NONE;
|
478
|
+
}
|
479
|
+
|
480
|
+
static PyObject *
|
481
|
+
gufunc_getidentity(GufuncObject *self, PyObject *args GM_UNUSED)
|
482
|
+
{
|
483
|
+
Py_INCREF(self->identity);
|
484
|
+
return self->identity;
|
485
|
+
}
|
486
|
+
|
487
|
+
static int
|
488
|
+
gufunc_setidentity(GufuncObject *self, PyObject *value, void *closure GM_UNUSED)
|
489
|
+
{
|
490
|
+
Py_DECREF(self->identity);
|
491
|
+
Py_INCREF(value);
|
492
|
+
self->identity = value;
|
493
|
+
return 0;
|
494
|
+
}
|
495
|
+
|
496
|
+
static PyObject *
|
497
|
+
gufunc_getkernels(GufuncObject *self, PyObject *args GM_UNUSED)
|
257
498
|
{
|
258
499
|
NDT_STATIC_CONTEXT(ctx);
|
259
500
|
PyObject *list, *tmp;
|
@@ -294,7 +535,9 @@ gufunc_kernels(GufuncObject *self, PyObject *args GM_UNUSED)
|
|
294
535
|
|
295
536
|
static PyGetSetDef gufunc_getsets [] =
|
296
537
|
{
|
297
|
-
{ "
|
538
|
+
{ "device", (getter)gufunc_getdevice, NULL, NULL, NULL},
|
539
|
+
{ "identity", (getter)gufunc_getidentity, (setter)gufunc_setidentity, NULL, NULL},
|
540
|
+
{ "kernels", (getter)gufunc_getkernels, NULL, NULL, NULL},
|
298
541
|
{NULL}
|
299
542
|
};
|
300
543
|
|
@@ -323,13 +566,25 @@ struct map_args {
|
|
323
566
|
const gm_tbl_t *tbl;
|
324
567
|
};
|
325
568
|
|
569
|
+
static int
|
570
|
+
Gufunc_CheckExact(const PyObject *v)
|
571
|
+
{
|
572
|
+
return Py_TYPE(v) == &Gufunc_Type;
|
573
|
+
}
|
574
|
+
|
575
|
+
static int
|
576
|
+
Gufunc_Check(const PyObject *v)
|
577
|
+
{
|
578
|
+
return PyObject_TypeCheck(v, &Gufunc_Type);
|
579
|
+
}
|
580
|
+
|
326
581
|
static int
|
327
582
|
add_function(const gm_func_t *f, void *args)
|
328
583
|
{
|
329
584
|
struct map_args *a = (struct map_args *)args;
|
330
585
|
PyObject *func;
|
331
586
|
|
332
|
-
func = gufunc_new(a->tbl, f->name);
|
587
|
+
func = gufunc_new(a->tbl, f->name, GM_CPU_FUNC);
|
333
588
|
if (func == NULL) {
|
334
589
|
return -1;
|
335
590
|
}
|
@@ -349,10 +604,40 @@ Gumath_AddFunctions(PyObject *m, const gm_tbl_t *tbl)
|
|
349
604
|
return 0;
|
350
605
|
}
|
351
606
|
|
607
|
+
static int
|
608
|
+
add_cuda_function(const gm_func_t *f, void *args)
|
609
|
+
{
|
610
|
+
struct map_args *a = (struct map_args *)args;
|
611
|
+
PyObject *func;
|
612
|
+
|
613
|
+
func = gufunc_new(a->tbl, f->name, GM_CUDA_MANAGED_FUNC);
|
614
|
+
if (func == NULL) {
|
615
|
+
return -1;
|
616
|
+
}
|
617
|
+
|
618
|
+
return PyModule_AddObject(a->module, f->name, func);
|
619
|
+
}
|
620
|
+
|
621
|
+
static int
|
622
|
+
Gumath_AddCudaFunctions(PyObject *m, const gm_tbl_t *tbl)
|
623
|
+
{
|
624
|
+
struct map_args args = {m, tbl};
|
625
|
+
|
626
|
+
if (gm_tbl_map(tbl, add_cuda_function, &args) < 0) {
|
627
|
+
return -1;
|
628
|
+
}
|
629
|
+
|
630
|
+
return 0;
|
631
|
+
}
|
632
|
+
|
352
633
|
static PyObject *
|
353
634
|
init_api(void)
|
354
635
|
{
|
636
|
+
gumath_api[Gufunc_CheckExact_INDEX] = (void *)Gufunc_CheckExact;
|
637
|
+
gumath_api[Gufunc_Check_INDEX] = (void *)Gufunc_Check;
|
355
638
|
gumath_api[Gumath_AddFunctions_INDEX] = (void *)Gumath_AddFunctions;
|
639
|
+
gumath_api[Gumath_AddFunctions_INDEX] = (void *)Gumath_AddFunctions;
|
640
|
+
gumath_api[Gumath_AddCudaFunctions_INDEX] = (void *)Gumath_AddCudaFunctions;
|
356
641
|
|
357
642
|
return PyCapsule_New(gumath_api, "gumath._gumath._API", NULL);
|
358
643
|
}
|
@@ -362,6 +647,75 @@ init_api(void)
|
|
362
647
|
/* Module */
|
363
648
|
/****************************************************************************/
|
364
649
|
|
650
|
+
static PyObject *
|
651
|
+
gufunc_vfold(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwargs)
|
652
|
+
{
|
653
|
+
static char *kwlist[] = {"f", "acc", NULL};
|
654
|
+
PyObject *func = Py_None;
|
655
|
+
PyObject *acc = Py_None;
|
656
|
+
PyObject *tuple;
|
657
|
+
PyObject *dict;
|
658
|
+
PyObject *res;
|
659
|
+
Py_ssize_t size, i;
|
660
|
+
int ret;
|
661
|
+
|
662
|
+
ret = PyArg_ParseTupleAndKeywords(positional_empty, kwargs, "|$OO", kwlist,
|
663
|
+
&func, &acc);
|
664
|
+
if (ret < 0) {
|
665
|
+
return NULL;
|
666
|
+
}
|
667
|
+
|
668
|
+
if (!Gufunc_Check(func)) {
|
669
|
+
PyErr_Format(PyExc_TypeError,
|
670
|
+
"vfold: expected gufunc object, got '%.200s'", Py_TYPE(func));
|
671
|
+
return NULL;
|
672
|
+
}
|
673
|
+
|
674
|
+
if (!Xnd_Check(acc)) {
|
675
|
+
PyErr_Format(PyExc_TypeError,
|
676
|
+
"vfold: expected xnd instance, got '%.200s'", Py_TYPE(acc));
|
677
|
+
return NULL;
|
678
|
+
}
|
679
|
+
|
680
|
+
/* Push the accumulator onto the argument stack. */
|
681
|
+
size = PyTuple_Size(args);
|
682
|
+
tuple = PyTuple_New(size+1);
|
683
|
+
if (tuple == NULL) {
|
684
|
+
return NULL;
|
685
|
+
}
|
686
|
+
|
687
|
+
Py_INCREF(acc);
|
688
|
+
PyTuple_SET_ITEM(tuple, 0, acc);
|
689
|
+
for (i = 0; i < size; i++) {
|
690
|
+
PyObject *v = PyTuple_GET_ITEM(args, i);
|
691
|
+
Py_INCREF(v);
|
692
|
+
PyTuple_SET_ITEM(tuple, i+1, v);
|
693
|
+
}
|
694
|
+
|
695
|
+
/* Simultaneously use the accumulator as the 'out' argument. */
|
696
|
+
dict = PyDict_New();
|
697
|
+
if (dict == NULL) {
|
698
|
+
Py_DECREF(tuple);
|
699
|
+
return NULL;
|
700
|
+
}
|
701
|
+
if (PyDict_SetItemString(dict, "out", acc) < 0) {
|
702
|
+
Py_DECREF(dict);
|
703
|
+
Py_DECREF(tuple);
|
704
|
+
return NULL;
|
705
|
+
}
|
706
|
+
if (PyDict_SetItemString(dict, "cls", (PyObject *)(Py_TYPE(acc))) < 0) {
|
707
|
+
Py_DECREF(dict);
|
708
|
+
Py_DECREF(tuple);
|
709
|
+
return NULL;
|
710
|
+
}
|
711
|
+
|
712
|
+
res = _gufunc_call((GufuncObject *)func, tuple, dict, false, false);
|
713
|
+
Py_DECREF(tuple);
|
714
|
+
Py_DECREF(dict);
|
715
|
+
|
716
|
+
return res;
|
717
|
+
}
|
718
|
+
|
365
719
|
static PyObject *
|
366
720
|
unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
|
367
721
|
{
|
@@ -388,8 +742,8 @@ unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
|
|
388
742
|
k.name = name;
|
389
743
|
k.sig = sig;
|
390
744
|
|
391
|
-
if (strcmp(tag, "Opt") == 0) {
|
392
|
-
k.
|
745
|
+
if (strcmp(tag, "Opt") == 0) { /* XXX */
|
746
|
+
k.OptC = p;
|
393
747
|
}
|
394
748
|
else if (strcmp(tag, "C") == 0) {
|
395
749
|
k.C = p;
|
@@ -418,7 +772,7 @@ unsafe_add_kernel(PyObject *m GM_UNUSED, PyObject *args, PyObject *kwds)
|
|
418
772
|
return seterr(&ctx);
|
419
773
|
}
|
420
774
|
|
421
|
-
return gufunc_new(table, f->name);
|
775
|
+
return gufunc_new(table, f->name, GM_CPU_FUNC);
|
422
776
|
}
|
423
777
|
|
424
778
|
static void
|
@@ -490,6 +844,7 @@ set_max_threads(PyObject *m UNUSED, PyObject *obj)
|
|
490
844
|
static PyMethodDef gumath_methods [] =
|
491
845
|
{
|
492
846
|
/* Methods */
|
847
|
+
{ "vfold", (PyCFunction)gufunc_vfold, METH_VARARGS|METH_KEYWORDS, NULL },
|
493
848
|
{ "unsafe_add_kernel", (PyCFunction)unsafe_add_kernel, METH_VARARGS|METH_KEYWORDS, NULL },
|
494
849
|
{ "get_max_threads", (PyCFunction)get_max_threads, METH_NOARGS, NULL },
|
495
850
|
{ "set_max_threads", (PyCFunction)set_max_threads, METH_O, NULL },
|
@@ -554,11 +909,21 @@ PyInit__gumath(void)
|
|
554
909
|
goto error;
|
555
910
|
}
|
556
911
|
|
912
|
+
positional_empty = PyTuple_New(0);
|
913
|
+
if (positional_empty == NULL) {
|
914
|
+
goto error;
|
915
|
+
}
|
916
|
+
|
557
917
|
m = PyModule_Create(&gumath_module);
|
558
918
|
if (m == NULL) {
|
559
919
|
goto error;
|
560
920
|
}
|
561
921
|
|
922
|
+
Py_INCREF(&Gufunc_Type);
|
923
|
+
if (PyModule_AddObject(m, "gufunc", (PyObject *)&Gufunc_Type) < 0) {
|
924
|
+
goto error;
|
925
|
+
}
|
926
|
+
|
562
927
|
Py_INCREF(capsule);
|
563
928
|
if (PyModule_AddObject(m, "_API", capsule) < 0) {
|
564
929
|
goto error;
|
@@ -571,6 +936,7 @@ PyInit__gumath(void)
|
|
571
936
|
return m;
|
572
937
|
|
573
938
|
error:
|
939
|
+
Py_CLEAR(positional_empty);
|
574
940
|
Py_CLEAR(xnd);
|
575
941
|
Py_CLEAR(m);
|
576
942
|
return NULL;
|