gumath 0.2.0dev5 → 0.2.0dev8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -27,10 +27,14 @@ LIBXNDDIR = ..\xnd\libxnd
|
|
27
27
|
|
28
28
|
OPT = /MT /Ox /GS /EHsc /fp:strict
|
29
29
|
OPT_SHARED = /DGM_EXPORT /DNDT_IMPORT /DXND_IMPORT /MD /Ox /GS /EHsc /fp:strict /Fo.objs^\
|
30
|
+
OPT_NOFP = /MT /Ox /GS /EHsc
|
31
|
+
OPT_SHARED_NOFP = /DGM_EXPORT /DNDT_IMPORT /DXND_IMPORT /MD /Ox /GS /EHsc /Fo.objs^\
|
30
32
|
|
31
33
|
COMMON_CFLAGS = /nologo /W4 /wd4200 /wd4201 /wd4204
|
32
34
|
CFLAGS = $(COMMON_CFLAGS) $(OPT)
|
33
35
|
CFLAGS_SHARED = $(COMMON_CFLAGS) $(OPT_SHARED)
|
36
|
+
CFLAGS_NOFP = $(COMMON_CFLAGS) $(OPT_NOFP)
|
37
|
+
CFLAGS_SHARED_NOFP = $(COMMON_CFLAGS) $(OPT_SHARED_NOFP)
|
34
38
|
|
35
39
|
|
36
40
|
default: $(LIBSTATIC) $(LIBSHARED)
|
@@ -40,11 +44,14 @@ default: $(LIBSTATIC) $(LIBSHARED)
|
|
40
44
|
copy /y $(LIBSHARED) ..\python\gumath
|
41
45
|
|
42
46
|
|
43
|
-
OBJS = apply.obj func.obj nploops.obj tbl.obj xndloops.obj \
|
44
|
-
|
47
|
+
OBJS = apply.obj func.obj nploops.obj tbl.obj xndloops.obj cpu_host_unary.obj \
|
48
|
+
cpu_device_unary.obj cpu_host_binary.obj cpu_device_binary.obj cpu_device_msvc.obj \
|
49
|
+
common.obj examples.obj graph.obj pdist.obj
|
45
50
|
|
46
51
|
SHARED_OBJS = .objs/apply.obj .objs/func.obj .objs/nploops.obj .objs/tbl.obj .objs/xndloops.obj \
|
47
|
-
.objs/
|
52
|
+
.objs/cpu_host_unary.obj .objs/cpu_device_unary.obj .objs/cpu_host_binary.obj \
|
53
|
+
.objs/cpu_device_binary.obj .objs/cpu_device_msvc.obj .objs/common.obj \
|
54
|
+
.objs/examples.obj .objs/graph.obj .objs/pdist.obj
|
48
55
|
|
49
56
|
|
50
57
|
$(LIBSTATIC):\
|
@@ -101,21 +108,53 @@ Makefile xndloops.c gumath.h
|
|
101
108
|
Makefile xndloops.c gumath.h
|
102
109
|
$(CC) "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c xndloops.c
|
103
110
|
|
104
|
-
|
105
|
-
Makefile kernels\
|
106
|
-
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\
|
111
|
+
cpu_host_unary.obj:\
|
112
|
+
Makefile kernels\cpu_host_unary.c kernels\common.h gumath.h
|
113
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_host_unary.c
|
107
114
|
|
108
|
-
.objs\
|
109
|
-
Makefile kernels\
|
110
|
-
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\
|
115
|
+
.objs\cpu_host_unary.obj:\
|
116
|
+
Makefile kernels\cpu_host_unary.c kernels\common.h gumath.h
|
117
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_host_unary.c
|
111
118
|
|
112
|
-
|
113
|
-
Makefile kernels\
|
114
|
-
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\
|
119
|
+
cpu_device_unary.obj:\
|
120
|
+
Makefile kernels\cpu_device_unary.cc kernels\common.h gumath.h
|
121
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_device_unary.cc
|
115
122
|
|
116
|
-
.objs\
|
117
|
-
Makefile kernels\
|
118
|
-
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\
|
123
|
+
.objs\cpu_device_unary.obj:\
|
124
|
+
Makefile kernels\cpu_device_unary.cc kernels\common.h gumath.h
|
125
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_device_unary.cc
|
126
|
+
|
127
|
+
cpu_host_binary.obj:\
|
128
|
+
Makefile kernels\cpu_host_binary.c kernels\common.h gumath.h
|
129
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_host_binary.c
|
130
|
+
|
131
|
+
.objs\cpu_host_binary.obj:\
|
132
|
+
Makefile kernels\cpu_host_binary.c kernels\common.h gumath.h
|
133
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_host_binary.c
|
134
|
+
|
135
|
+
cpu_device_binary.obj:\
|
136
|
+
Makefile kernels\cpu_device_binary.cc kernels\common.h gumath.h
|
137
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_device_binary.cc
|
138
|
+
|
139
|
+
.objs\cpu_device_binary.obj:\
|
140
|
+
Makefile kernels\cpu_device_binary.cc kernels\common.h gumath.h
|
141
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_device_binary.cc
|
142
|
+
|
143
|
+
cpu_device_msvc.obj:\
|
144
|
+
Makefile kernels\cpu_device_msvc.cc kernels\common.h gumath.h
|
145
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_NOFP) -c kernels\cpu_device_msvc.cc
|
146
|
+
|
147
|
+
.objs\cpu_device_msvc.obj:\
|
148
|
+
Makefile kernels\cpu_device_msvc.cc kernels\common.h gumath.h
|
149
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED_NOFP) -c kernels\cpu_device_msvc.cc
|
150
|
+
|
151
|
+
common.obj:\
|
152
|
+
Makefile kernels\common.c kernels\common.h gumath.h
|
153
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\common.c
|
154
|
+
|
155
|
+
.objs\common.obj:\
|
156
|
+
Makefile kernels\common.c kernels\common.h gumath.h
|
157
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\common.c
|
119
158
|
|
120
159
|
examples.obj:\
|
121
160
|
Makefile extending\examples.c gumath.h
|
@@ -40,6 +40,29 @@
|
|
40
40
|
#include "gumath.h"
|
41
41
|
|
42
42
|
|
43
|
+
/* flags that apply to all arguments */
|
44
|
+
#define OPT_Z (NDT_EXT_ZERO|NDT_INNER_C)
|
45
|
+
#define OPT_C (NDT_EXT_C|NDT_INNER_C)
|
46
|
+
#define OPT_S (NDT_EXT_STRIDED|NDT_INNER_STRIDED)
|
47
|
+
#define OPT_SC (NDT_EXT_STRIDED|NDT_INNER_C)
|
48
|
+
#define OPT_SF (NDT_EXT_STRIDED|NDT_INNER_F)
|
49
|
+
|
50
|
+
#define INNER_C (NDT_INNER_C)
|
51
|
+
#define INNER_F (NDT_INNER_F)
|
52
|
+
#define INNER_S (NDT_INNER_STRIDED)
|
53
|
+
#define INNER_X (NDT_INNER_XND)
|
54
|
+
|
55
|
+
/* kernel requests */
|
56
|
+
#define REQ_LOOP_C(flags) ((flags&OPT_C) == OPT_C)
|
57
|
+
#define REQ_LOOP_Z(flags) ((flags&OPT_C) == OPT_C || (flags&OPT_Z) == OPT_Z)
|
58
|
+
#define REQ_LOOP_S(flags) ((flags&OPT_S) == OPT_S)
|
59
|
+
|
60
|
+
#define REQ_INNER_C(flags) ((flags&INNER_C) == INNER_C)
|
61
|
+
#define REQ_INNER_F(flags) ((flags&INNER_F) == INNER_F)
|
62
|
+
#define REQ_INNER_S(flags) ((flags&INNER_S) == INNER_S)
|
63
|
+
#define REQ_INNER_X(flags) ((flags&INNER_X) == INNER_X)
|
64
|
+
|
65
|
+
|
43
66
|
static int
|
44
67
|
sum_inner_dimensions(const xnd_t stack[], int nargs, int outer_dims)
|
45
68
|
{
|
@@ -55,6 +78,18 @@ sum_inner_dimensions(const xnd_t stack[], int nargs, int outer_dims)
|
|
55
78
|
return sum;
|
56
79
|
}
|
57
80
|
|
81
|
+
static inline bool
|
82
|
+
opt_safe(int outer, ndt_context_t *ctx)
|
83
|
+
{
|
84
|
+
if (outer == 0) {
|
85
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
86
|
+
"internal error: optimized kernel called with outer_dims==0");
|
87
|
+
return false;
|
88
|
+
}
|
89
|
+
|
90
|
+
return true;
|
91
|
+
}
|
92
|
+
|
58
93
|
int
|
59
94
|
gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
60
95
|
ndt_context_t *ctx)
|
@@ -62,30 +97,43 @@ gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
62
97
|
const int nargs = (int)kernel->set->sig->Function.nargs;
|
63
98
|
|
64
99
|
switch (kernel->flag) {
|
65
|
-
case
|
66
|
-
if (outer_dims
|
67
|
-
ndt_err_format(ctx, NDT_RuntimeError,
|
68
|
-
"gm_xnd_map: optimized elementwise kernel called with "
|
69
|
-
"outer_dims==0");
|
100
|
+
case OPT_C: {
|
101
|
+
if (!opt_safe(outer_dims, ctx)) {
|
70
102
|
return -1;
|
71
103
|
}
|
72
104
|
|
73
|
-
return gm_xnd_map(kernel->set->
|
105
|
+
return gm_xnd_map(kernel->set->OptC, stack, nargs, outer_dims-1, ctx);
|
74
106
|
}
|
75
107
|
|
76
|
-
case
|
108
|
+
case OPT_Z: {
|
109
|
+
if (!opt_safe(outer_dims, ctx)) {
|
110
|
+
return -1;
|
111
|
+
}
|
112
|
+
|
113
|
+
return gm_xnd_map(kernel->set->OptZ, stack, nargs, outer_dims-1, ctx);
|
114
|
+
}
|
115
|
+
|
116
|
+
case OPT_S: {
|
117
|
+
if (!opt_safe(outer_dims, ctx)) {
|
118
|
+
return -1;
|
119
|
+
}
|
120
|
+
|
121
|
+
return gm_xnd_map(kernel->set->OptS, stack, nargs, outer_dims-1, ctx);
|
122
|
+
}
|
123
|
+
|
124
|
+
case INNER_C: {
|
77
125
|
return gm_xnd_map(kernel->set->C, stack, nargs, outer_dims, ctx);
|
78
126
|
}
|
79
127
|
|
80
|
-
case
|
128
|
+
case INNER_F: {
|
81
129
|
return gm_xnd_map(kernel->set->Fortran, stack, nargs, outer_dims, ctx);
|
82
130
|
}
|
83
131
|
|
84
|
-
case
|
132
|
+
case INNER_X: {
|
85
133
|
return gm_xnd_map(kernel->set->Xnd, stack, nargs, outer_dims, ctx);
|
86
134
|
}
|
87
135
|
|
88
|
-
case
|
136
|
+
case INNER_S: {
|
89
137
|
const int sum_inner = sum_inner_dimensions(stack, nargs, outer_dims);
|
90
138
|
const int dims_size = outer_dims + sum_inner;
|
91
139
|
const int steps_size = nargs * outer_dims + sum_inner;
|
@@ -105,8 +153,9 @@ gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
105
153
|
}
|
106
154
|
}
|
107
155
|
|
108
|
-
|
109
|
-
|
156
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
157
|
+
"internal error: kernel selection failed");
|
158
|
+
return -1;
|
110
159
|
}
|
111
160
|
|
112
161
|
static gm_kernel_t
|
@@ -117,35 +166,48 @@ select_kernel(const ndt_apply_spec_t *spec, const gm_kernel_set_t *set,
|
|
117
166
|
|
118
167
|
kernel.set = set;
|
119
168
|
|
120
|
-
if (set->
|
121
|
-
kernel.flag =
|
169
|
+
if (REQ_LOOP_C(spec->flags) && set->OptC != NULL) {
|
170
|
+
kernel.flag = OPT_C;
|
171
|
+
return kernel;
|
172
|
+
}
|
173
|
+
|
174
|
+
if (REQ_LOOP_Z(spec->flags) && set->OptZ != NULL) {
|
175
|
+
kernel.flag = OPT_Z;
|
176
|
+
return kernel;
|
177
|
+
}
|
178
|
+
|
179
|
+
if (REQ_LOOP_S(spec->flags) && set->OptS != NULL) {
|
180
|
+
kernel.flag = OPT_S;
|
122
181
|
return kernel;
|
123
182
|
}
|
124
183
|
|
125
|
-
if (set->C != NULL
|
126
|
-
kernel.flag =
|
184
|
+
if (REQ_INNER_C(spec->flags) && set->C != NULL) {
|
185
|
+
kernel.flag = INNER_C;
|
127
186
|
return kernel;
|
128
187
|
}
|
129
188
|
|
130
|
-
if (set->Fortran != NULL
|
131
|
-
kernel.flag =
|
189
|
+
if (REQ_INNER_F(spec->flags) && set->Fortran != NULL) {
|
190
|
+
kernel.flag = INNER_F;
|
132
191
|
return kernel;
|
133
192
|
}
|
134
193
|
|
135
|
-
if (set->Strided != NULL
|
136
|
-
kernel.flag =
|
194
|
+
if (REQ_INNER_S(spec->flags) && set->Strided != NULL) {
|
195
|
+
kernel.flag = INNER_S;
|
137
196
|
return kernel;
|
138
197
|
}
|
139
198
|
|
140
|
-
if (set->Xnd != NULL
|
141
|
-
kernel.flag =
|
199
|
+
if (REQ_INNER_X(spec->flags) && set->Xnd != NULL) {
|
200
|
+
kernel.flag = INNER_X;
|
142
201
|
return kernel;
|
143
202
|
}
|
144
203
|
|
145
204
|
kernel.set = NULL;
|
146
205
|
ndt_err_format(ctx, NDT_RuntimeError,
|
147
|
-
"could not find specialized kernel for '%s' input (available: %s, %s, %s, %s)",
|
206
|
+
"could not find specialized kernel for '%s' input (available: %s, %s, %s, %s, %s, %s, %s)",
|
148
207
|
ndt_apply_flags_as_string(spec),
|
208
|
+
set->OptC ? "OptC" : "_",
|
209
|
+
set->OptZ ? "OptZ" : "_",
|
210
|
+
set->OptS ? "OptS" : "_",
|
149
211
|
set->C ? "C" : "_",
|
150
212
|
set->Fortran ? "Fortran" : "_",
|
151
213
|
set->Xnd ? "Xnd" : "_",
|
@@ -157,8 +219,8 @@ select_kernel(const ndt_apply_spec_t *spec, const gm_kernel_set_t *set,
|
|
157
219
|
/* Look up a multimethod by name and select a kernel. */
|
158
220
|
gm_kernel_t
|
159
221
|
gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
160
|
-
const ndt_t *
|
161
|
-
ndt_context_t *ctx)
|
222
|
+
const ndt_t *types[], const int64_t li[], int nin, int nout,
|
223
|
+
bool check_broadcast, const xnd_t args[], ndt_context_t *ctx)
|
162
224
|
{
|
163
225
|
gm_kernel_t empty_kernel = {0U, NULL};
|
164
226
|
const gm_func_t *f;
|
@@ -171,7 +233,8 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
|
171
233
|
}
|
172
234
|
|
173
235
|
if (f->typecheck != NULL) {
|
174
|
-
const gm_kernel_set_t *set = f->typecheck(spec, f,
|
236
|
+
const gm_kernel_set_t *set = f->typecheck(spec, f, types, li, nin, nout,
|
237
|
+
check_broadcast, ctx);
|
175
238
|
if (set == NULL) {
|
176
239
|
return empty_kernel;
|
177
240
|
}
|
@@ -180,7 +243,8 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
|
180
243
|
|
181
244
|
for (i = 0; i < f->nkernels; i++) {
|
182
245
|
const gm_kernel_set_t *set = &f->kernels[i];
|
183
|
-
if (ndt_typecheck(spec, set->sig,
|
246
|
+
if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
|
247
|
+
check_broadcast, set->constraint, args,
|
184
248
|
ctx) < 0) {
|
185
249
|
ndt_err_clear(ctx);
|
186
250
|
continue;
|
@@ -188,7 +252,7 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
|
188
252
|
return select_kernel(spec, set, ctx);
|
189
253
|
}
|
190
254
|
|
191
|
-
s = ndt_list_as_string(
|
255
|
+
s = ndt_list_as_string(types, nin, ctx);
|
192
256
|
if (s == NULL) {
|
193
257
|
return empty_kernel;
|
194
258
|
}
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -181,9 +181,10 @@ static xnd_t
|
|
181
181
|
mk_return_array(int32_t p[], const int64_t N, const int32_t u,
|
182
182
|
ndt_context_t *ctx)
|
183
183
|
{
|
184
|
-
|
185
|
-
|
186
|
-
|
184
|
+
ndt_offsets_t *ndim2_offsets = NULL;
|
185
|
+
ndt_offsets_t *ndim1_offsets = NULL;
|
186
|
+
int32_t *ptr;
|
187
|
+
const ndt_t *t, *type;
|
187
188
|
int64_t sum;
|
188
189
|
int32_t v;
|
189
190
|
|
@@ -191,45 +192,51 @@ mk_return_array(int32_t p[], const int64_t N, const int32_t u,
|
|
191
192
|
goto offset_overflow;
|
192
193
|
}
|
193
194
|
|
194
|
-
ndim2_offsets =
|
195
|
+
ndim2_offsets = ndt_offsets_new(2, ctx);
|
195
196
|
if (ndim2_offsets == NULL) {
|
196
|
-
(void)ndt_memory_error(ctx);
|
197
197
|
return xnd_error;
|
198
198
|
}
|
199
|
-
|
200
|
-
|
199
|
+
ptr = (int32_t *)ndim2_offsets->v;
|
200
|
+
ptr[0] = 0;
|
201
|
+
ptr[1] = (int32_t)N;
|
201
202
|
|
202
203
|
|
203
|
-
ndim1_offsets =
|
204
|
+
ndim1_offsets = ndt_offsets_new((int32_t)(N+1), ctx);
|
204
205
|
if (ndim1_offsets == NULL) {
|
205
|
-
(
|
206
|
+
ndt_decref_offsets(ndim2_offsets);
|
206
207
|
return xnd_error;
|
207
208
|
}
|
208
209
|
|
209
210
|
sum = 0;
|
211
|
+
ptr = (int32_t *)ndim1_offsets->v;
|
210
212
|
for (v = 0; v < N; v++) {
|
211
|
-
|
213
|
+
ptr[v] = (int32_t)sum;
|
212
214
|
int64_t n = write_path(NULL, 0, p, N, u, v);
|
213
215
|
sum += n;
|
214
216
|
if (sum > INT32_MAX) {
|
215
217
|
goto offset_overflow;
|
216
218
|
}
|
217
219
|
}
|
218
|
-
|
220
|
+
ptr[v] = (int32_t)sum;
|
219
221
|
|
220
222
|
|
221
|
-
|
222
|
-
if (
|
223
|
+
type = ndt_from_string("node", ctx);
|
224
|
+
if (type == NULL) {
|
223
225
|
goto error;
|
224
226
|
}
|
225
227
|
|
226
|
-
t = ndt_var_dim(
|
228
|
+
t = ndt_var_dim(type, ndim1_offsets, 0, NULL, false, ctx);
|
229
|
+
ndt_decref(type);
|
230
|
+
ndt_decref_offsets(ndim1_offsets);
|
227
231
|
ndim1_offsets = NULL;
|
228
232
|
if (t == NULL) {
|
229
233
|
goto error;
|
230
234
|
}
|
235
|
+
type = t;
|
231
236
|
|
232
|
-
t = ndt_var_dim(
|
237
|
+
t = ndt_var_dim(type, ndim2_offsets, 0, NULL, false, ctx);
|
238
|
+
ndt_decref(type);
|
239
|
+
ndt_decref_offsets(ndim2_offsets);
|
233
240
|
ndim2_offsets = NULL;
|
234
241
|
if (t == NULL) {
|
235
242
|
goto error;
|
@@ -244,16 +251,16 @@ mk_return_array(int32_t p[], const int64_t N, const int32_t u,
|
|
244
251
|
|
245
252
|
t = out.type->VarDim.type;
|
246
253
|
for (v = 0; v < N; v++) {
|
247
|
-
int32_t shape = t->Concrete.VarDim.offsets[v+1]-t->Concrete.VarDim.offsets[v];
|
248
|
-
char *
|
249
|
-
(void)write_path((int32_t *)
|
254
|
+
int32_t shape = t->Concrete.VarDim.offsets->v[v+1]-t->Concrete.VarDim.offsets->v[v];
|
255
|
+
char *cp = out.ptr + t->Concrete.VarDim.offsets->v[v] * t->Concrete.VarDim.itemsize;
|
256
|
+
(void)write_path((int32_t *)cp, shape, p, N, u, v);
|
250
257
|
}
|
251
258
|
|
252
259
|
return out;
|
253
260
|
|
254
261
|
error:
|
255
|
-
|
256
|
-
|
262
|
+
ndt_decref_offsets(ndim2_offsets);
|
263
|
+
ndt_decref_offsets(ndim1_offsets);
|
257
264
|
return xnd_error;
|
258
265
|
|
259
266
|
offset_overflow:
|
@@ -73,7 +73,7 @@ gm_func_del(gm_func_t *f)
|
|
73
73
|
ndt_free(f->name);
|
74
74
|
|
75
75
|
for (int i = 0; i < f->nkernels; i++) {
|
76
|
-
|
76
|
+
ndt_decref(f->kernels[i].sig);
|
77
77
|
}
|
78
78
|
|
79
79
|
ndt_free(f);
|
@@ -101,7 +101,7 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
|
|
101
101
|
{
|
102
102
|
gm_func_t *f = gm_tbl_find(tbl, k->name, ctx);
|
103
103
|
gm_kernel_set_t kernel;
|
104
|
-
ndt_t *t;
|
104
|
+
const ndt_t *t;
|
105
105
|
|
106
106
|
if (f == NULL) {
|
107
107
|
ndt_err_clear(ctx);
|
@@ -117,7 +117,7 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
|
|
117
117
|
}
|
118
118
|
|
119
119
|
if (f->nkernels == GM_MAX_KERNELS) {
|
120
|
-
|
120
|
+
ndt_decref(t);
|
121
121
|
ndt_err_format(ctx, NDT_RuntimeError,
|
122
122
|
"%s: maximum number of kernels reached for", f->name);
|
123
123
|
return -1;
|
@@ -125,11 +125,13 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
|
|
125
125
|
|
126
126
|
kernel.sig = t;
|
127
127
|
kernel.constraint = k->constraint;
|
128
|
-
kernel.
|
128
|
+
kernel.OptC = k->OptC;
|
129
|
+
kernel.OptZ = k->OptZ;
|
130
|
+
kernel.OptS = k->OptS;
|
129
131
|
kernel.C = k->C;
|
130
132
|
kernel.Fortran = k->Fortran;
|
131
|
-
kernel.Strided = k->Strided;
|
132
133
|
kernel.Xnd = k->Xnd;
|
134
|
+
kernel.Strided = k->Strided;
|
133
135
|
|
134
136
|
f->kernels[f->nkernels++] = kernel;
|
135
137
|
return 0;
|
@@ -141,7 +143,7 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
|
|
141
143
|
{
|
142
144
|
gm_func_t *f = gm_tbl_find(tbl, k->name, ctx);
|
143
145
|
gm_kernel_set_t kernel;
|
144
|
-
ndt_t *t;
|
146
|
+
const ndt_t *t;
|
145
147
|
|
146
148
|
if (f == NULL) {
|
147
149
|
ndt_err_clear(ctx);
|
@@ -158,7 +160,7 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
|
|
158
160
|
}
|
159
161
|
|
160
162
|
if (f->nkernels == GM_MAX_KERNELS) {
|
161
|
-
|
163
|
+
ndt_decref(t);
|
162
164
|
ndt_err_format(ctx, NDT_RuntimeError,
|
163
165
|
"%s: maximum number of kernels reached for", f->name);
|
164
166
|
return -1;
|
@@ -166,11 +168,13 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
|
|
166
168
|
|
167
169
|
kernel.sig = t;
|
168
170
|
kernel.constraint = k->constraint;
|
169
|
-
kernel.
|
171
|
+
kernel.OptC = k->OptC;
|
172
|
+
kernel.OptZ = k->OptZ;
|
173
|
+
kernel.OptS = k->OptS;
|
170
174
|
kernel.C = k->C;
|
171
175
|
kernel.Fortran = k->Fortran;
|
172
|
-
kernel.Strided = k->Strided;
|
173
176
|
kernel.Xnd = k->Xnd;
|
177
|
+
kernel.Strided = k->Strided;
|
174
178
|
|
175
179
|
f->kernels[f->nkernels++] = kernel;
|
176
180
|
return 0;
|