gumath 0.2.0dev5 → 0.2.0dev8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,78 @@
|
|
1
|
+
#include <Python.h>
|
2
|
+
#include "ndtypes.h"
|
3
|
+
#include "pyndtypes.h"
|
4
|
+
#include "gumath.h"
|
5
|
+
#include "pygumath.h"
|
6
|
+
|
7
|
+
|
8
|
+
/****************************************************************************/
|
9
|
+
/* Module globals */
|
10
|
+
/****************************************************************************/
|
11
|
+
|
12
|
+
/* Function table */
|
13
|
+
static gm_tbl_t *table = NULL;
|
14
|
+
|
15
|
+
|
16
|
+
/****************************************************************************/
|
17
|
+
/* Module */
|
18
|
+
/****************************************************************************/
|
19
|
+
|
20
|
+
static struct PyModuleDef cuda_module = {
|
21
|
+
PyModuleDef_HEAD_INIT, /* m_base */
|
22
|
+
"cuda", /* m_name */
|
23
|
+
NULL, /* m_doc */
|
24
|
+
-1, /* m_size */
|
25
|
+
NULL, /* m_methods */
|
26
|
+
NULL, /* m_slots */
|
27
|
+
NULL, /* m_traverse */
|
28
|
+
NULL, /* m_clear */
|
29
|
+
NULL /* m_free */
|
30
|
+
};
|
31
|
+
|
32
|
+
|
33
|
+
PyMODINIT_FUNC
|
34
|
+
PyInit_cuda(void)
|
35
|
+
{
|
36
|
+
NDT_STATIC_CONTEXT(ctx);
|
37
|
+
PyObject *m = NULL;
|
38
|
+
static int initialized = 0;
|
39
|
+
|
40
|
+
if (!initialized) {
|
41
|
+
if (import_ndtypes() < 0) {
|
42
|
+
return NULL;
|
43
|
+
}
|
44
|
+
if (import_gumath() < 0) {
|
45
|
+
return NULL;
|
46
|
+
}
|
47
|
+
|
48
|
+
table = gm_tbl_new(&ctx);
|
49
|
+
if (table == NULL) {
|
50
|
+
return Ndt_SetError(&ctx);
|
51
|
+
}
|
52
|
+
|
53
|
+
if (gm_init_cuda_unary_kernels(table, &ctx) < 0) {
|
54
|
+
return Ndt_SetError(&ctx);
|
55
|
+
}
|
56
|
+
|
57
|
+
if (gm_init_cuda_binary_kernels(table, &ctx) < 0) {
|
58
|
+
return Ndt_SetError(&ctx);
|
59
|
+
}
|
60
|
+
|
61
|
+
initialized = 1;
|
62
|
+
}
|
63
|
+
|
64
|
+
m = PyModule_Create(&cuda_module);
|
65
|
+
if (m == NULL) {
|
66
|
+
goto error;
|
67
|
+
}
|
68
|
+
|
69
|
+
if (Gumath_AddCudaFunctions(m, table) < 0) {
|
70
|
+
goto error;
|
71
|
+
}
|
72
|
+
|
73
|
+
return m;
|
74
|
+
|
75
|
+
error:
|
76
|
+
Py_CLEAR(m);
|
77
|
+
return NULL;
|
78
|
+
}
|
@@ -56,11 +56,6 @@ PyInit_examples(void)
|
|
56
56
|
}
|
57
57
|
|
58
58
|
/* extending examples */
|
59
|
-
#ifndef _MSC_VER
|
60
|
-
if (gm_init_bfloat16_kernels(table, &ctx) < 0) {
|
61
|
-
return Ndt_SetError(&ctx);
|
62
|
-
}
|
63
|
-
#endif
|
64
59
|
if (gm_init_graph_kernels(table, &ctx) < 0) {
|
65
60
|
return Ndt_SetError(&ctx);
|
66
61
|
}
|
@@ -50,10 +50,10 @@ PyInit_functions(void)
|
|
50
50
|
return Ndt_SetError(&ctx);
|
51
51
|
}
|
52
52
|
|
53
|
-
if (
|
53
|
+
if (gm_init_cpu_unary_kernels(table, &ctx) < 0) {
|
54
54
|
return Ndt_SetError(&ctx);
|
55
55
|
}
|
56
|
-
if (
|
56
|
+
if (gm_init_cpu_binary_kernels(table, &ctx) < 0) {
|
57
57
|
return Ndt_SetError(&ctx);
|
58
58
|
}
|
59
59
|
|
@@ -0,0 +1,246 @@
|
|
1
|
+
/*
|
2
|
+
* BSD 3-Clause License
|
3
|
+
*
|
4
|
+
* Copyright (c) 2017-2018, plures
|
5
|
+
* All rights reserved.
|
6
|
+
*
|
7
|
+
* Redistribution and use in source and binary forms, with or without
|
8
|
+
* modification, are permitted provided that the following conditions are met:
|
9
|
+
*
|
10
|
+
* 1. Redistributions of source code must retain the above copyright notice,
|
11
|
+
* this list of conditions and the following disclaimer.
|
12
|
+
*
|
13
|
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
14
|
+
* this list of conditions and the following disclaimer in the documentation
|
15
|
+
* and/or other materials provided with the distribution.
|
16
|
+
*
|
17
|
+
* 3. Neither the name of the copyright holder nor the names of its
|
18
|
+
* contributors may be used to endorse or promote products derived from
|
19
|
+
* this software without specific prior written permission.
|
20
|
+
*
|
21
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
22
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
25
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
26
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
27
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
28
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
29
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
30
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31
|
+
*/
|
32
|
+
|
33
|
+
|
34
|
+
#ifndef GUMATH_H
|
35
|
+
#define GUMATH_H
|
36
|
+
|
37
|
+
|
38
|
+
#ifdef __cplusplus
|
39
|
+
extern "C" {
|
40
|
+
#endif
|
41
|
+
|
42
|
+
#ifdef __cplusplus
|
43
|
+
#include <cstdint>
|
44
|
+
#else
|
45
|
+
#include <stdint.h>
|
46
|
+
#endif
|
47
|
+
|
48
|
+
#include "ndtypes.h"
|
49
|
+
#include "xnd.h"
|
50
|
+
|
51
|
+
|
52
|
+
#ifdef _MSC_VER
|
53
|
+
#if defined (GM_EXPORT)
|
54
|
+
#define GM_API __declspec(dllexport)
|
55
|
+
#elif defined(GM_IMPORT)
|
56
|
+
#define GM_API __declspec(dllimport)
|
57
|
+
#else
|
58
|
+
#define GM_API
|
59
|
+
#endif
|
60
|
+
|
61
|
+
#ifndef GM_UNUSED
|
62
|
+
#define GM_UNUSED
|
63
|
+
#endif
|
64
|
+
|
65
|
+
#include "malloc.h"
|
66
|
+
#define ALLOCA(type, name, nmemb) type *name = _alloca(nmemb * sizeof(type))
|
67
|
+
#else
|
68
|
+
#define GM_API
|
69
|
+
#if defined(__GNUC__) && !defined(__INTEL_COMPILER)
|
70
|
+
#define GM_UNUSED __attribute__((unused))
|
71
|
+
#else
|
72
|
+
#define GM_UNUSED
|
73
|
+
#endif
|
74
|
+
|
75
|
+
#define ALLOCA(type, name, nmemb) type name[nmemb]
|
76
|
+
#endif
|
77
|
+
|
78
|
+
|
79
|
+
#define GM_MAX_KERNELS 8192
|
80
|
+
#define GM_THREAD_CUTOFF 1000000
|
81
|
+
|
82
|
+
typedef float float32_t;
|
83
|
+
typedef double float64_t;
|
84
|
+
|
85
|
+
|
86
|
+
typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
|
87
|
+
typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
|
88
|
+
|
89
|
+
/*
|
90
|
+
* Collection of specialized kernels for a single function signature.
|
91
|
+
*
|
92
|
+
* NOTE: The specialized kernel lookup scheme is transitional and may
|
93
|
+
* be replaced by something else.
|
94
|
+
*
|
95
|
+
* This should be considered as a first version of a kernel request
|
96
|
+
* protocol.
|
97
|
+
*/
|
98
|
+
typedef struct {
|
99
|
+
const ndt_t *sig;
|
100
|
+
const ndt_constraint_t *constraint;
|
101
|
+
|
102
|
+
/* Xnd signatures */
|
103
|
+
gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
|
104
|
+
gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
|
105
|
+
gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
|
106
|
+
gm_xnd_kernel_t C; /* C in inner dimensions */
|
107
|
+
gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
|
108
|
+
gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
|
109
|
+
|
110
|
+
/* NumPy signature */
|
111
|
+
gm_strided_kernel_t Strided;
|
112
|
+
} gm_kernel_set_t;
|
113
|
+
|
114
|
+
typedef struct {
|
115
|
+
const char *name;
|
116
|
+
const char *type;
|
117
|
+
const ndt_methods_t *meth;
|
118
|
+
} gm_typedef_init_t;
|
119
|
+
|
120
|
+
typedef struct {
|
121
|
+
const char *name;
|
122
|
+
const char *sig;
|
123
|
+
const ndt_constraint_t *constraint;
|
124
|
+
uint32_t cap;
|
125
|
+
|
126
|
+
/* Xnd signatures */
|
127
|
+
gm_xnd_kernel_t OptC;
|
128
|
+
gm_xnd_kernel_t OptZ;
|
129
|
+
gm_xnd_kernel_t OptS;
|
130
|
+
gm_xnd_kernel_t C;
|
131
|
+
gm_xnd_kernel_t Fortran;
|
132
|
+
gm_xnd_kernel_t Xnd;
|
133
|
+
|
134
|
+
/* NumPy signature */
|
135
|
+
gm_strided_kernel_t Strided;
|
136
|
+
} gm_kernel_init_t;
|
137
|
+
|
138
|
+
/* Actual kernel selected for application */
|
139
|
+
typedef struct {
|
140
|
+
uint32_t flag;
|
141
|
+
const gm_kernel_set_t *set;
|
142
|
+
} gm_kernel_t;
|
143
|
+
|
144
|
+
/* Multimethod with associated kernels */
|
145
|
+
typedef struct gm_func gm_func_t;
|
146
|
+
typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
|
147
|
+
const ndt_t *in[], const int64_t li[],
|
148
|
+
int nin, int nout, bool check_broadcast,
|
149
|
+
ndt_context_t *ctx);
|
150
|
+
struct gm_func {
|
151
|
+
char *name;
|
152
|
+
gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
|
153
|
+
int nkernels;
|
154
|
+
gm_kernel_set_t kernels[GM_MAX_KERNELS];
|
155
|
+
};
|
156
|
+
|
157
|
+
|
158
|
+
typedef struct _gm_tbl gm_tbl_t;
|
159
|
+
|
160
|
+
|
161
|
+
/******************************************************************************/
|
162
|
+
/* Functions */
|
163
|
+
/******************************************************************************/
|
164
|
+
|
165
|
+
GM_API gm_func_t *gm_func_new(const char *name, ndt_context_t *ctx);
|
166
|
+
GM_API void gm_func_del(gm_func_t *f);
|
167
|
+
|
168
|
+
GM_API gm_func_t *gm_add_func(gm_tbl_t *tbl, const char *name, ndt_context_t *ctx);
|
169
|
+
GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx);
|
170
|
+
GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
|
171
|
+
|
172
|
+
GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
173
|
+
const ndt_t *types[], const int64_t li[], int nin, int nout,
|
174
|
+
bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
|
175
|
+
GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
|
176
|
+
GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
|
177
|
+
|
178
|
+
|
179
|
+
/******************************************************************************/
|
180
|
+
/* NumPy loops */
|
181
|
+
/******************************************************************************/
|
182
|
+
|
183
|
+
GM_API int gm_np_flatten(char **args, const int nargs,
|
184
|
+
int64_t *dimensions, int64_t *strides, const xnd_t stack[],
|
185
|
+
ndt_context_t *ctx);
|
186
|
+
|
187
|
+
GM_API int gm_np_convert_xnd(char **args, const int nargs,
|
188
|
+
intptr_t *dimensions, const int dims_size,
|
189
|
+
intptr_t *steps, const int steps_size,
|
190
|
+
xnd_t stack[], const int outer_dims,
|
191
|
+
ndt_context_t *ctx);
|
192
|
+
|
193
|
+
GM_API int gm_np_map(const gm_strided_kernel_t f,
|
194
|
+
char **args, int nargs,
|
195
|
+
intptr_t *dimensions,
|
196
|
+
intptr_t *steps,
|
197
|
+
void *data,
|
198
|
+
int outer_dims);
|
199
|
+
|
200
|
+
|
201
|
+
/******************************************************************************/
|
202
|
+
/* Xnd loops */
|
203
|
+
/******************************************************************************/
|
204
|
+
|
205
|
+
GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
|
206
|
+
GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
207
|
+
const int outer_dims, ndt_context_t *ctx);
|
208
|
+
|
209
|
+
|
210
|
+
/******************************************************************************/
|
211
|
+
/* Gufunc table */
|
212
|
+
/******************************************************************************/
|
213
|
+
GM_API gm_tbl_t *gm_tbl_new(ndt_context_t *ctx);
|
214
|
+
GM_API void gm_tbl_del(gm_tbl_t *t);
|
215
|
+
|
216
|
+
GM_API int gm_tbl_add(gm_tbl_t *tbl, const char *key, gm_func_t *value, ndt_context_t *ctx);
|
217
|
+
GM_API gm_func_t *gm_tbl_find(const gm_tbl_t *tbl, const char *key, ndt_context_t *ctx);
|
218
|
+
GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *state), void *state);
|
219
|
+
|
220
|
+
|
221
|
+
/******************************************************************************/
|
222
|
+
/* Library initialization and tables */
|
223
|
+
/******************************************************************************/
|
224
|
+
|
225
|
+
GM_API void gm_init(void);
|
226
|
+
GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
227
|
+
GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
228
|
+
GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
229
|
+
|
230
|
+
GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
231
|
+
GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
232
|
+
|
233
|
+
GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
234
|
+
GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
235
|
+
GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
236
|
+
GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
237
|
+
|
238
|
+
GM_API void gm_finalize(void);
|
239
|
+
|
240
|
+
|
241
|
+
#ifdef __cplusplus
|
242
|
+
} /* END extern "C" */
|
243
|
+
#endif
|
244
|
+
|
245
|
+
|
246
|
+
#endif /* GUMATH_H */
|
Binary file
|
@@ -0,0 +1 @@
|
|
1
|
+
ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3
|
@@ -0,0 +1 @@
|
|
1
|
+
ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3
|
Binary file
|
@@ -49,10 +49,15 @@ extern "C" {
|
|
49
49
|
/* Exposed here for the benefit of Numba. The API should not be regarded
|
50
50
|
stable across versions. */
|
51
51
|
|
52
|
+
#define GM_CPU_FUNC 0x0001U
|
53
|
+
#define GM_CUDA_MANAGED_FUNC 0x0002U
|
54
|
+
|
52
55
|
typedef struct {
|
53
56
|
PyObject_HEAD
|
54
57
|
const gm_tbl_t *tbl; /* kernel table */
|
58
|
+
uint32_t flags; /* memory target */
|
55
59
|
char *name; /* function name */
|
60
|
+
PyObject *identity; /* identity element */
|
56
61
|
} GufuncObject;
|
57
62
|
|
58
63
|
|
@@ -60,21 +65,45 @@ typedef struct {
|
|
60
65
|
/* Capsule API */
|
61
66
|
/****************************************************************************/
|
62
67
|
|
63
|
-
#define
|
68
|
+
#define Gufunc_CheckExact_INDEX 0
|
69
|
+
#define Gufunc_CheckExact_RETURN int
|
70
|
+
#define Gufunc_CheckExact_ARGS (const PyObject *)
|
71
|
+
|
72
|
+
#define Gufunc_Check_INDEX 1
|
73
|
+
#define Gufunc_Check_RETURN int
|
74
|
+
#define Gufunc_Check_ARGS (const PyObject *)
|
75
|
+
|
76
|
+
#define Gumath_AddFunctions_INDEX 2
|
64
77
|
#define Gumath_AddFunctions_RETURN int
|
65
78
|
#define Gumath_AddFunctions_ARGS (PyObject *, const gm_tbl_t *)
|
66
79
|
|
67
|
-
#define
|
80
|
+
#define Gumath_AddCudaFunctions_INDEX 3
|
81
|
+
#define Gumath_AddCudaFunctions_RETURN int
|
82
|
+
#define Gumath_AddCudaFunctions_ARGS (PyObject *, const gm_tbl_t *)
|
83
|
+
|
84
|
+
#define GUMATH_MAX_API 4
|
68
85
|
|
69
86
|
|
70
87
|
#ifdef GUMATH_MODULE
|
88
|
+
static Gufunc_CheckExact_RETURN Gufunc_CheckExact Gufunc_CheckExact_ARGS;
|
89
|
+
static Gufunc_Check_RETURN Gufunc_Check Gufunc_Check_ARGS;
|
71
90
|
static Gumath_AddFunctions_RETURN Gumath_AddFunctions Gumath_AddFunctions_ARGS;
|
91
|
+
static Gumath_AddCudaFunctions_RETURN Gumath_AddCudaFunctions Gumath_AddCudaFunctions_ARGS;
|
72
92
|
#else
|
73
93
|
static void **_gumath_api;
|
74
94
|
|
95
|
+
#define Gufunc_CheckExact \
|
96
|
+
(*(Gufunc_CheckExact_RETURN (*)Gufunc_CheckExact_ARGS) _gumath_api[Gufunc_CheckExact_INDEX])
|
97
|
+
|
98
|
+
#define Gufunc_Check \
|
99
|
+
(*(Gufunc_Check_RETURN (*)Gufunc_Check_ARGS) _gumath_api[Gufunc_Check_INDEX])
|
100
|
+
|
75
101
|
#define Gumath_AddFunctions \
|
76
102
|
(*(Gumath_AddFunctions_RETURN (*)Gumath_AddFunctions_ARGS) _gumath_api[Gumath_AddFunctions_INDEX])
|
77
103
|
|
104
|
+
#define Gumath_AddCudaFunctions \
|
105
|
+
(*(Gumath_AddCudaFunctions_RETURN (*)Gumath_AddCudaFunctions_ARGS) _gumath_api[Gumath_AddCudaFunctions_INDEX])
|
106
|
+
|
78
107
|
|
79
108
|
static int
|
80
109
|
import_gumath(void)
|
@@ -0,0 +1,767 @@
|
|
1
|
+
#
|
2
|
+
# BSD 3-Clause License
|
3
|
+
#
|
4
|
+
# Copyright (c) 2017-2018, plures
|
5
|
+
# All rights reserved.
|
6
|
+
#
|
7
|
+
# Redistribution and use in source and binary forms, with or without
|
8
|
+
# modification, are permitted provided that the following conditions are met:
|
9
|
+
#
|
10
|
+
# 1. Redistributions of source code must retain the above copyright notice,
|
11
|
+
# this list of conditions and the following disclaimer.
|
12
|
+
#
|
13
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
14
|
+
# this list of conditions and the following disclaimer in the documentation
|
15
|
+
# and/or other materials provided with the distribution.
|
16
|
+
#
|
17
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
18
|
+
# contributors may be used to endorse or promote products derived from
|
19
|
+
# this software without specific prior written permission.
|
20
|
+
#
|
21
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
22
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
25
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
26
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
27
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
28
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
29
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
30
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31
|
+
#
|
32
|
+
|
33
|
+
# Python NDarray and functions for generating test cases.
|
34
|
+
|
35
|
+
from itertools import accumulate, count, product
|
36
|
+
from random import randrange, sample
|
37
|
+
from collections import namedtuple
|
38
|
+
import math
|
39
|
+
import struct
|
40
|
+
import unittest
|
41
|
+
from randdec import all_unary, all_binary
|
42
|
+
from randfloat import un_randfloat, bin_randfloat
|
43
|
+
import numpy as np
|
44
|
+
|
45
|
+
|
46
|
+
def skip_if(condition, reason):
|
47
|
+
if condition:
|
48
|
+
raise unittest.SkipTest(reason)
|
49
|
+
|
50
|
+
|
51
|
+
# ======================================================================
|
52
|
+
# Minimal test cases
|
53
|
+
# ======================================================================
|
54
|
+
|
55
|
+
TEST_CASES = [
|
56
|
+
([float(i)/100.0 for i in range(2000)], "2000 * float64", "float64"),
|
57
|
+
|
58
|
+
([[float(i)/100.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
|
59
|
+
"2 * 1000 * float64", "float64"),
|
60
|
+
|
61
|
+
(1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float64", "float64"),
|
62
|
+
|
63
|
+
([float(i)/10.0 for i in range(2000)], "2000 * float32", "float32"),
|
64
|
+
|
65
|
+
([[float(i)/10.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
|
66
|
+
"2 * 1000 * float32", "float32"),
|
67
|
+
|
68
|
+
(1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float32", "float32"),
|
69
|
+
]
|
70
|
+
|
71
|
+
|
72
|
+
# ======================================================================
|
73
|
+
# Definition of generalized slicing and indexing
|
74
|
+
# ======================================================================
|
75
|
+
|
76
|
+
def have_none(lst):
|
77
|
+
if isinstance(lst, (list, tuple)):
|
78
|
+
return any(have_none(item) for item in lst)
|
79
|
+
if isinstance(lst, dict):
|
80
|
+
return any(have_none(item) for item in lst.values())
|
81
|
+
return lst is None
|
82
|
+
|
83
|
+
def sinrec(lst):
|
84
|
+
if isinstance(lst, list):
|
85
|
+
return [sinrec(item) for item in lst]
|
86
|
+
elif isinstance(lst, (int, type(None))):
|
87
|
+
return None if lst is None else math.sin(lst)
|
88
|
+
else:
|
89
|
+
raise TypeError("unexpected operand type '%s'" % type(lst))
|
90
|
+
|
91
|
+
def mulrec(lst1, lst2):
|
92
|
+
if isinstance(lst1, list) and isinstance(lst2, list):
|
93
|
+
return [mulrec(*pair) for pair in zip(lst1, lst2)]
|
94
|
+
elif isinstance(lst1, (int, type(None))) and isinstance(lst2, (int, type(None))):
|
95
|
+
return None if lst1 is None or lst2 is None else lst1 * lst2
|
96
|
+
else:
|
97
|
+
raise TypeError("unexpected operand types '%s', '%s'" %
|
98
|
+
(type(lst1), type(lst2)))
|
99
|
+
|
100
|
+
|
101
|
+
def maxlevel(lst):
|
102
|
+
"""Return maximum nesting depth"""
|
103
|
+
maxlev = 0
|
104
|
+
def f(lst, level):
|
105
|
+
nonlocal maxlev
|
106
|
+
if isinstance(lst, list):
|
107
|
+
level += 1
|
108
|
+
maxlev = max(level, maxlev)
|
109
|
+
for item in lst:
|
110
|
+
f(item, level)
|
111
|
+
f(lst, 0)
|
112
|
+
return maxlev
|
113
|
+
|
114
|
+
def getitem(lst, indices):
|
115
|
+
"""Definition for multidimensional slicing and indexing on arbitrarily
|
116
|
+
shaped nested lists.
|
117
|
+
"""
|
118
|
+
if not indices:
|
119
|
+
return lst
|
120
|
+
|
121
|
+
i, indices = indices[0], indices[1:]
|
122
|
+
item = list.__getitem__(lst, i)
|
123
|
+
|
124
|
+
if isinstance(i, int):
|
125
|
+
return getitem(item, indices)
|
126
|
+
|
127
|
+
# Empty slice: check if all subsequent indices are in range for the
|
128
|
+
# full slice, raise IndexError otherwise. This is NumPy's behavior.
|
129
|
+
if not item:
|
130
|
+
if lst:
|
131
|
+
_ = getitem(lst, (slice(None),) + indices)
|
132
|
+
elif any(isinstance(k, int) for k in indices):
|
133
|
+
raise IndexError
|
134
|
+
return []
|
135
|
+
|
136
|
+
return [getitem(x, indices) for x in item]
|
137
|
+
|
138
|
+
class NDArray(list):
|
139
|
+
"""A simple wrapper for using generalized slicing/indexing on a list."""
|
140
|
+
def __init__(self, value, dtype=None):
|
141
|
+
list.__init__(self, value)
|
142
|
+
self.maxlevel = maxlevel(value)
|
143
|
+
|
144
|
+
def __getitem__(self, indices):
|
145
|
+
if not isinstance(indices, tuple):
|
146
|
+
indices = (indices,)
|
147
|
+
|
148
|
+
if len(indices) > self.maxlevel: # NumPy
|
149
|
+
raise IndexError("too many indices")
|
150
|
+
|
151
|
+
if not all(isinstance(i, (int, slice)) for i in indices):
|
152
|
+
raise TypeError(
|
153
|
+
"index must be int or slice or a tuple of integers and slices")
|
154
|
+
|
155
|
+
result = getitem(self, indices)
|
156
|
+
return NDArray(result) if isinstance(result, list) else result
|
157
|
+
|
158
|
+
def sin(self):
|
159
|
+
return NDArray(sinrec(self))
|
160
|
+
|
161
|
+
def __mul__(self, other):
|
162
|
+
return NDArray(mulrec(self, other))
|
163
|
+
|
164
|
+
|
165
|
+
|
166
|
+
# ======================================================================
|
167
|
+
# Generate test cases
|
168
|
+
# ======================================================================
|
169
|
+
|
170
|
+
SUBSCRIPT_FIXED_TEST_CASES = [
|
171
|
+
[],
|
172
|
+
[[]],
|
173
|
+
[[], []],
|
174
|
+
[[0], [1]],
|
175
|
+
[[0], [1], [2]],
|
176
|
+
[[0, 1], [1, 2], [2 ,3]],
|
177
|
+
[[[]]],
|
178
|
+
[[[0]]],
|
179
|
+
[[[], []]],
|
180
|
+
[[[0], [1]]],
|
181
|
+
[[[0, 1], [2, 3]]],
|
182
|
+
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
|
183
|
+
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
|
184
|
+
]
|
185
|
+
|
186
|
+
SUBSCRIPT_VAR_TEST_CASES = [
|
187
|
+
[[[0, 1], [2, 3]], [[4, 5, 6], [7]], [[8, 9]]],
|
188
|
+
[[[0, 1], [2, 3]], [[4, 5, None], [None], [7]], [[], [None, 8]], [[9, 10]]],
|
189
|
+
[[[0, 1, 2], [3, 4, 5, 6], [7, 8, 9, 10]], [[11, 12, 13, 14], [15, 16, 17], [18, 19]]],
|
190
|
+
[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9]], [[10, 11]]]
|
191
|
+
]
|
192
|
+
|
193
|
+
def single_fixed(max_ndim=4, min_shape=1, max_shape=10):
|
194
|
+
nat = count()
|
195
|
+
shape = [randrange(min_shape, max_shape+1) for _ in range(max_ndim)]
|
196
|
+
|
197
|
+
def f(ndim):
|
198
|
+
if ndim == 0:
|
199
|
+
return next(nat)
|
200
|
+
return [f(ndim-1) for _ in range(shape[ndim-1])]
|
201
|
+
|
202
|
+
return f(max_ndim)
|
203
|
+
|
204
|
+
def gen_fixed(max_ndim=4, min_shape=1, max_shape=10):
|
205
|
+
assert max_ndim >=0 and min_shape >=0 and min_shape <= max_shape
|
206
|
+
|
207
|
+
for _ in range(30):
|
208
|
+
yield single_fixed(max_ndim, min_shape, max_shape)
|
209
|
+
|
210
|
+
def single_var(max_ndim=4, min_shape=1, max_shape=10):
|
211
|
+
nat = count()
|
212
|
+
|
213
|
+
def f(ndim):
|
214
|
+
if ndim == 0:
|
215
|
+
return next(nat)
|
216
|
+
if ndim == 1:
|
217
|
+
shape = randrange(min_shape, max_shape+1)
|
218
|
+
else:
|
219
|
+
n = 1 if min_shape == 0 else min_shape
|
220
|
+
shape = randrange(n, max_shape+1)
|
221
|
+
return [f(ndim-1) for _ in range(shape)]
|
222
|
+
|
223
|
+
return f(max_ndim)
|
224
|
+
|
225
|
+
def gen_var(max_ndim=4, min_shape=1, max_shape=10):
|
226
|
+
assert max_ndim >=0 and min_shape >=0 and min_shape <= max_shape
|
227
|
+
|
228
|
+
for _ in range(30):
|
229
|
+
yield single_var(max_ndim, min_shape, max_shape)
|
230
|
+
|
231
|
+
|
232
|
+
def genindices():
|
233
|
+
for i in range(4):
|
234
|
+
yield (i,)
|
235
|
+
for i in range(4):
|
236
|
+
for j in range(4):
|
237
|
+
yield (i, j)
|
238
|
+
for i in range(4):
|
239
|
+
for j in range(4):
|
240
|
+
for k in range(4):
|
241
|
+
yield (i, j, k)
|
242
|
+
|
243
|
+
def rslice(ndim):
|
244
|
+
start = randrange(0, ndim+1)
|
245
|
+
stop = randrange(0, ndim+1)
|
246
|
+
step = 0
|
247
|
+
while step == 0:
|
248
|
+
step = randrange(-ndim-1, ndim+1)
|
249
|
+
start = None if randrange(5) == 4 else start
|
250
|
+
stop = None if randrange(5) == 4 else stop
|
251
|
+
step = None if randrange(5) == 4 else step
|
252
|
+
return slice(start, stop, step)
|
253
|
+
|
254
|
+
def rslice_neg(ndim):
|
255
|
+
start = randrange(-ndim-1, ndim+1)
|
256
|
+
stop = randrange(-ndim-1, ndim+1)
|
257
|
+
step = 0
|
258
|
+
while step == 0:
|
259
|
+
step = randrange(-ndim-1, ndim+1)
|
260
|
+
return slice(start, stop, step)
|
261
|
+
|
262
|
+
def multislice(ndim):
|
263
|
+
return tuple(rslice(ndim) for _ in range(randrange(1, ndim+1)))
|
264
|
+
|
265
|
+
def randslices(ndim):
|
266
|
+
for i in range(5):
|
267
|
+
yield multislice(ndim)
|
268
|
+
|
269
|
+
def gen_indices_or_slices():
|
270
|
+
for i in range(5):
|
271
|
+
if randrange(2):
|
272
|
+
yield (randrange(4), randrange(4), randrange(4))
|
273
|
+
else:
|
274
|
+
yield multislice(3)
|
275
|
+
|
276
|
+
def genslices(n):
|
277
|
+
"""Generate all possible slices for a single dimension."""
|
278
|
+
def range_with_none():
|
279
|
+
yield None
|
280
|
+
yield from range(-n, n+1)
|
281
|
+
|
282
|
+
for t in product(range_with_none(), range_with_none(), range_with_none()):
|
283
|
+
s = slice(*t)
|
284
|
+
if s.step != 0:
|
285
|
+
yield s
|
286
|
+
|
287
|
+
def genslices_ndim(ndim, shape):
|
288
|
+
"""Generate all possible slice tuples for 'shape'."""
|
289
|
+
iterables = [genslices(shape[n]) for n in range(ndim)]
|
290
|
+
yield from product(*iterables)
|
291
|
+
|
292
|
+
def mixed_index(max_ndim):
|
293
|
+
ndim = randrange(1, max_ndim+1)
|
294
|
+
indices = []
|
295
|
+
for i in range(1, ndim+1):
|
296
|
+
if randrange(2):
|
297
|
+
indices.append(randrange(-max_ndim, max_ndim))
|
298
|
+
else:
|
299
|
+
indices.append(rslice(ndim))
|
300
|
+
return tuple(indices)
|
301
|
+
|
302
|
+
def mixed_index_neg(max_ndim):
|
303
|
+
ndim = randrange(1, max_ndim+1)
|
304
|
+
indices = []
|
305
|
+
for i in range(1, ndim+1):
|
306
|
+
if randrange(2):
|
307
|
+
indices.append(randrange(-max_ndim, max_ndim))
|
308
|
+
else:
|
309
|
+
indices.append(rslice_neg(ndim))
|
310
|
+
return tuple(indices)
|
311
|
+
|
312
|
+
def mixed_indices(max_ndim):
|
313
|
+
for i in range(5):
|
314
|
+
yield mixed_index(max_ndim)
|
315
|
+
for i in range(5):
|
316
|
+
yield mixed_index_neg(max_ndim)
|
317
|
+
|
318
|
+
def itos(indices):
|
319
|
+
return ", ".join(str(i) if isinstance(i, int) else "%s:%s:%s" %
|
320
|
+
(i.start, i.stop, i.step) for i in indices)
|
321
|
+
|
322
|
+
|
323
|
+
# ======================================================================
|
324
|
+
# Split a shape into N almost equal slices
|
325
|
+
# ======================================================================
|
326
|
+
|
327
|
+
def start(i, r, q):
|
328
|
+
return i*(q+1) if i < r else r+i*q
|
329
|
+
|
330
|
+
def stop(i, r, q):
|
331
|
+
return (i+1)*(q+1) if i < r else r+(i+1)*q
|
332
|
+
|
333
|
+
def step(i, r, q):
|
334
|
+
return q+1 if i < r else q
|
335
|
+
|
336
|
+
def sl(i, r, q):
|
337
|
+
return slice(start(i, r, q), stop(i, r, q))
|
338
|
+
|
339
|
+
def prepend(x, xs):
|
340
|
+
return [(x,) + t for t in xs]
|
341
|
+
|
342
|
+
def last_column(i, r, q, n):
|
343
|
+
return [(sl(i, r, q),) for i in range(n)]
|
344
|
+
|
345
|
+
def schedule(n, shape):
|
346
|
+
assert isinstance(n, int) and isinstance(shape, list)
|
347
|
+
if (n <= 0):
|
348
|
+
raise ValueError("n must be greater than zero")
|
349
|
+
if shape == []:
|
350
|
+
return [()]
|
351
|
+
m, ms = shape[0], shape[1:]
|
352
|
+
if (m <= 0):
|
353
|
+
raise ValueError("shape must be greater than zero")
|
354
|
+
if n <= m:
|
355
|
+
q, r = divmod(m, n)
|
356
|
+
return last_column(0, r, q, n)
|
357
|
+
else:
|
358
|
+
q, r = divmod(n, m)
|
359
|
+
return column(0, r, q, m, ms)
|
360
|
+
|
361
|
+
def column(i, r, q, m, ms):
|
362
|
+
if i == m: return []
|
363
|
+
return prepend(slice(i, i+1),
|
364
|
+
schedule(step(i, r, q), ms)) + \
|
365
|
+
column(i+1, r, q, m, ms)
|
366
|
+
|
367
|
+
# ======================================================================
|
368
|
+
# Split an xnd object into N subtrees
|
369
|
+
# ======================================================================
|
370
|
+
|
371
|
+
def zero_in_shape(shape):
|
372
|
+
for i in shape:
|
373
|
+
if i == 0:
|
374
|
+
return True
|
375
|
+
return False
|
376
|
+
|
377
|
+
def split_xnd(x, n, max_outer=None):
|
378
|
+
shape = list(x.type.shape)
|
379
|
+
if zero_in_shape(shape):
|
380
|
+
raise ValueError("split does not support zeros in shape")
|
381
|
+
if max_outer is not None:
|
382
|
+
shape = shape[:max_outer]
|
383
|
+
indices_list = schedule(n, shape)
|
384
|
+
return [x[i] for i in indices_list]
|
385
|
+
|
386
|
+
|
387
|
+
# ======================================================================
|
388
|
+
# Generate test cases
|
389
|
+
# ======================================================================
|
390
|
+
|
391
|
+
functions = {
|
392
|
+
"unary": {
|
393
|
+
"default": ["copy", "abs"],
|
394
|
+
"arith": ["negative"],
|
395
|
+
"complex_math_with_half": ["exp", "log", "log10", "sqrt", "sin", "cos"],
|
396
|
+
"complex_math": ["tan", "asin", "acos", "atan", "sinh", "cosh", "tanh",
|
397
|
+
"asinh", "acosh", "atanh"],
|
398
|
+
"real_math_with_half": ["fabs", "exp2", "log2"],
|
399
|
+
"real_math": ["expm1", "log1p", "logb", "cbrt", "erf", "erfc", "lgamma",
|
400
|
+
"tgamma", "ceil", "floor", "trunc", "round", "nearbyint"],
|
401
|
+
"bitwise": ["invert"],
|
402
|
+
},
|
403
|
+
"binary": {
|
404
|
+
"default": ["add", "subtract", "multiply", "floor_divide", "remainder", "power"],
|
405
|
+
"float_result": ["divide"],
|
406
|
+
"bool_result": ["less_equal", "less", "greater_equal", "greater", "equal", "not_equal"],
|
407
|
+
"bitwise": ["bitwise_and", "bitwise_or", "bitwise_xor"]
|
408
|
+
},
|
409
|
+
"binary_mv": {
|
410
|
+
"default": ["divmod"],
|
411
|
+
}
|
412
|
+
}
|
413
|
+
|
414
|
+
def complex_noimpl(name):
|
415
|
+
return name in functions["unary"]["real_math"] or \
|
416
|
+
name in functions["unary"]["real_math_with_half"]
|
417
|
+
|
418
|
+
def half_noimpl(name):
|
419
|
+
return name in functions["unary"]["real_math"] or \
|
420
|
+
name in functions["unary"]["complex_math"] or \
|
421
|
+
name in ("floor_divide", "remainder")
|
422
|
+
|
423
|
+
tunsigned = ["bool", "uint8", "uint16", "uint32", "uint64"]
|
424
|
+
tsigned = ["int8", "int16", "int32", "int64"]
|
425
|
+
tfloat = ["bfloat16", "float16", "float32", "float64"]
|
426
|
+
tcomplex = ["complex32", "complex64", "complex128"]
|
427
|
+
|
428
|
+
tinfo = {
|
429
|
+
"bool": (0, 1, 0),
|
430
|
+
"uint8": (0, 2**8-1, 0),
|
431
|
+
"uint16": (0, 2**16-1, 0),
|
432
|
+
"uint32": (0, 2**32-1, 0),
|
433
|
+
"uint64": (0, 2**64-1, 0),
|
434
|
+
"int8": (-2**7, 2**7-1, 0),
|
435
|
+
"int16": (-2**15, 2**15-1, 0),
|
436
|
+
"int32": (-2**31, 2**31-1, 0),
|
437
|
+
"int64": (-2**63, 2**63-1, 0),
|
438
|
+
"float16": (-2**11, 2**11, 15),
|
439
|
+
"bfloat16": (-2**8, 2**8, 127),
|
440
|
+
"float32": (-2**24, 2**24, 127),
|
441
|
+
"float64": (-2**53, 2**53, 1023),
|
442
|
+
"complex32": (-2**11, 2**11, 15),
|
443
|
+
"complex64": (-2**24, 2**24, 127),
|
444
|
+
"complex128": (-2**53, 2**53, 1023)
|
445
|
+
}
|
446
|
+
|
447
|
+
class Tint(object):
|
448
|
+
def __init__(self, type):
|
449
|
+
if type not in tunsigned + tsigned:
|
450
|
+
raise ValueError("not an integer type: '%s'" % type)
|
451
|
+
self.type = type
|
452
|
+
self.min, self.max, self.exp = tinfo[type]
|
453
|
+
self.all = (self.type, self.min, self.max, self.exp)
|
454
|
+
def __repr__(self):
|
455
|
+
return self.type
|
456
|
+
def __eq__(self, other):
|
457
|
+
return isinstance(Tint, other) and self.all == other.all
|
458
|
+
def __hash__(self):
|
459
|
+
return hash(self.all)
|
460
|
+
def testcases(self):
|
461
|
+
yield 0
|
462
|
+
yield self.min
|
463
|
+
yield self.max
|
464
|
+
for i in range(10):
|
465
|
+
yield randrange(self.min, self.max+1)
|
466
|
+
def cpu_noimpl(self, f=None):
|
467
|
+
return False
|
468
|
+
def cpu_nokern(self, f=None):
|
469
|
+
return False
|
470
|
+
def cuda_noimpl(self, f=None):
|
471
|
+
return False
|
472
|
+
def cuda_nokern(self, f=None):
|
473
|
+
return False
|
474
|
+
|
475
|
+
class Tfloat(object):
|
476
|
+
def __init__(self, type):
|
477
|
+
if type not in tfloat:
|
478
|
+
raise ValueError("not a float type: '%s'" % type)
|
479
|
+
self.type = type
|
480
|
+
self.min, self.max, self.exp = tinfo[type]
|
481
|
+
self.all = (self.type, self.min, self.max, self.exp)
|
482
|
+
def __repr__(self):
|
483
|
+
return self.type
|
484
|
+
def __eq__(self, other):
|
485
|
+
return isinstance(Tint, other) and self.all == other.all
|
486
|
+
def __hash__(self):
|
487
|
+
return hash(self.all)
|
488
|
+
def testcases(self):
|
489
|
+
yield 0
|
490
|
+
yield 0.5
|
491
|
+
yield -0.5
|
492
|
+
yield self.min
|
493
|
+
yield self.max
|
494
|
+
prec = randrange(1, 10)
|
495
|
+
for v in all_unary(prec, self.exp, 1):
|
496
|
+
yield float(v)
|
497
|
+
for v in un_randfloat():
|
498
|
+
yield float(v)
|
499
|
+
def cpu_noimpl(self, f=None):
|
500
|
+
return self.type == "float16"
|
501
|
+
def cpu_nokern(self, f=None):
|
502
|
+
return False
|
503
|
+
def cuda_noimpl(self, f=None):
|
504
|
+
if self.type == "float16":
|
505
|
+
return half_noimpl(f)
|
506
|
+
def cuda_nokern(self, f=None):
|
507
|
+
return False
|
508
|
+
|
509
|
+
class Tcomplex(object):
|
510
|
+
def __init__(self, type):
|
511
|
+
if type not in tcomplex:
|
512
|
+
raise ValueError("not a complex type: '%s'" % type)
|
513
|
+
self.type = type
|
514
|
+
self.min, self.max, self.exp = tinfo[type]
|
515
|
+
self.all = (self.type, self.min, self.max, self.exp)
|
516
|
+
def __repr__(self):
|
517
|
+
return self.type
|
518
|
+
def __eq__(self, other):
|
519
|
+
return isinstance(Tint, other) and self.all == other.all
|
520
|
+
def __hash__(self):
|
521
|
+
return hash(self.all)
|
522
|
+
def testcases(self):
|
523
|
+
yield 0
|
524
|
+
yield 0.5
|
525
|
+
yield -0.5
|
526
|
+
yield 0.5j
|
527
|
+
yield -0.5j
|
528
|
+
yield self.min
|
529
|
+
yield self.max
|
530
|
+
prec = randrange(1, 10)
|
531
|
+
for v, w in all_binary(prec, self.exp, 1):
|
532
|
+
yield complex(float(v), float(w))
|
533
|
+
for v, w in bin_randfloat():
|
534
|
+
yield complex(float(v), float(w))
|
535
|
+
def cpu_noimpl(self, f=None):
|
536
|
+
if self.type == "complex32":
|
537
|
+
return True
|
538
|
+
return complex_noimpl(f)
|
539
|
+
def cpu_nokern(self, f=None):
|
540
|
+
return f in ("floor_divide", "remainder")
|
541
|
+
def cuda_noimpl(self, f=None):
|
542
|
+
if self.type == "complex32":
|
543
|
+
return True
|
544
|
+
return complex_noimpl(f)
|
545
|
+
def cuda_nokern(self, f=None):
|
546
|
+
return f in ("floor_divide", "remainder")
|
547
|
+
|
548
|
+
|
549
|
+
tinfo_default = [
|
550
|
+
Tint("uint8"),
|
551
|
+
Tint("uint16"),
|
552
|
+
Tint("uint32"),
|
553
|
+
Tint("uint64"),
|
554
|
+
Tint("int8"),
|
555
|
+
Tint("int16"),
|
556
|
+
Tint("int32"),
|
557
|
+
Tint("int64"),
|
558
|
+
Tfloat("float16"),
|
559
|
+
Tfloat("bfloat16"),
|
560
|
+
Tfloat("float32"),
|
561
|
+
Tfloat("float64"),
|
562
|
+
Tcomplex("complex32"),
|
563
|
+
Tcomplex("complex64"),
|
564
|
+
Tcomplex("complex128")
|
565
|
+
]
|
566
|
+
|
567
|
+
tinfo_bitwise = [
|
568
|
+
Tint("bool"),
|
569
|
+
Tint("uint8"),
|
570
|
+
Tint("uint16"),
|
571
|
+
Tint("uint32"),
|
572
|
+
Tint("uint64"),
|
573
|
+
Tint("int8"),
|
574
|
+
Tint("int16"),
|
575
|
+
Tint("int32"),
|
576
|
+
Tint("int64")
|
577
|
+
]
|
578
|
+
|
579
|
+
implemented_sigs = {
|
580
|
+
"unary": {
|
581
|
+
"default": {}, "float_result": {}
|
582
|
+
},
|
583
|
+
"binary": {
|
584
|
+
"default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
|
585
|
+
},
|
586
|
+
"binary_mv": {
|
587
|
+
"default": {
|
588
|
+
(Tint("uint8"), Tint("uint8")): (Tint("uint8"), Tint("uint8")),
|
589
|
+
(Tint("uint16"), Tint("uint16")): (Tint("uint16"), Tint("uint16")),
|
590
|
+
(Tint("uint32"), Tint("uint32")): (Tint("uint32"), Tint("uint32")),
|
591
|
+
(Tint("uint64"), Tint("uint64")): (Tint("uint64"), Tint("uint64")),
|
592
|
+
(Tint("int8"), Tint("int8")): (Tint("int8"), Tint("int8")),
|
593
|
+
(Tint("int16"), Tint("int16")): (Tint("int16"), Tint("int16")),
|
594
|
+
(Tint("int32"), Tint("int32")): (Tint("int32"), Tint("int32")),
|
595
|
+
(Tint("int64"), Tint("int64")): (Tint("int64"), Tint("int64")),
|
596
|
+
(Tfloat("float32"), Tfloat("float32")): (Tfloat("float32"), Tfloat("float32")),
|
597
|
+
(Tfloat("float64"), Tfloat("float64")): (Tfloat("float64"), Tfloat("float64"))
|
598
|
+
},
|
599
|
+
}
|
600
|
+
}
|
601
|
+
|
602
|
+
exact_sigs = {
|
603
|
+
"unary": {
|
604
|
+
"default": {}, "float_result": {}
|
605
|
+
},
|
606
|
+
"binary": {
|
607
|
+
"default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
|
608
|
+
}
|
609
|
+
}
|
610
|
+
|
611
|
+
inexact_sigs = {
|
612
|
+
"unary": {
|
613
|
+
"default": {}, "float_result": {}
|
614
|
+
},
|
615
|
+
"binary": {
|
616
|
+
"default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
|
617
|
+
}
|
618
|
+
}
|
619
|
+
|
620
|
+
def init_unary_cast(pattern, tinfo, rank):
|
621
|
+
t = tinfo[rank]
|
622
|
+
|
623
|
+
start = max(8, rank) if pattern == "float_result" else rank
|
624
|
+
found_cast = False
|
625
|
+
|
626
|
+
for i in range(start, len(tinfo_default)):
|
627
|
+
cast = tinfo[i]
|
628
|
+
if cast.min <= t.min and t.max <= cast.max:
|
629
|
+
if found_cast or (t.type=="bfloat16") != (cast.type=="bfloat16"):
|
630
|
+
exact_sigs["unary"][pattern][(t,)] = cast
|
631
|
+
else:
|
632
|
+
found_cast = True
|
633
|
+
implemented_sigs["unary"][pattern][(t,)] = cast
|
634
|
+
exact_sigs["unary"][pattern][(t,)] = cast
|
635
|
+
else:
|
636
|
+
inexact_sigs["unary"][pattern][(t,)] = cast
|
637
|
+
|
638
|
+
def init_unary_cast_tbl(pattern):
|
639
|
+
if pattern == "default":
|
640
|
+
tinfo = [Tint("bool")] + tinfo_default
|
641
|
+
elif pattern == "float_result":
|
642
|
+
tinfo = tinfo_default
|
643
|
+
elif pattern == "bitwise":
|
644
|
+
tinfo = tinfo_bitwise
|
645
|
+
else:
|
646
|
+
raise ValueError("unsupported function type '%s'" % func)
|
647
|
+
|
648
|
+
for rank, _ in enumerate(tinfo):
|
649
|
+
init_unary_cast(pattern, tinfo, rank)
|
650
|
+
|
651
|
+
def is_binary_common_cast(cast, t, u):
|
652
|
+
if cast.min <= t.min and t.max <= cast.max and \
|
653
|
+
cast.min <= u.min and u.max <= cast.max:
|
654
|
+
if isinstance(cast, Tfloat):
|
655
|
+
return t.exp <= cast.exp and u.exp <= cast.exp
|
656
|
+
else:
|
657
|
+
return True
|
658
|
+
return False
|
659
|
+
|
660
|
+
def init_binary_cast(pattern, tinfo, rank1, rank2):
|
661
|
+
min_rank = min(rank1, rank2)
|
662
|
+
max_rank = max(rank1, rank2)
|
663
|
+
|
664
|
+
t = tinfo[min_rank]
|
665
|
+
u = tinfo[max_rank]
|
666
|
+
|
667
|
+
start = max(8, max_rank) if pattern == "float_result" else max_rank
|
668
|
+
smallest_common_cast = False
|
669
|
+
|
670
|
+
for i in range(start, len(tinfo_default)):
|
671
|
+
common_cast = tinfo_default[i]
|
672
|
+
w = Tint("bool") if pattern == "bool_result" else common_cast
|
673
|
+
if is_binary_common_cast(common_cast, t, u):
|
674
|
+
if smallest_common_cast:
|
675
|
+
exact_sigs["binary"][pattern][(t, u)] = w
|
676
|
+
else:
|
677
|
+
smallest_common_cast = True
|
678
|
+
implemented_sigs["binary"][pattern][(t, u)] = w
|
679
|
+
exact_sigs["binary"][pattern][(t, u)] = w
|
680
|
+
else:
|
681
|
+
inexact_sigs["binary"][pattern][(t, u)] = w
|
682
|
+
|
683
|
+
def init_binary_cast_tbl(pattern):
|
684
|
+
if pattern == "default" or pattern == "float_result" or pattern == "bool_result":
|
685
|
+
tinfo = tinfo_default
|
686
|
+
elif pattern == "bitwise":
|
687
|
+
tinfo = tinfo_bitwise
|
688
|
+
else:
|
689
|
+
raise ValueError("unsupported function type '%s'" % pattern)
|
690
|
+
|
691
|
+
for rank1, _ in enumerate(tinfo):
|
692
|
+
for rank2, _ in enumerate(tinfo):
|
693
|
+
init_binary_cast(pattern, tinfo, rank1, rank2)
|
694
|
+
|
695
|
+
_struct_format = {
|
696
|
+
"float16": "e",
|
697
|
+
"float32": "f",
|
698
|
+
"float64": "d",
|
699
|
+
"complex32": "e",
|
700
|
+
"complex64": "f",
|
701
|
+
"complex128": "d"
|
702
|
+
}
|
703
|
+
|
704
|
+
def roundtrip_ne(v, fmt):
|
705
|
+
if fmt == "e":
|
706
|
+
try:
|
707
|
+
struct.pack(fmt, v)
|
708
|
+
except (OverflowError, struct.error):
|
709
|
+
return True
|
710
|
+
else:
|
711
|
+
return False
|
712
|
+
else:
|
713
|
+
if math.isinf(v):
|
714
|
+
return False
|
715
|
+
s = struct.unpack(fmt, struct.pack(fmt, v))[0]
|
716
|
+
return math.isinf(float(s))
|
717
|
+
|
718
|
+
def struct_overflow(v, t):
|
719
|
+
try:
|
720
|
+
fmt = _struct_format[t.type]
|
721
|
+
except KeyError:
|
722
|
+
return False
|
723
|
+
|
724
|
+
if isinstance(t, Tcomplex):
|
725
|
+
return roundtrip_ne(v.real, fmt) or roundtrip_ne(v.imag, fmt)
|
726
|
+
else:
|
727
|
+
return roundtrip_ne(v, fmt)
|
728
|
+
|
729
|
+
|
730
|
+
init_unary_cast_tbl("default")
|
731
|
+
init_unary_cast_tbl("float_result")
|
732
|
+
|
733
|
+
init_binary_cast_tbl("default")
|
734
|
+
init_binary_cast_tbl("float_result")
|
735
|
+
init_binary_cast_tbl("bool_result")
|
736
|
+
init_binary_cast_tbl("bitwise")
|
737
|
+
|
738
|
+
|
739
|
+
_np_names = {
|
740
|
+
"asin" : "arcsin",
|
741
|
+
"acos" : "arccos",
|
742
|
+
"atan" : "arctan",
|
743
|
+
"asinh" : "arcsinh",
|
744
|
+
"acosh" : "arccosh",
|
745
|
+
"atanh" : "arctanh",
|
746
|
+
"nearbyint" : "round",
|
747
|
+
}
|
748
|
+
|
749
|
+
def np_function(name):
|
750
|
+
return _np_names.get(name, name)
|
751
|
+
|
752
|
+
def np_noimpl(name):
|
753
|
+
if name == "round":
|
754
|
+
# np.round == gumath.nearbyint
|
755
|
+
return True
|
756
|
+
try:
|
757
|
+
getattr(np, name)
|
758
|
+
return False
|
759
|
+
except AttributeError:
|
760
|
+
return True
|
761
|
+
|
762
|
+
def gen_axes(ndim):
|
763
|
+
for i in range(ndim):
|
764
|
+
yield i
|
765
|
+
lst = list(range(ndim))
|
766
|
+
for i in range(ndim):
|
767
|
+
yield tuple(sample(lst, i))
|