gumath 0.2.0dev5 → 0.2.0dev8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
|
@@ -27,10 +27,14 @@ LIBXNDDIR = ..\xnd\libxnd
|
|
|
27
27
|
|
|
28
28
|
OPT = /MT /Ox /GS /EHsc /fp:strict
|
|
29
29
|
OPT_SHARED = /DGM_EXPORT /DNDT_IMPORT /DXND_IMPORT /MD /Ox /GS /EHsc /fp:strict /Fo.objs^\
|
|
30
|
+
OPT_NOFP = /MT /Ox /GS /EHsc
|
|
31
|
+
OPT_SHARED_NOFP = /DGM_EXPORT /DNDT_IMPORT /DXND_IMPORT /MD /Ox /GS /EHsc /Fo.objs^\
|
|
30
32
|
|
|
31
33
|
COMMON_CFLAGS = /nologo /W4 /wd4200 /wd4201 /wd4204
|
|
32
34
|
CFLAGS = $(COMMON_CFLAGS) $(OPT)
|
|
33
35
|
CFLAGS_SHARED = $(COMMON_CFLAGS) $(OPT_SHARED)
|
|
36
|
+
CFLAGS_NOFP = $(COMMON_CFLAGS) $(OPT_NOFP)
|
|
37
|
+
CFLAGS_SHARED_NOFP = $(COMMON_CFLAGS) $(OPT_SHARED_NOFP)
|
|
34
38
|
|
|
35
39
|
|
|
36
40
|
default: $(LIBSTATIC) $(LIBSHARED)
|
|
@@ -40,11 +44,14 @@ default: $(LIBSTATIC) $(LIBSHARED)
|
|
|
40
44
|
copy /y $(LIBSHARED) ..\python\gumath
|
|
41
45
|
|
|
42
46
|
|
|
43
|
-
OBJS = apply.obj func.obj nploops.obj tbl.obj xndloops.obj \
|
|
44
|
-
|
|
47
|
+
OBJS = apply.obj func.obj nploops.obj tbl.obj xndloops.obj cpu_host_unary.obj \
|
|
48
|
+
cpu_device_unary.obj cpu_host_binary.obj cpu_device_binary.obj cpu_device_msvc.obj \
|
|
49
|
+
common.obj examples.obj graph.obj pdist.obj
|
|
45
50
|
|
|
46
51
|
SHARED_OBJS = .objs/apply.obj .objs/func.obj .objs/nploops.obj .objs/tbl.obj .objs/xndloops.obj \
|
|
47
|
-
.objs/
|
|
52
|
+
.objs/cpu_host_unary.obj .objs/cpu_device_unary.obj .objs/cpu_host_binary.obj \
|
|
53
|
+
.objs/cpu_device_binary.obj .objs/cpu_device_msvc.obj .objs/common.obj \
|
|
54
|
+
.objs/examples.obj .objs/graph.obj .objs/pdist.obj
|
|
48
55
|
|
|
49
56
|
|
|
50
57
|
$(LIBSTATIC):\
|
|
@@ -101,21 +108,53 @@ Makefile xndloops.c gumath.h
|
|
|
101
108
|
Makefile xndloops.c gumath.h
|
|
102
109
|
$(CC) "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c xndloops.c
|
|
103
110
|
|
|
104
|
-
|
|
105
|
-
Makefile kernels\
|
|
106
|
-
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\
|
|
111
|
+
cpu_host_unary.obj:\
|
|
112
|
+
Makefile kernels\cpu_host_unary.c kernels\common.h gumath.h
|
|
113
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_host_unary.c
|
|
107
114
|
|
|
108
|
-
.objs\
|
|
109
|
-
Makefile kernels\
|
|
110
|
-
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\
|
|
115
|
+
.objs\cpu_host_unary.obj:\
|
|
116
|
+
Makefile kernels\cpu_host_unary.c kernels\common.h gumath.h
|
|
117
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_host_unary.c
|
|
111
118
|
|
|
112
|
-
|
|
113
|
-
Makefile kernels\
|
|
114
|
-
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\
|
|
119
|
+
cpu_device_unary.obj:\
|
|
120
|
+
Makefile kernels\cpu_device_unary.cc kernels\common.h gumath.h
|
|
121
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_device_unary.cc
|
|
115
122
|
|
|
116
|
-
.objs\
|
|
117
|
-
Makefile kernels\
|
|
118
|
-
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\
|
|
123
|
+
.objs\cpu_device_unary.obj:\
|
|
124
|
+
Makefile kernels\cpu_device_unary.cc kernels\common.h gumath.h
|
|
125
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_device_unary.cc
|
|
126
|
+
|
|
127
|
+
cpu_host_binary.obj:\
|
|
128
|
+
Makefile kernels\cpu_host_binary.c kernels\common.h gumath.h
|
|
129
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_host_binary.c
|
|
130
|
+
|
|
131
|
+
.objs\cpu_host_binary.obj:\
|
|
132
|
+
Makefile kernels\cpu_host_binary.c kernels\common.h gumath.h
|
|
133
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_host_binary.c
|
|
134
|
+
|
|
135
|
+
cpu_device_binary.obj:\
|
|
136
|
+
Makefile kernels\cpu_device_binary.cc kernels\common.h gumath.h
|
|
137
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\cpu_device_binary.cc
|
|
138
|
+
|
|
139
|
+
.objs\cpu_device_binary.obj:\
|
|
140
|
+
Makefile kernels\cpu_device_binary.cc kernels\common.h gumath.h
|
|
141
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\cpu_device_binary.cc
|
|
142
|
+
|
|
143
|
+
cpu_device_msvc.obj:\
|
|
144
|
+
Makefile kernels\cpu_device_msvc.cc kernels\common.h gumath.h
|
|
145
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_NOFP) -c kernels\cpu_device_msvc.cc
|
|
146
|
+
|
|
147
|
+
.objs\cpu_device_msvc.obj:\
|
|
148
|
+
Makefile kernels\cpu_device_msvc.cc kernels\common.h gumath.h
|
|
149
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED_NOFP) -c kernels\cpu_device_msvc.cc
|
|
150
|
+
|
|
151
|
+
common.obj:\
|
|
152
|
+
Makefile kernels\common.c kernels\common.h gumath.h
|
|
153
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS) -c kernels\common.c
|
|
154
|
+
|
|
155
|
+
.objs\common.obj:\
|
|
156
|
+
Makefile kernels\common.c kernels\common.h gumath.h
|
|
157
|
+
$(CC) -I. "-I$(LIBNDTYPESINCLUDE)" "-I$(LIBXNDINCLUDE)" $(CFLAGS_SHARED) -c kernels\common.c
|
|
119
158
|
|
|
120
159
|
examples.obj:\
|
|
121
160
|
Makefile extending\examples.c gumath.h
|
|
@@ -40,6 +40,29 @@
|
|
|
40
40
|
#include "gumath.h"
|
|
41
41
|
|
|
42
42
|
|
|
43
|
+
/* flags that apply to all arguments */
|
|
44
|
+
#define OPT_Z (NDT_EXT_ZERO|NDT_INNER_C)
|
|
45
|
+
#define OPT_C (NDT_EXT_C|NDT_INNER_C)
|
|
46
|
+
#define OPT_S (NDT_EXT_STRIDED|NDT_INNER_STRIDED)
|
|
47
|
+
#define OPT_SC (NDT_EXT_STRIDED|NDT_INNER_C)
|
|
48
|
+
#define OPT_SF (NDT_EXT_STRIDED|NDT_INNER_F)
|
|
49
|
+
|
|
50
|
+
#define INNER_C (NDT_INNER_C)
|
|
51
|
+
#define INNER_F (NDT_INNER_F)
|
|
52
|
+
#define INNER_S (NDT_INNER_STRIDED)
|
|
53
|
+
#define INNER_X (NDT_INNER_XND)
|
|
54
|
+
|
|
55
|
+
/* kernel requests */
|
|
56
|
+
#define REQ_LOOP_C(flags) ((flags&OPT_C) == OPT_C)
|
|
57
|
+
#define REQ_LOOP_Z(flags) ((flags&OPT_C) == OPT_C || (flags&OPT_Z) == OPT_Z)
|
|
58
|
+
#define REQ_LOOP_S(flags) ((flags&OPT_S) == OPT_S)
|
|
59
|
+
|
|
60
|
+
#define REQ_INNER_C(flags) ((flags&INNER_C) == INNER_C)
|
|
61
|
+
#define REQ_INNER_F(flags) ((flags&INNER_F) == INNER_F)
|
|
62
|
+
#define REQ_INNER_S(flags) ((flags&INNER_S) == INNER_S)
|
|
63
|
+
#define REQ_INNER_X(flags) ((flags&INNER_X) == INNER_X)
|
|
64
|
+
|
|
65
|
+
|
|
43
66
|
static int
|
|
44
67
|
sum_inner_dimensions(const xnd_t stack[], int nargs, int outer_dims)
|
|
45
68
|
{
|
|
@@ -55,6 +78,18 @@ sum_inner_dimensions(const xnd_t stack[], int nargs, int outer_dims)
|
|
|
55
78
|
return sum;
|
|
56
79
|
}
|
|
57
80
|
|
|
81
|
+
static inline bool
|
|
82
|
+
opt_safe(int outer, ndt_context_t *ctx)
|
|
83
|
+
{
|
|
84
|
+
if (outer == 0) {
|
|
85
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
86
|
+
"internal error: optimized kernel called with outer_dims==0");
|
|
87
|
+
return false;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
return true;
|
|
91
|
+
}
|
|
92
|
+
|
|
58
93
|
int
|
|
59
94
|
gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
60
95
|
ndt_context_t *ctx)
|
|
@@ -62,30 +97,43 @@ gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
|
62
97
|
const int nargs = (int)kernel->set->sig->Function.nargs;
|
|
63
98
|
|
|
64
99
|
switch (kernel->flag) {
|
|
65
|
-
case
|
|
66
|
-
if (outer_dims
|
|
67
|
-
ndt_err_format(ctx, NDT_RuntimeError,
|
|
68
|
-
"gm_xnd_map: optimized elementwise kernel called with "
|
|
69
|
-
"outer_dims==0");
|
|
100
|
+
case OPT_C: {
|
|
101
|
+
if (!opt_safe(outer_dims, ctx)) {
|
|
70
102
|
return -1;
|
|
71
103
|
}
|
|
72
104
|
|
|
73
|
-
return gm_xnd_map(kernel->set->
|
|
105
|
+
return gm_xnd_map(kernel->set->OptC, stack, nargs, outer_dims-1, ctx);
|
|
74
106
|
}
|
|
75
107
|
|
|
76
|
-
case
|
|
108
|
+
case OPT_Z: {
|
|
109
|
+
if (!opt_safe(outer_dims, ctx)) {
|
|
110
|
+
return -1;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
return gm_xnd_map(kernel->set->OptZ, stack, nargs, outer_dims-1, ctx);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
case OPT_S: {
|
|
117
|
+
if (!opt_safe(outer_dims, ctx)) {
|
|
118
|
+
return -1;
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
return gm_xnd_map(kernel->set->OptS, stack, nargs, outer_dims-1, ctx);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
case INNER_C: {
|
|
77
125
|
return gm_xnd_map(kernel->set->C, stack, nargs, outer_dims, ctx);
|
|
78
126
|
}
|
|
79
127
|
|
|
80
|
-
case
|
|
128
|
+
case INNER_F: {
|
|
81
129
|
return gm_xnd_map(kernel->set->Fortran, stack, nargs, outer_dims, ctx);
|
|
82
130
|
}
|
|
83
131
|
|
|
84
|
-
case
|
|
132
|
+
case INNER_X: {
|
|
85
133
|
return gm_xnd_map(kernel->set->Xnd, stack, nargs, outer_dims, ctx);
|
|
86
134
|
}
|
|
87
135
|
|
|
88
|
-
case
|
|
136
|
+
case INNER_S: {
|
|
89
137
|
const int sum_inner = sum_inner_dimensions(stack, nargs, outer_dims);
|
|
90
138
|
const int dims_size = outer_dims + sum_inner;
|
|
91
139
|
const int steps_size = nargs * outer_dims + sum_inner;
|
|
@@ -105,8 +153,9 @@ gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims,
|
|
|
105
153
|
}
|
|
106
154
|
}
|
|
107
155
|
|
|
108
|
-
|
|
109
|
-
|
|
156
|
+
ndt_err_format(ctx, NDT_RuntimeError,
|
|
157
|
+
"internal error: kernel selection failed");
|
|
158
|
+
return -1;
|
|
110
159
|
}
|
|
111
160
|
|
|
112
161
|
static gm_kernel_t
|
|
@@ -117,35 +166,48 @@ select_kernel(const ndt_apply_spec_t *spec, const gm_kernel_set_t *set,
|
|
|
117
166
|
|
|
118
167
|
kernel.set = set;
|
|
119
168
|
|
|
120
|
-
if (set->
|
|
121
|
-
kernel.flag =
|
|
169
|
+
if (REQ_LOOP_C(spec->flags) && set->OptC != NULL) {
|
|
170
|
+
kernel.flag = OPT_C;
|
|
171
|
+
return kernel;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
if (REQ_LOOP_Z(spec->flags) && set->OptZ != NULL) {
|
|
175
|
+
kernel.flag = OPT_Z;
|
|
176
|
+
return kernel;
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
if (REQ_LOOP_S(spec->flags) && set->OptS != NULL) {
|
|
180
|
+
kernel.flag = OPT_S;
|
|
122
181
|
return kernel;
|
|
123
182
|
}
|
|
124
183
|
|
|
125
|
-
if (set->C != NULL
|
|
126
|
-
kernel.flag =
|
|
184
|
+
if (REQ_INNER_C(spec->flags) && set->C != NULL) {
|
|
185
|
+
kernel.flag = INNER_C;
|
|
127
186
|
return kernel;
|
|
128
187
|
}
|
|
129
188
|
|
|
130
|
-
if (set->Fortran != NULL
|
|
131
|
-
kernel.flag =
|
|
189
|
+
if (REQ_INNER_F(spec->flags) && set->Fortran != NULL) {
|
|
190
|
+
kernel.flag = INNER_F;
|
|
132
191
|
return kernel;
|
|
133
192
|
}
|
|
134
193
|
|
|
135
|
-
if (set->Strided != NULL
|
|
136
|
-
kernel.flag =
|
|
194
|
+
if (REQ_INNER_S(spec->flags) && set->Strided != NULL) {
|
|
195
|
+
kernel.flag = INNER_S;
|
|
137
196
|
return kernel;
|
|
138
197
|
}
|
|
139
198
|
|
|
140
|
-
if (set->Xnd != NULL
|
|
141
|
-
kernel.flag =
|
|
199
|
+
if (REQ_INNER_X(spec->flags) && set->Xnd != NULL) {
|
|
200
|
+
kernel.flag = INNER_X;
|
|
142
201
|
return kernel;
|
|
143
202
|
}
|
|
144
203
|
|
|
145
204
|
kernel.set = NULL;
|
|
146
205
|
ndt_err_format(ctx, NDT_RuntimeError,
|
|
147
|
-
"could not find specialized kernel for '%s' input (available: %s, %s, %s, %s)",
|
|
206
|
+
"could not find specialized kernel for '%s' input (available: %s, %s, %s, %s, %s, %s, %s)",
|
|
148
207
|
ndt_apply_flags_as_string(spec),
|
|
208
|
+
set->OptC ? "OptC" : "_",
|
|
209
|
+
set->OptZ ? "OptZ" : "_",
|
|
210
|
+
set->OptS ? "OptS" : "_",
|
|
149
211
|
set->C ? "C" : "_",
|
|
150
212
|
set->Fortran ? "Fortran" : "_",
|
|
151
213
|
set->Xnd ? "Xnd" : "_",
|
|
@@ -157,8 +219,8 @@ select_kernel(const ndt_apply_spec_t *spec, const gm_kernel_set_t *set,
|
|
|
157
219
|
/* Look up a multimethod by name and select a kernel. */
|
|
158
220
|
gm_kernel_t
|
|
159
221
|
gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
|
160
|
-
const ndt_t *
|
|
161
|
-
ndt_context_t *ctx)
|
|
222
|
+
const ndt_t *types[], const int64_t li[], int nin, int nout,
|
|
223
|
+
bool check_broadcast, const xnd_t args[], ndt_context_t *ctx)
|
|
162
224
|
{
|
|
163
225
|
gm_kernel_t empty_kernel = {0U, NULL};
|
|
164
226
|
const gm_func_t *f;
|
|
@@ -171,7 +233,8 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
|
|
171
233
|
}
|
|
172
234
|
|
|
173
235
|
if (f->typecheck != NULL) {
|
|
174
|
-
const gm_kernel_set_t *set = f->typecheck(spec, f,
|
|
236
|
+
const gm_kernel_set_t *set = f->typecheck(spec, f, types, li, nin, nout,
|
|
237
|
+
check_broadcast, ctx);
|
|
175
238
|
if (set == NULL) {
|
|
176
239
|
return empty_kernel;
|
|
177
240
|
}
|
|
@@ -180,7 +243,8 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
|
|
180
243
|
|
|
181
244
|
for (i = 0; i < f->nkernels; i++) {
|
|
182
245
|
const gm_kernel_set_t *set = &f->kernels[i];
|
|
183
|
-
if (ndt_typecheck(spec, set->sig,
|
|
246
|
+
if (ndt_typecheck(spec, set->sig, types, li, nin, nout,
|
|
247
|
+
check_broadcast, set->constraint, args,
|
|
184
248
|
ctx) < 0) {
|
|
185
249
|
ndt_err_clear(ctx);
|
|
186
250
|
continue;
|
|
@@ -188,7 +252,7 @@ gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
|
|
188
252
|
return select_kernel(spec, set, ctx);
|
|
189
253
|
}
|
|
190
254
|
|
|
191
|
-
s = ndt_list_as_string(
|
|
255
|
+
s = ndt_list_as_string(types, nin, ctx);
|
|
192
256
|
if (s == NULL) {
|
|
193
257
|
return empty_kernel;
|
|
194
258
|
}
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -181,9 +181,10 @@ static xnd_t
|
|
|
181
181
|
mk_return_array(int32_t p[], const int64_t N, const int32_t u,
|
|
182
182
|
ndt_context_t *ctx)
|
|
183
183
|
{
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
184
|
+
ndt_offsets_t *ndim2_offsets = NULL;
|
|
185
|
+
ndt_offsets_t *ndim1_offsets = NULL;
|
|
186
|
+
int32_t *ptr;
|
|
187
|
+
const ndt_t *t, *type;
|
|
187
188
|
int64_t sum;
|
|
188
189
|
int32_t v;
|
|
189
190
|
|
|
@@ -191,45 +192,51 @@ mk_return_array(int32_t p[], const int64_t N, const int32_t u,
|
|
|
191
192
|
goto offset_overflow;
|
|
192
193
|
}
|
|
193
194
|
|
|
194
|
-
ndim2_offsets =
|
|
195
|
+
ndim2_offsets = ndt_offsets_new(2, ctx);
|
|
195
196
|
if (ndim2_offsets == NULL) {
|
|
196
|
-
(void)ndt_memory_error(ctx);
|
|
197
197
|
return xnd_error;
|
|
198
198
|
}
|
|
199
|
-
|
|
200
|
-
|
|
199
|
+
ptr = (int32_t *)ndim2_offsets->v;
|
|
200
|
+
ptr[0] = 0;
|
|
201
|
+
ptr[1] = (int32_t)N;
|
|
201
202
|
|
|
202
203
|
|
|
203
|
-
ndim1_offsets =
|
|
204
|
+
ndim1_offsets = ndt_offsets_new((int32_t)(N+1), ctx);
|
|
204
205
|
if (ndim1_offsets == NULL) {
|
|
205
|
-
(
|
|
206
|
+
ndt_decref_offsets(ndim2_offsets);
|
|
206
207
|
return xnd_error;
|
|
207
208
|
}
|
|
208
209
|
|
|
209
210
|
sum = 0;
|
|
211
|
+
ptr = (int32_t *)ndim1_offsets->v;
|
|
210
212
|
for (v = 0; v < N; v++) {
|
|
211
|
-
|
|
213
|
+
ptr[v] = (int32_t)sum;
|
|
212
214
|
int64_t n = write_path(NULL, 0, p, N, u, v);
|
|
213
215
|
sum += n;
|
|
214
216
|
if (sum > INT32_MAX) {
|
|
215
217
|
goto offset_overflow;
|
|
216
218
|
}
|
|
217
219
|
}
|
|
218
|
-
|
|
220
|
+
ptr[v] = (int32_t)sum;
|
|
219
221
|
|
|
220
222
|
|
|
221
|
-
|
|
222
|
-
if (
|
|
223
|
+
type = ndt_from_string("node", ctx);
|
|
224
|
+
if (type == NULL) {
|
|
223
225
|
goto error;
|
|
224
226
|
}
|
|
225
227
|
|
|
226
|
-
t = ndt_var_dim(
|
|
228
|
+
t = ndt_var_dim(type, ndim1_offsets, 0, NULL, false, ctx);
|
|
229
|
+
ndt_decref(type);
|
|
230
|
+
ndt_decref_offsets(ndim1_offsets);
|
|
227
231
|
ndim1_offsets = NULL;
|
|
228
232
|
if (t == NULL) {
|
|
229
233
|
goto error;
|
|
230
234
|
}
|
|
235
|
+
type = t;
|
|
231
236
|
|
|
232
|
-
t = ndt_var_dim(
|
|
237
|
+
t = ndt_var_dim(type, ndim2_offsets, 0, NULL, false, ctx);
|
|
238
|
+
ndt_decref(type);
|
|
239
|
+
ndt_decref_offsets(ndim2_offsets);
|
|
233
240
|
ndim2_offsets = NULL;
|
|
234
241
|
if (t == NULL) {
|
|
235
242
|
goto error;
|
|
@@ -244,16 +251,16 @@ mk_return_array(int32_t p[], const int64_t N, const int32_t u,
|
|
|
244
251
|
|
|
245
252
|
t = out.type->VarDim.type;
|
|
246
253
|
for (v = 0; v < N; v++) {
|
|
247
|
-
int32_t shape = t->Concrete.VarDim.offsets[v+1]-t->Concrete.VarDim.offsets[v];
|
|
248
|
-
char *
|
|
249
|
-
(void)write_path((int32_t *)
|
|
254
|
+
int32_t shape = t->Concrete.VarDim.offsets->v[v+1]-t->Concrete.VarDim.offsets->v[v];
|
|
255
|
+
char *cp = out.ptr + t->Concrete.VarDim.offsets->v[v] * t->Concrete.VarDim.itemsize;
|
|
256
|
+
(void)write_path((int32_t *)cp, shape, p, N, u, v);
|
|
250
257
|
}
|
|
251
258
|
|
|
252
259
|
return out;
|
|
253
260
|
|
|
254
261
|
error:
|
|
255
|
-
|
|
256
|
-
|
|
262
|
+
ndt_decref_offsets(ndim2_offsets);
|
|
263
|
+
ndt_decref_offsets(ndim1_offsets);
|
|
257
264
|
return xnd_error;
|
|
258
265
|
|
|
259
266
|
offset_overflow:
|
|
@@ -73,7 +73,7 @@ gm_func_del(gm_func_t *f)
|
|
|
73
73
|
ndt_free(f->name);
|
|
74
74
|
|
|
75
75
|
for (int i = 0; i < f->nkernels; i++) {
|
|
76
|
-
|
|
76
|
+
ndt_decref(f->kernels[i].sig);
|
|
77
77
|
}
|
|
78
78
|
|
|
79
79
|
ndt_free(f);
|
|
@@ -101,7 +101,7 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
|
|
|
101
101
|
{
|
|
102
102
|
gm_func_t *f = gm_tbl_find(tbl, k->name, ctx);
|
|
103
103
|
gm_kernel_set_t kernel;
|
|
104
|
-
ndt_t *t;
|
|
104
|
+
const ndt_t *t;
|
|
105
105
|
|
|
106
106
|
if (f == NULL) {
|
|
107
107
|
ndt_err_clear(ctx);
|
|
@@ -117,7 +117,7 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
|
|
|
117
117
|
}
|
|
118
118
|
|
|
119
119
|
if (f->nkernels == GM_MAX_KERNELS) {
|
|
120
|
-
|
|
120
|
+
ndt_decref(t);
|
|
121
121
|
ndt_err_format(ctx, NDT_RuntimeError,
|
|
122
122
|
"%s: maximum number of kernels reached for", f->name);
|
|
123
123
|
return -1;
|
|
@@ -125,11 +125,13 @@ gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t *ctx)
|
|
|
125
125
|
|
|
126
126
|
kernel.sig = t;
|
|
127
127
|
kernel.constraint = k->constraint;
|
|
128
|
-
kernel.
|
|
128
|
+
kernel.OptC = k->OptC;
|
|
129
|
+
kernel.OptZ = k->OptZ;
|
|
130
|
+
kernel.OptS = k->OptS;
|
|
129
131
|
kernel.C = k->C;
|
|
130
132
|
kernel.Fortran = k->Fortran;
|
|
131
|
-
kernel.Strided = k->Strided;
|
|
132
133
|
kernel.Xnd = k->Xnd;
|
|
134
|
+
kernel.Strided = k->Strided;
|
|
133
135
|
|
|
134
136
|
f->kernels[f->nkernels++] = kernel;
|
|
135
137
|
return 0;
|
|
@@ -141,7 +143,7 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
|
|
|
141
143
|
{
|
|
142
144
|
gm_func_t *f = gm_tbl_find(tbl, k->name, ctx);
|
|
143
145
|
gm_kernel_set_t kernel;
|
|
144
|
-
ndt_t *t;
|
|
146
|
+
const ndt_t *t;
|
|
145
147
|
|
|
146
148
|
if (f == NULL) {
|
|
147
149
|
ndt_err_clear(ctx);
|
|
@@ -158,7 +160,7 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
|
|
|
158
160
|
}
|
|
159
161
|
|
|
160
162
|
if (f->nkernels == GM_MAX_KERNELS) {
|
|
161
|
-
|
|
163
|
+
ndt_decref(t);
|
|
162
164
|
ndt_err_format(ctx, NDT_RuntimeError,
|
|
163
165
|
"%s: maximum number of kernels reached for", f->name);
|
|
164
166
|
return -1;
|
|
@@ -166,11 +168,13 @@ gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *k, ndt_context_t
|
|
|
166
168
|
|
|
167
169
|
kernel.sig = t;
|
|
168
170
|
kernel.constraint = k->constraint;
|
|
169
|
-
kernel.
|
|
171
|
+
kernel.OptC = k->OptC;
|
|
172
|
+
kernel.OptZ = k->OptZ;
|
|
173
|
+
kernel.OptS = k->OptS;
|
|
170
174
|
kernel.C = k->C;
|
|
171
175
|
kernel.Fortran = k->Fortran;
|
|
172
|
-
kernel.Strided = k->Strided;
|
|
173
176
|
kernel.Xnd = k->Xnd;
|
|
177
|
+
kernel.Strided = k->Strided;
|
|
174
178
|
|
|
175
179
|
f->kernels[f->nkernels++] = kernel;
|
|
176
180
|
return 0;
|