gumath 0.2.0dev5 → 0.2.0dev8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CONTRIBUTING.md +7 -2
- data/Gemfile +0 -3
- data/ext/ruby_gumath/GPATH +0 -0
- data/ext/ruby_gumath/GRTAGS +0 -0
- data/ext/ruby_gumath/GTAGS +0 -0
- data/ext/ruby_gumath/extconf.rb +0 -5
- data/ext/ruby_gumath/functions.c +10 -2
- data/ext/ruby_gumath/gufunc_object.c +15 -4
- data/ext/ruby_gumath/gufunc_object.h +9 -3
- data/ext/ruby_gumath/gumath/Makefile +63 -0
- data/ext/ruby_gumath/gumath/Makefile.in +1 -0
- data/ext/ruby_gumath/gumath/config.h +56 -0
- data/ext/ruby_gumath/gumath/config.h.in +3 -0
- data/ext/ruby_gumath/gumath/config.log +497 -0
- data/ext/ruby_gumath/gumath/config.status +1034 -0
- data/ext/ruby_gumath/gumath/configure +375 -4
- data/ext/ruby_gumath/gumath/configure.ac +47 -3
- data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
- data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
- data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
- data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
- data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
- data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
- data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
- data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
- data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
- data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
- data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
- data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
- data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
- data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
- data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
- data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
- data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
- data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
- data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
- data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
- data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
- data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
- data/ext/ruby_gumath/gumath/setup.py +67 -6
- data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
- data/ext/ruby_gumath/include/gumath.h +55 -14
- data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/lib/libgumath.a +0 -0
- data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
- data/ext/ruby_gumath/ruby_gumath.c +231 -70
- data/ext/ruby_gumath/ruby_gumath.h +4 -1
- data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
- data/ext/ruby_gumath/util.c +34 -0
- data/ext/ruby_gumath/util.h +9 -0
- data/gumath.gemspec +3 -2
- data/lib/gumath.rb +55 -1
- data/lib/gumath/version.rb +2 -2
- data/lib/ruby_gumath.so +0 -0
- metadata +63 -10
- data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
- data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
- data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
#include <Python.h>
|
|
2
|
+
#include "ndtypes.h"
|
|
3
|
+
#include "pyndtypes.h"
|
|
4
|
+
#include "gumath.h"
|
|
5
|
+
#include "pygumath.h"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
/****************************************************************************/
|
|
9
|
+
/* Module globals */
|
|
10
|
+
/****************************************************************************/
|
|
11
|
+
|
|
12
|
+
/* Function table */
|
|
13
|
+
static gm_tbl_t *table = NULL;
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
/****************************************************************************/
|
|
17
|
+
/* Module */
|
|
18
|
+
/****************************************************************************/
|
|
19
|
+
|
|
20
|
+
static struct PyModuleDef cuda_module = {
|
|
21
|
+
PyModuleDef_HEAD_INIT, /* m_base */
|
|
22
|
+
"cuda", /* m_name */
|
|
23
|
+
NULL, /* m_doc */
|
|
24
|
+
-1, /* m_size */
|
|
25
|
+
NULL, /* m_methods */
|
|
26
|
+
NULL, /* m_slots */
|
|
27
|
+
NULL, /* m_traverse */
|
|
28
|
+
NULL, /* m_clear */
|
|
29
|
+
NULL /* m_free */
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
PyMODINIT_FUNC
|
|
34
|
+
PyInit_cuda(void)
|
|
35
|
+
{
|
|
36
|
+
NDT_STATIC_CONTEXT(ctx);
|
|
37
|
+
PyObject *m = NULL;
|
|
38
|
+
static int initialized = 0;
|
|
39
|
+
|
|
40
|
+
if (!initialized) {
|
|
41
|
+
if (import_ndtypes() < 0) {
|
|
42
|
+
return NULL;
|
|
43
|
+
}
|
|
44
|
+
if (import_gumath() < 0) {
|
|
45
|
+
return NULL;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
table = gm_tbl_new(&ctx);
|
|
49
|
+
if (table == NULL) {
|
|
50
|
+
return Ndt_SetError(&ctx);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
if (gm_init_cuda_unary_kernels(table, &ctx) < 0) {
|
|
54
|
+
return Ndt_SetError(&ctx);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
if (gm_init_cuda_binary_kernels(table, &ctx) < 0) {
|
|
58
|
+
return Ndt_SetError(&ctx);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
initialized = 1;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
m = PyModule_Create(&cuda_module);
|
|
65
|
+
if (m == NULL) {
|
|
66
|
+
goto error;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
if (Gumath_AddCudaFunctions(m, table) < 0) {
|
|
70
|
+
goto error;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
return m;
|
|
74
|
+
|
|
75
|
+
error:
|
|
76
|
+
Py_CLEAR(m);
|
|
77
|
+
return NULL;
|
|
78
|
+
}
|
|
@@ -56,11 +56,6 @@ PyInit_examples(void)
|
|
|
56
56
|
}
|
|
57
57
|
|
|
58
58
|
/* extending examples */
|
|
59
|
-
#ifndef _MSC_VER
|
|
60
|
-
if (gm_init_bfloat16_kernels(table, &ctx) < 0) {
|
|
61
|
-
return Ndt_SetError(&ctx);
|
|
62
|
-
}
|
|
63
|
-
#endif
|
|
64
59
|
if (gm_init_graph_kernels(table, &ctx) < 0) {
|
|
65
60
|
return Ndt_SetError(&ctx);
|
|
66
61
|
}
|
|
@@ -50,10 +50,10 @@ PyInit_functions(void)
|
|
|
50
50
|
return Ndt_SetError(&ctx);
|
|
51
51
|
}
|
|
52
52
|
|
|
53
|
-
if (
|
|
53
|
+
if (gm_init_cpu_unary_kernels(table, &ctx) < 0) {
|
|
54
54
|
return Ndt_SetError(&ctx);
|
|
55
55
|
}
|
|
56
|
-
if (
|
|
56
|
+
if (gm_init_cpu_binary_kernels(table, &ctx) < 0) {
|
|
57
57
|
return Ndt_SetError(&ctx);
|
|
58
58
|
}
|
|
59
59
|
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* BSD 3-Clause License
|
|
3
|
+
*
|
|
4
|
+
* Copyright (c) 2017-2018, plures
|
|
5
|
+
* All rights reserved.
|
|
6
|
+
*
|
|
7
|
+
* Redistribution and use in source and binary forms, with or without
|
|
8
|
+
* modification, are permitted provided that the following conditions are met:
|
|
9
|
+
*
|
|
10
|
+
* 1. Redistributions of source code must retain the above copyright notice,
|
|
11
|
+
* this list of conditions and the following disclaimer.
|
|
12
|
+
*
|
|
13
|
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
14
|
+
* this list of conditions and the following disclaimer in the documentation
|
|
15
|
+
* and/or other materials provided with the distribution.
|
|
16
|
+
*
|
|
17
|
+
* 3. Neither the name of the copyright holder nor the names of its
|
|
18
|
+
* contributors may be used to endorse or promote products derived from
|
|
19
|
+
* this software without specific prior written permission.
|
|
20
|
+
*
|
|
21
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
22
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
23
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
24
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
25
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
26
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
27
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
28
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
29
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
30
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
31
|
+
*/
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
#ifndef GUMATH_H
|
|
35
|
+
#define GUMATH_H
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
#ifdef __cplusplus
|
|
39
|
+
extern "C" {
|
|
40
|
+
#endif
|
|
41
|
+
|
|
42
|
+
#ifdef __cplusplus
|
|
43
|
+
#include <cstdint>
|
|
44
|
+
#else
|
|
45
|
+
#include <stdint.h>
|
|
46
|
+
#endif
|
|
47
|
+
|
|
48
|
+
#include "ndtypes.h"
|
|
49
|
+
#include "xnd.h"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
#ifdef _MSC_VER
|
|
53
|
+
#if defined (GM_EXPORT)
|
|
54
|
+
#define GM_API __declspec(dllexport)
|
|
55
|
+
#elif defined(GM_IMPORT)
|
|
56
|
+
#define GM_API __declspec(dllimport)
|
|
57
|
+
#else
|
|
58
|
+
#define GM_API
|
|
59
|
+
#endif
|
|
60
|
+
|
|
61
|
+
#ifndef GM_UNUSED
|
|
62
|
+
#define GM_UNUSED
|
|
63
|
+
#endif
|
|
64
|
+
|
|
65
|
+
#include "malloc.h"
|
|
66
|
+
#define ALLOCA(type, name, nmemb) type *name = _alloca(nmemb * sizeof(type))
|
|
67
|
+
#else
|
|
68
|
+
#define GM_API
|
|
69
|
+
#if defined(__GNUC__) && !defined(__INTEL_COMPILER)
|
|
70
|
+
#define GM_UNUSED __attribute__((unused))
|
|
71
|
+
#else
|
|
72
|
+
#define GM_UNUSED
|
|
73
|
+
#endif
|
|
74
|
+
|
|
75
|
+
#define ALLOCA(type, name, nmemb) type name[nmemb]
|
|
76
|
+
#endif
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
#define GM_MAX_KERNELS 8192
|
|
80
|
+
#define GM_THREAD_CUTOFF 1000000
|
|
81
|
+
|
|
82
|
+
typedef float float32_t;
|
|
83
|
+
typedef double float64_t;
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
|
|
87
|
+
typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
|
|
88
|
+
|
|
89
|
+
/*
|
|
90
|
+
* Collection of specialized kernels for a single function signature.
|
|
91
|
+
*
|
|
92
|
+
* NOTE: The specialized kernel lookup scheme is transitional and may
|
|
93
|
+
* be replaced by something else.
|
|
94
|
+
*
|
|
95
|
+
* This should be considered as a first version of a kernel request
|
|
96
|
+
* protocol.
|
|
97
|
+
*/
|
|
98
|
+
typedef struct {
|
|
99
|
+
const ndt_t *sig;
|
|
100
|
+
const ndt_constraint_t *constraint;
|
|
101
|
+
|
|
102
|
+
/* Xnd signatures */
|
|
103
|
+
gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
|
|
104
|
+
gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
|
|
105
|
+
gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
|
|
106
|
+
gm_xnd_kernel_t C; /* C in inner dimensions */
|
|
107
|
+
gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
|
|
108
|
+
gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
|
|
109
|
+
|
|
110
|
+
/* NumPy signature */
|
|
111
|
+
gm_strided_kernel_t Strided;
|
|
112
|
+
} gm_kernel_set_t;
|
|
113
|
+
|
|
114
|
+
typedef struct {
|
|
115
|
+
const char *name;
|
|
116
|
+
const char *type;
|
|
117
|
+
const ndt_methods_t *meth;
|
|
118
|
+
} gm_typedef_init_t;
|
|
119
|
+
|
|
120
|
+
typedef struct {
|
|
121
|
+
const char *name;
|
|
122
|
+
const char *sig;
|
|
123
|
+
const ndt_constraint_t *constraint;
|
|
124
|
+
uint32_t cap;
|
|
125
|
+
|
|
126
|
+
/* Xnd signatures */
|
|
127
|
+
gm_xnd_kernel_t OptC;
|
|
128
|
+
gm_xnd_kernel_t OptZ;
|
|
129
|
+
gm_xnd_kernel_t OptS;
|
|
130
|
+
gm_xnd_kernel_t C;
|
|
131
|
+
gm_xnd_kernel_t Fortran;
|
|
132
|
+
gm_xnd_kernel_t Xnd;
|
|
133
|
+
|
|
134
|
+
/* NumPy signature */
|
|
135
|
+
gm_strided_kernel_t Strided;
|
|
136
|
+
} gm_kernel_init_t;
|
|
137
|
+
|
|
138
|
+
/* Actual kernel selected for application */
|
|
139
|
+
typedef struct {
|
|
140
|
+
uint32_t flag;
|
|
141
|
+
const gm_kernel_set_t *set;
|
|
142
|
+
} gm_kernel_t;
|
|
143
|
+
|
|
144
|
+
/* Multimethod with associated kernels */
|
|
145
|
+
typedef struct gm_func gm_func_t;
|
|
146
|
+
typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
|
|
147
|
+
const ndt_t *in[], const int64_t li[],
|
|
148
|
+
int nin, int nout, bool check_broadcast,
|
|
149
|
+
ndt_context_t *ctx);
|
|
150
|
+
struct gm_func {
|
|
151
|
+
char *name;
|
|
152
|
+
gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
|
|
153
|
+
int nkernels;
|
|
154
|
+
gm_kernel_set_t kernels[GM_MAX_KERNELS];
|
|
155
|
+
};
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
typedef struct _gm_tbl gm_tbl_t;
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
/******************************************************************************/
|
|
162
|
+
/* Functions */
|
|
163
|
+
/******************************************************************************/
|
|
164
|
+
|
|
165
|
+
GM_API gm_func_t *gm_func_new(const char *name, ndt_context_t *ctx);
|
|
166
|
+
GM_API void gm_func_del(gm_func_t *f);
|
|
167
|
+
|
|
168
|
+
GM_API gm_func_t *gm_add_func(gm_tbl_t *tbl, const char *name, ndt_context_t *ctx);
|
|
169
|
+
GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx);
|
|
170
|
+
GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
|
|
171
|
+
|
|
172
|
+
GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
|
|
173
|
+
const ndt_t *types[], const int64_t li[], int nin, int nout,
|
|
174
|
+
bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
|
|
175
|
+
GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
|
|
176
|
+
GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
/******************************************************************************/
|
|
180
|
+
/* NumPy loops */
|
|
181
|
+
/******************************************************************************/
|
|
182
|
+
|
|
183
|
+
GM_API int gm_np_flatten(char **args, const int nargs,
|
|
184
|
+
int64_t *dimensions, int64_t *strides, const xnd_t stack[],
|
|
185
|
+
ndt_context_t *ctx);
|
|
186
|
+
|
|
187
|
+
GM_API int gm_np_convert_xnd(char **args, const int nargs,
|
|
188
|
+
intptr_t *dimensions, const int dims_size,
|
|
189
|
+
intptr_t *steps, const int steps_size,
|
|
190
|
+
xnd_t stack[], const int outer_dims,
|
|
191
|
+
ndt_context_t *ctx);
|
|
192
|
+
|
|
193
|
+
GM_API int gm_np_map(const gm_strided_kernel_t f,
|
|
194
|
+
char **args, int nargs,
|
|
195
|
+
intptr_t *dimensions,
|
|
196
|
+
intptr_t *steps,
|
|
197
|
+
void *data,
|
|
198
|
+
int outer_dims);
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
/******************************************************************************/
|
|
202
|
+
/* Xnd loops */
|
|
203
|
+
/******************************************************************************/
|
|
204
|
+
|
|
205
|
+
GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
|
|
206
|
+
GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
|
|
207
|
+
const int outer_dims, ndt_context_t *ctx);
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
/******************************************************************************/
|
|
211
|
+
/* Gufunc table */
|
|
212
|
+
/******************************************************************************/
|
|
213
|
+
GM_API gm_tbl_t *gm_tbl_new(ndt_context_t *ctx);
|
|
214
|
+
GM_API void gm_tbl_del(gm_tbl_t *t);
|
|
215
|
+
|
|
216
|
+
GM_API int gm_tbl_add(gm_tbl_t *tbl, const char *key, gm_func_t *value, ndt_context_t *ctx);
|
|
217
|
+
GM_API gm_func_t *gm_tbl_find(const gm_tbl_t *tbl, const char *key, ndt_context_t *ctx);
|
|
218
|
+
GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *state), void *state);
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
/******************************************************************************/
|
|
222
|
+
/* Library initialization and tables */
|
|
223
|
+
/******************************************************************************/
|
|
224
|
+
|
|
225
|
+
GM_API void gm_init(void);
|
|
226
|
+
GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
227
|
+
GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
228
|
+
GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
229
|
+
|
|
230
|
+
GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
231
|
+
GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
232
|
+
|
|
233
|
+
GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
234
|
+
GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
235
|
+
GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
236
|
+
GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
|
|
237
|
+
|
|
238
|
+
GM_API void gm_finalize(void);
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
#ifdef __cplusplus
|
|
242
|
+
} /* END extern "C" */
|
|
243
|
+
#endif
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
#endif /* GUMATH_H */
|
|
Binary file
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3
|
|
Binary file
|
|
@@ -49,10 +49,15 @@ extern "C" {
|
|
|
49
49
|
/* Exposed here for the benefit of Numba. The API should not be regarded
|
|
50
50
|
stable across versions. */
|
|
51
51
|
|
|
52
|
+
#define GM_CPU_FUNC 0x0001U
|
|
53
|
+
#define GM_CUDA_MANAGED_FUNC 0x0002U
|
|
54
|
+
|
|
52
55
|
typedef struct {
|
|
53
56
|
PyObject_HEAD
|
|
54
57
|
const gm_tbl_t *tbl; /* kernel table */
|
|
58
|
+
uint32_t flags; /* memory target */
|
|
55
59
|
char *name; /* function name */
|
|
60
|
+
PyObject *identity; /* identity element */
|
|
56
61
|
} GufuncObject;
|
|
57
62
|
|
|
58
63
|
|
|
@@ -60,21 +65,45 @@ typedef struct {
|
|
|
60
65
|
/* Capsule API */
|
|
61
66
|
/****************************************************************************/
|
|
62
67
|
|
|
63
|
-
#define
|
|
68
|
+
#define Gufunc_CheckExact_INDEX 0
|
|
69
|
+
#define Gufunc_CheckExact_RETURN int
|
|
70
|
+
#define Gufunc_CheckExact_ARGS (const PyObject *)
|
|
71
|
+
|
|
72
|
+
#define Gufunc_Check_INDEX 1
|
|
73
|
+
#define Gufunc_Check_RETURN int
|
|
74
|
+
#define Gufunc_Check_ARGS (const PyObject *)
|
|
75
|
+
|
|
76
|
+
#define Gumath_AddFunctions_INDEX 2
|
|
64
77
|
#define Gumath_AddFunctions_RETURN int
|
|
65
78
|
#define Gumath_AddFunctions_ARGS (PyObject *, const gm_tbl_t *)
|
|
66
79
|
|
|
67
|
-
#define
|
|
80
|
+
#define Gumath_AddCudaFunctions_INDEX 3
|
|
81
|
+
#define Gumath_AddCudaFunctions_RETURN int
|
|
82
|
+
#define Gumath_AddCudaFunctions_ARGS (PyObject *, const gm_tbl_t *)
|
|
83
|
+
|
|
84
|
+
#define GUMATH_MAX_API 4
|
|
68
85
|
|
|
69
86
|
|
|
70
87
|
#ifdef GUMATH_MODULE
|
|
88
|
+
static Gufunc_CheckExact_RETURN Gufunc_CheckExact Gufunc_CheckExact_ARGS;
|
|
89
|
+
static Gufunc_Check_RETURN Gufunc_Check Gufunc_Check_ARGS;
|
|
71
90
|
static Gumath_AddFunctions_RETURN Gumath_AddFunctions Gumath_AddFunctions_ARGS;
|
|
91
|
+
static Gumath_AddCudaFunctions_RETURN Gumath_AddCudaFunctions Gumath_AddCudaFunctions_ARGS;
|
|
72
92
|
#else
|
|
73
93
|
static void **_gumath_api;
|
|
74
94
|
|
|
95
|
+
#define Gufunc_CheckExact \
|
|
96
|
+
(*(Gufunc_CheckExact_RETURN (*)Gufunc_CheckExact_ARGS) _gumath_api[Gufunc_CheckExact_INDEX])
|
|
97
|
+
|
|
98
|
+
#define Gufunc_Check \
|
|
99
|
+
(*(Gufunc_Check_RETURN (*)Gufunc_Check_ARGS) _gumath_api[Gufunc_Check_INDEX])
|
|
100
|
+
|
|
75
101
|
#define Gumath_AddFunctions \
|
|
76
102
|
(*(Gumath_AddFunctions_RETURN (*)Gumath_AddFunctions_ARGS) _gumath_api[Gumath_AddFunctions_INDEX])
|
|
77
103
|
|
|
104
|
+
#define Gumath_AddCudaFunctions \
|
|
105
|
+
(*(Gumath_AddCudaFunctions_RETURN (*)Gumath_AddCudaFunctions_ARGS) _gumath_api[Gumath_AddCudaFunctions_INDEX])
|
|
106
|
+
|
|
78
107
|
|
|
79
108
|
static int
|
|
80
109
|
import_gumath(void)
|
|
@@ -0,0 +1,767 @@
|
|
|
1
|
+
#
|
|
2
|
+
# BSD 3-Clause License
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) 2017-2018, plures
|
|
5
|
+
# All rights reserved.
|
|
6
|
+
#
|
|
7
|
+
# Redistribution and use in source and binary forms, with or without
|
|
8
|
+
# modification, are permitted provided that the following conditions are met:
|
|
9
|
+
#
|
|
10
|
+
# 1. Redistributions of source code must retain the above copyright notice,
|
|
11
|
+
# this list of conditions and the following disclaimer.
|
|
12
|
+
#
|
|
13
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
14
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
15
|
+
# and/or other materials provided with the distribution.
|
|
16
|
+
#
|
|
17
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
18
|
+
# contributors may be used to endorse or promote products derived from
|
|
19
|
+
# this software without specific prior written permission.
|
|
20
|
+
#
|
|
21
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
22
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
23
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
24
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
25
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
26
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
27
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
28
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
29
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
30
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
31
|
+
#
|
|
32
|
+
|
|
33
|
+
# Python NDarray and functions for generating test cases.
|
|
34
|
+
|
|
35
|
+
from itertools import accumulate, count, product
|
|
36
|
+
from random import randrange, sample
|
|
37
|
+
from collections import namedtuple
|
|
38
|
+
import math
|
|
39
|
+
import struct
|
|
40
|
+
import unittest
|
|
41
|
+
from randdec import all_unary, all_binary
|
|
42
|
+
from randfloat import un_randfloat, bin_randfloat
|
|
43
|
+
import numpy as np
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def skip_if(condition, reason):
|
|
47
|
+
if condition:
|
|
48
|
+
raise unittest.SkipTest(reason)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# ======================================================================
|
|
52
|
+
# Minimal test cases
|
|
53
|
+
# ======================================================================
|
|
54
|
+
|
|
55
|
+
TEST_CASES = [
|
|
56
|
+
([float(i)/100.0 for i in range(2000)], "2000 * float64", "float64"),
|
|
57
|
+
|
|
58
|
+
([[float(i)/100.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
|
|
59
|
+
"2 * 1000 * float64", "float64"),
|
|
60
|
+
|
|
61
|
+
(1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float64", "float64"),
|
|
62
|
+
|
|
63
|
+
([float(i)/10.0 for i in range(2000)], "2000 * float32", "float32"),
|
|
64
|
+
|
|
65
|
+
([[float(i)/10.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
|
|
66
|
+
"2 * 1000 * float32", "float32"),
|
|
67
|
+
|
|
68
|
+
(1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float32", "float32"),
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# ======================================================================
|
|
73
|
+
# Definition of generalized slicing and indexing
|
|
74
|
+
# ======================================================================
|
|
75
|
+
|
|
76
|
+
def have_none(lst):
|
|
77
|
+
if isinstance(lst, (list, tuple)):
|
|
78
|
+
return any(have_none(item) for item in lst)
|
|
79
|
+
if isinstance(lst, dict):
|
|
80
|
+
return any(have_none(item) for item in lst.values())
|
|
81
|
+
return lst is None
|
|
82
|
+
|
|
83
|
+
def sinrec(lst):
|
|
84
|
+
if isinstance(lst, list):
|
|
85
|
+
return [sinrec(item) for item in lst]
|
|
86
|
+
elif isinstance(lst, (int, type(None))):
|
|
87
|
+
return None if lst is None else math.sin(lst)
|
|
88
|
+
else:
|
|
89
|
+
raise TypeError("unexpected operand type '%s'" % type(lst))
|
|
90
|
+
|
|
91
|
+
def mulrec(lst1, lst2):
|
|
92
|
+
if isinstance(lst1, list) and isinstance(lst2, list):
|
|
93
|
+
return [mulrec(*pair) for pair in zip(lst1, lst2)]
|
|
94
|
+
elif isinstance(lst1, (int, type(None))) and isinstance(lst2, (int, type(None))):
|
|
95
|
+
return None if lst1 is None or lst2 is None else lst1 * lst2
|
|
96
|
+
else:
|
|
97
|
+
raise TypeError("unexpected operand types '%s', '%s'" %
|
|
98
|
+
(type(lst1), type(lst2)))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def maxlevel(lst):
|
|
102
|
+
"""Return maximum nesting depth"""
|
|
103
|
+
maxlev = 0
|
|
104
|
+
def f(lst, level):
|
|
105
|
+
nonlocal maxlev
|
|
106
|
+
if isinstance(lst, list):
|
|
107
|
+
level += 1
|
|
108
|
+
maxlev = max(level, maxlev)
|
|
109
|
+
for item in lst:
|
|
110
|
+
f(item, level)
|
|
111
|
+
f(lst, 0)
|
|
112
|
+
return maxlev
|
|
113
|
+
|
|
114
|
+
def getitem(lst, indices):
|
|
115
|
+
"""Definition for multidimensional slicing and indexing on arbitrarily
|
|
116
|
+
shaped nested lists.
|
|
117
|
+
"""
|
|
118
|
+
if not indices:
|
|
119
|
+
return lst
|
|
120
|
+
|
|
121
|
+
i, indices = indices[0], indices[1:]
|
|
122
|
+
item = list.__getitem__(lst, i)
|
|
123
|
+
|
|
124
|
+
if isinstance(i, int):
|
|
125
|
+
return getitem(item, indices)
|
|
126
|
+
|
|
127
|
+
# Empty slice: check if all subsequent indices are in range for the
|
|
128
|
+
# full slice, raise IndexError otherwise. This is NumPy's behavior.
|
|
129
|
+
if not item:
|
|
130
|
+
if lst:
|
|
131
|
+
_ = getitem(lst, (slice(None),) + indices)
|
|
132
|
+
elif any(isinstance(k, int) for k in indices):
|
|
133
|
+
raise IndexError
|
|
134
|
+
return []
|
|
135
|
+
|
|
136
|
+
return [getitem(x, indices) for x in item]
|
|
137
|
+
|
|
138
|
+
class NDArray(list):
|
|
139
|
+
"""A simple wrapper for using generalized slicing/indexing on a list."""
|
|
140
|
+
def __init__(self, value, dtype=None):
|
|
141
|
+
list.__init__(self, value)
|
|
142
|
+
self.maxlevel = maxlevel(value)
|
|
143
|
+
|
|
144
|
+
def __getitem__(self, indices):
|
|
145
|
+
if not isinstance(indices, tuple):
|
|
146
|
+
indices = (indices,)
|
|
147
|
+
|
|
148
|
+
if len(indices) > self.maxlevel: # NumPy
|
|
149
|
+
raise IndexError("too many indices")
|
|
150
|
+
|
|
151
|
+
if not all(isinstance(i, (int, slice)) for i in indices):
|
|
152
|
+
raise TypeError(
|
|
153
|
+
"index must be int or slice or a tuple of integers and slices")
|
|
154
|
+
|
|
155
|
+
result = getitem(self, indices)
|
|
156
|
+
return NDArray(result) if isinstance(result, list) else result
|
|
157
|
+
|
|
158
|
+
def sin(self):
|
|
159
|
+
return NDArray(sinrec(self))
|
|
160
|
+
|
|
161
|
+
def __mul__(self, other):
|
|
162
|
+
return NDArray(mulrec(self, other))
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# ======================================================================
|
|
167
|
+
# Generate test cases
|
|
168
|
+
# ======================================================================
|
|
169
|
+
|
|
170
|
+
SUBSCRIPT_FIXED_TEST_CASES = [
|
|
171
|
+
[],
|
|
172
|
+
[[]],
|
|
173
|
+
[[], []],
|
|
174
|
+
[[0], [1]],
|
|
175
|
+
[[0], [1], [2]],
|
|
176
|
+
[[0, 1], [1, 2], [2 ,3]],
|
|
177
|
+
[[[]]],
|
|
178
|
+
[[[0]]],
|
|
179
|
+
[[[], []]],
|
|
180
|
+
[[[0], [1]]],
|
|
181
|
+
[[[0, 1], [2, 3]]],
|
|
182
|
+
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
|
|
183
|
+
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
|
|
184
|
+
]
|
|
185
|
+
|
|
186
|
+
SUBSCRIPT_VAR_TEST_CASES = [
|
|
187
|
+
[[[0, 1], [2, 3]], [[4, 5, 6], [7]], [[8, 9]]],
|
|
188
|
+
[[[0, 1], [2, 3]], [[4, 5, None], [None], [7]], [[], [None, 8]], [[9, 10]]],
|
|
189
|
+
[[[0, 1, 2], [3, 4, 5, 6], [7, 8, 9, 10]], [[11, 12, 13, 14], [15, 16, 17], [18, 19]]],
|
|
190
|
+
[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9]], [[10, 11]]]
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
def single_fixed(max_ndim=4, min_shape=1, max_shape=10):
|
|
194
|
+
nat = count()
|
|
195
|
+
shape = [randrange(min_shape, max_shape+1) for _ in range(max_ndim)]
|
|
196
|
+
|
|
197
|
+
def f(ndim):
|
|
198
|
+
if ndim == 0:
|
|
199
|
+
return next(nat)
|
|
200
|
+
return [f(ndim-1) for _ in range(shape[ndim-1])]
|
|
201
|
+
|
|
202
|
+
return f(max_ndim)
|
|
203
|
+
|
|
204
|
+
def gen_fixed(max_ndim=4, min_shape=1, max_shape=10):
|
|
205
|
+
assert max_ndim >=0 and min_shape >=0 and min_shape <= max_shape
|
|
206
|
+
|
|
207
|
+
for _ in range(30):
|
|
208
|
+
yield single_fixed(max_ndim, min_shape, max_shape)
|
|
209
|
+
|
|
210
|
+
def single_var(max_ndim=4, min_shape=1, max_shape=10):
|
|
211
|
+
nat = count()
|
|
212
|
+
|
|
213
|
+
def f(ndim):
|
|
214
|
+
if ndim == 0:
|
|
215
|
+
return next(nat)
|
|
216
|
+
if ndim == 1:
|
|
217
|
+
shape = randrange(min_shape, max_shape+1)
|
|
218
|
+
else:
|
|
219
|
+
n = 1 if min_shape == 0 else min_shape
|
|
220
|
+
shape = randrange(n, max_shape+1)
|
|
221
|
+
return [f(ndim-1) for _ in range(shape)]
|
|
222
|
+
|
|
223
|
+
return f(max_ndim)
|
|
224
|
+
|
|
225
|
+
def gen_var(max_ndim=4, min_shape=1, max_shape=10):
|
|
226
|
+
assert max_ndim >=0 and min_shape >=0 and min_shape <= max_shape
|
|
227
|
+
|
|
228
|
+
for _ in range(30):
|
|
229
|
+
yield single_var(max_ndim, min_shape, max_shape)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def genindices():
|
|
233
|
+
for i in range(4):
|
|
234
|
+
yield (i,)
|
|
235
|
+
for i in range(4):
|
|
236
|
+
for j in range(4):
|
|
237
|
+
yield (i, j)
|
|
238
|
+
for i in range(4):
|
|
239
|
+
for j in range(4):
|
|
240
|
+
for k in range(4):
|
|
241
|
+
yield (i, j, k)
|
|
242
|
+
|
|
243
|
+
def rslice(ndim):
|
|
244
|
+
start = randrange(0, ndim+1)
|
|
245
|
+
stop = randrange(0, ndim+1)
|
|
246
|
+
step = 0
|
|
247
|
+
while step == 0:
|
|
248
|
+
step = randrange(-ndim-1, ndim+1)
|
|
249
|
+
start = None if randrange(5) == 4 else start
|
|
250
|
+
stop = None if randrange(5) == 4 else stop
|
|
251
|
+
step = None if randrange(5) == 4 else step
|
|
252
|
+
return slice(start, stop, step)
|
|
253
|
+
|
|
254
|
+
def rslice_neg(ndim):
|
|
255
|
+
start = randrange(-ndim-1, ndim+1)
|
|
256
|
+
stop = randrange(-ndim-1, ndim+1)
|
|
257
|
+
step = 0
|
|
258
|
+
while step == 0:
|
|
259
|
+
step = randrange(-ndim-1, ndim+1)
|
|
260
|
+
return slice(start, stop, step)
|
|
261
|
+
|
|
262
|
+
def multislice(ndim):
|
|
263
|
+
return tuple(rslice(ndim) for _ in range(randrange(1, ndim+1)))
|
|
264
|
+
|
|
265
|
+
def randslices(ndim):
|
|
266
|
+
for i in range(5):
|
|
267
|
+
yield multislice(ndim)
|
|
268
|
+
|
|
269
|
+
def gen_indices_or_slices():
|
|
270
|
+
for i in range(5):
|
|
271
|
+
if randrange(2):
|
|
272
|
+
yield (randrange(4), randrange(4), randrange(4))
|
|
273
|
+
else:
|
|
274
|
+
yield multislice(3)
|
|
275
|
+
|
|
276
|
+
def genslices(n):
|
|
277
|
+
"""Generate all possible slices for a single dimension."""
|
|
278
|
+
def range_with_none():
|
|
279
|
+
yield None
|
|
280
|
+
yield from range(-n, n+1)
|
|
281
|
+
|
|
282
|
+
for t in product(range_with_none(), range_with_none(), range_with_none()):
|
|
283
|
+
s = slice(*t)
|
|
284
|
+
if s.step != 0:
|
|
285
|
+
yield s
|
|
286
|
+
|
|
287
|
+
def genslices_ndim(ndim, shape):
|
|
288
|
+
"""Generate all possible slice tuples for 'shape'."""
|
|
289
|
+
iterables = [genslices(shape[n]) for n in range(ndim)]
|
|
290
|
+
yield from product(*iterables)
|
|
291
|
+
|
|
292
|
+
def mixed_index(max_ndim):
|
|
293
|
+
ndim = randrange(1, max_ndim+1)
|
|
294
|
+
indices = []
|
|
295
|
+
for i in range(1, ndim+1):
|
|
296
|
+
if randrange(2):
|
|
297
|
+
indices.append(randrange(-max_ndim, max_ndim))
|
|
298
|
+
else:
|
|
299
|
+
indices.append(rslice(ndim))
|
|
300
|
+
return tuple(indices)
|
|
301
|
+
|
|
302
|
+
def mixed_index_neg(max_ndim):
|
|
303
|
+
ndim = randrange(1, max_ndim+1)
|
|
304
|
+
indices = []
|
|
305
|
+
for i in range(1, ndim+1):
|
|
306
|
+
if randrange(2):
|
|
307
|
+
indices.append(randrange(-max_ndim, max_ndim))
|
|
308
|
+
else:
|
|
309
|
+
indices.append(rslice_neg(ndim))
|
|
310
|
+
return tuple(indices)
|
|
311
|
+
|
|
312
|
+
def mixed_indices(max_ndim):
|
|
313
|
+
for i in range(5):
|
|
314
|
+
yield mixed_index(max_ndim)
|
|
315
|
+
for i in range(5):
|
|
316
|
+
yield mixed_index_neg(max_ndim)
|
|
317
|
+
|
|
318
|
+
def itos(indices):
|
|
319
|
+
return ", ".join(str(i) if isinstance(i, int) else "%s:%s:%s" %
|
|
320
|
+
(i.start, i.stop, i.step) for i in indices)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
# ======================================================================
|
|
324
|
+
# Split a shape into N almost equal slices
|
|
325
|
+
# ======================================================================
|
|
326
|
+
|
|
327
|
+
def start(i, r, q):
|
|
328
|
+
return i*(q+1) if i < r else r+i*q
|
|
329
|
+
|
|
330
|
+
def stop(i, r, q):
|
|
331
|
+
return (i+1)*(q+1) if i < r else r+(i+1)*q
|
|
332
|
+
|
|
333
|
+
def step(i, r, q):
|
|
334
|
+
return q+1 if i < r else q
|
|
335
|
+
|
|
336
|
+
def sl(i, r, q):
|
|
337
|
+
return slice(start(i, r, q), stop(i, r, q))
|
|
338
|
+
|
|
339
|
+
def prepend(x, xs):
|
|
340
|
+
return [(x,) + t for t in xs]
|
|
341
|
+
|
|
342
|
+
def last_column(i, r, q, n):
|
|
343
|
+
return [(sl(i, r, q),) for i in range(n)]
|
|
344
|
+
|
|
345
|
+
def schedule(n, shape):
|
|
346
|
+
assert isinstance(n, int) and isinstance(shape, list)
|
|
347
|
+
if (n <= 0):
|
|
348
|
+
raise ValueError("n must be greater than zero")
|
|
349
|
+
if shape == []:
|
|
350
|
+
return [()]
|
|
351
|
+
m, ms = shape[0], shape[1:]
|
|
352
|
+
if (m <= 0):
|
|
353
|
+
raise ValueError("shape must be greater than zero")
|
|
354
|
+
if n <= m:
|
|
355
|
+
q, r = divmod(m, n)
|
|
356
|
+
return last_column(0, r, q, n)
|
|
357
|
+
else:
|
|
358
|
+
q, r = divmod(n, m)
|
|
359
|
+
return column(0, r, q, m, ms)
|
|
360
|
+
|
|
361
|
+
def column(i, r, q, m, ms):
|
|
362
|
+
if i == m: return []
|
|
363
|
+
return prepend(slice(i, i+1),
|
|
364
|
+
schedule(step(i, r, q), ms)) + \
|
|
365
|
+
column(i+1, r, q, m, ms)
|
|
366
|
+
|
|
367
|
+
# ======================================================================
|
|
368
|
+
# Split an xnd object into N subtrees
|
|
369
|
+
# ======================================================================
|
|
370
|
+
|
|
371
|
+
def zero_in_shape(shape):
|
|
372
|
+
for i in shape:
|
|
373
|
+
if i == 0:
|
|
374
|
+
return True
|
|
375
|
+
return False
|
|
376
|
+
|
|
377
|
+
def split_xnd(x, n, max_outer=None):
|
|
378
|
+
shape = list(x.type.shape)
|
|
379
|
+
if zero_in_shape(shape):
|
|
380
|
+
raise ValueError("split does not support zeros in shape")
|
|
381
|
+
if max_outer is not None:
|
|
382
|
+
shape = shape[:max_outer]
|
|
383
|
+
indices_list = schedule(n, shape)
|
|
384
|
+
return [x[i] for i in indices_list]
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
# ======================================================================
|
|
388
|
+
# Generate test cases
|
|
389
|
+
# ======================================================================
|
|
390
|
+
|
|
391
|
+
functions = {
|
|
392
|
+
"unary": {
|
|
393
|
+
"default": ["copy", "abs"],
|
|
394
|
+
"arith": ["negative"],
|
|
395
|
+
"complex_math_with_half": ["exp", "log", "log10", "sqrt", "sin", "cos"],
|
|
396
|
+
"complex_math": ["tan", "asin", "acos", "atan", "sinh", "cosh", "tanh",
|
|
397
|
+
"asinh", "acosh", "atanh"],
|
|
398
|
+
"real_math_with_half": ["fabs", "exp2", "log2"],
|
|
399
|
+
"real_math": ["expm1", "log1p", "logb", "cbrt", "erf", "erfc", "lgamma",
|
|
400
|
+
"tgamma", "ceil", "floor", "trunc", "round", "nearbyint"],
|
|
401
|
+
"bitwise": ["invert"],
|
|
402
|
+
},
|
|
403
|
+
"binary": {
|
|
404
|
+
"default": ["add", "subtract", "multiply", "floor_divide", "remainder", "power"],
|
|
405
|
+
"float_result": ["divide"],
|
|
406
|
+
"bool_result": ["less_equal", "less", "greater_equal", "greater", "equal", "not_equal"],
|
|
407
|
+
"bitwise": ["bitwise_and", "bitwise_or", "bitwise_xor"]
|
|
408
|
+
},
|
|
409
|
+
"binary_mv": {
|
|
410
|
+
"default": ["divmod"],
|
|
411
|
+
}
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
def complex_noimpl(name):
|
|
415
|
+
return name in functions["unary"]["real_math"] or \
|
|
416
|
+
name in functions["unary"]["real_math_with_half"]
|
|
417
|
+
|
|
418
|
+
def half_noimpl(name):
|
|
419
|
+
return name in functions["unary"]["real_math"] or \
|
|
420
|
+
name in functions["unary"]["complex_math"] or \
|
|
421
|
+
name in ("floor_divide", "remainder")
|
|
422
|
+
|
|
423
|
+
tunsigned = ["bool", "uint8", "uint16", "uint32", "uint64"]
|
|
424
|
+
tsigned = ["int8", "int16", "int32", "int64"]
|
|
425
|
+
tfloat = ["bfloat16", "float16", "float32", "float64"]
|
|
426
|
+
tcomplex = ["complex32", "complex64", "complex128"]
|
|
427
|
+
|
|
428
|
+
tinfo = {
|
|
429
|
+
"bool": (0, 1, 0),
|
|
430
|
+
"uint8": (0, 2**8-1, 0),
|
|
431
|
+
"uint16": (0, 2**16-1, 0),
|
|
432
|
+
"uint32": (0, 2**32-1, 0),
|
|
433
|
+
"uint64": (0, 2**64-1, 0),
|
|
434
|
+
"int8": (-2**7, 2**7-1, 0),
|
|
435
|
+
"int16": (-2**15, 2**15-1, 0),
|
|
436
|
+
"int32": (-2**31, 2**31-1, 0),
|
|
437
|
+
"int64": (-2**63, 2**63-1, 0),
|
|
438
|
+
"float16": (-2**11, 2**11, 15),
|
|
439
|
+
"bfloat16": (-2**8, 2**8, 127),
|
|
440
|
+
"float32": (-2**24, 2**24, 127),
|
|
441
|
+
"float64": (-2**53, 2**53, 1023),
|
|
442
|
+
"complex32": (-2**11, 2**11, 15),
|
|
443
|
+
"complex64": (-2**24, 2**24, 127),
|
|
444
|
+
"complex128": (-2**53, 2**53, 1023)
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
class Tint(object):
|
|
448
|
+
def __init__(self, type):
|
|
449
|
+
if type not in tunsigned + tsigned:
|
|
450
|
+
raise ValueError("not an integer type: '%s'" % type)
|
|
451
|
+
self.type = type
|
|
452
|
+
self.min, self.max, self.exp = tinfo[type]
|
|
453
|
+
self.all = (self.type, self.min, self.max, self.exp)
|
|
454
|
+
def __repr__(self):
|
|
455
|
+
return self.type
|
|
456
|
+
def __eq__(self, other):
|
|
457
|
+
return isinstance(Tint, other) and self.all == other.all
|
|
458
|
+
def __hash__(self):
|
|
459
|
+
return hash(self.all)
|
|
460
|
+
def testcases(self):
|
|
461
|
+
yield 0
|
|
462
|
+
yield self.min
|
|
463
|
+
yield self.max
|
|
464
|
+
for i in range(10):
|
|
465
|
+
yield randrange(self.min, self.max+1)
|
|
466
|
+
def cpu_noimpl(self, f=None):
|
|
467
|
+
return False
|
|
468
|
+
def cpu_nokern(self, f=None):
|
|
469
|
+
return False
|
|
470
|
+
def cuda_noimpl(self, f=None):
|
|
471
|
+
return False
|
|
472
|
+
def cuda_nokern(self, f=None):
|
|
473
|
+
return False
|
|
474
|
+
|
|
475
|
+
class Tfloat(object):
|
|
476
|
+
def __init__(self, type):
|
|
477
|
+
if type not in tfloat:
|
|
478
|
+
raise ValueError("not a float type: '%s'" % type)
|
|
479
|
+
self.type = type
|
|
480
|
+
self.min, self.max, self.exp = tinfo[type]
|
|
481
|
+
self.all = (self.type, self.min, self.max, self.exp)
|
|
482
|
+
def __repr__(self):
|
|
483
|
+
return self.type
|
|
484
|
+
def __eq__(self, other):
|
|
485
|
+
return isinstance(Tint, other) and self.all == other.all
|
|
486
|
+
def __hash__(self):
|
|
487
|
+
return hash(self.all)
|
|
488
|
+
def testcases(self):
|
|
489
|
+
yield 0
|
|
490
|
+
yield 0.5
|
|
491
|
+
yield -0.5
|
|
492
|
+
yield self.min
|
|
493
|
+
yield self.max
|
|
494
|
+
prec = randrange(1, 10)
|
|
495
|
+
for v in all_unary(prec, self.exp, 1):
|
|
496
|
+
yield float(v)
|
|
497
|
+
for v in un_randfloat():
|
|
498
|
+
yield float(v)
|
|
499
|
+
def cpu_noimpl(self, f=None):
|
|
500
|
+
return self.type == "float16"
|
|
501
|
+
def cpu_nokern(self, f=None):
|
|
502
|
+
return False
|
|
503
|
+
def cuda_noimpl(self, f=None):
|
|
504
|
+
if self.type == "float16":
|
|
505
|
+
return half_noimpl(f)
|
|
506
|
+
def cuda_nokern(self, f=None):
|
|
507
|
+
return False
|
|
508
|
+
|
|
509
|
+
class Tcomplex(object):
|
|
510
|
+
def __init__(self, type):
|
|
511
|
+
if type not in tcomplex:
|
|
512
|
+
raise ValueError("not a complex type: '%s'" % type)
|
|
513
|
+
self.type = type
|
|
514
|
+
self.min, self.max, self.exp = tinfo[type]
|
|
515
|
+
self.all = (self.type, self.min, self.max, self.exp)
|
|
516
|
+
def __repr__(self):
|
|
517
|
+
return self.type
|
|
518
|
+
def __eq__(self, other):
|
|
519
|
+
return isinstance(Tint, other) and self.all == other.all
|
|
520
|
+
def __hash__(self):
|
|
521
|
+
return hash(self.all)
|
|
522
|
+
def testcases(self):
|
|
523
|
+
yield 0
|
|
524
|
+
yield 0.5
|
|
525
|
+
yield -0.5
|
|
526
|
+
yield 0.5j
|
|
527
|
+
yield -0.5j
|
|
528
|
+
yield self.min
|
|
529
|
+
yield self.max
|
|
530
|
+
prec = randrange(1, 10)
|
|
531
|
+
for v, w in all_binary(prec, self.exp, 1):
|
|
532
|
+
yield complex(float(v), float(w))
|
|
533
|
+
for v, w in bin_randfloat():
|
|
534
|
+
yield complex(float(v), float(w))
|
|
535
|
+
def cpu_noimpl(self, f=None):
|
|
536
|
+
if self.type == "complex32":
|
|
537
|
+
return True
|
|
538
|
+
return complex_noimpl(f)
|
|
539
|
+
def cpu_nokern(self, f=None):
|
|
540
|
+
return f in ("floor_divide", "remainder")
|
|
541
|
+
def cuda_noimpl(self, f=None):
|
|
542
|
+
if self.type == "complex32":
|
|
543
|
+
return True
|
|
544
|
+
return complex_noimpl(f)
|
|
545
|
+
def cuda_nokern(self, f=None):
|
|
546
|
+
return f in ("floor_divide", "remainder")
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
tinfo_default = [
|
|
550
|
+
Tint("uint8"),
|
|
551
|
+
Tint("uint16"),
|
|
552
|
+
Tint("uint32"),
|
|
553
|
+
Tint("uint64"),
|
|
554
|
+
Tint("int8"),
|
|
555
|
+
Tint("int16"),
|
|
556
|
+
Tint("int32"),
|
|
557
|
+
Tint("int64"),
|
|
558
|
+
Tfloat("float16"),
|
|
559
|
+
Tfloat("bfloat16"),
|
|
560
|
+
Tfloat("float32"),
|
|
561
|
+
Tfloat("float64"),
|
|
562
|
+
Tcomplex("complex32"),
|
|
563
|
+
Tcomplex("complex64"),
|
|
564
|
+
Tcomplex("complex128")
|
|
565
|
+
]
|
|
566
|
+
|
|
567
|
+
tinfo_bitwise = [
|
|
568
|
+
Tint("bool"),
|
|
569
|
+
Tint("uint8"),
|
|
570
|
+
Tint("uint16"),
|
|
571
|
+
Tint("uint32"),
|
|
572
|
+
Tint("uint64"),
|
|
573
|
+
Tint("int8"),
|
|
574
|
+
Tint("int16"),
|
|
575
|
+
Tint("int32"),
|
|
576
|
+
Tint("int64")
|
|
577
|
+
]
|
|
578
|
+
|
|
579
|
+
implemented_sigs = {
|
|
580
|
+
"unary": {
|
|
581
|
+
"default": {}, "float_result": {}
|
|
582
|
+
},
|
|
583
|
+
"binary": {
|
|
584
|
+
"default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
|
|
585
|
+
},
|
|
586
|
+
"binary_mv": {
|
|
587
|
+
"default": {
|
|
588
|
+
(Tint("uint8"), Tint("uint8")): (Tint("uint8"), Tint("uint8")),
|
|
589
|
+
(Tint("uint16"), Tint("uint16")): (Tint("uint16"), Tint("uint16")),
|
|
590
|
+
(Tint("uint32"), Tint("uint32")): (Tint("uint32"), Tint("uint32")),
|
|
591
|
+
(Tint("uint64"), Tint("uint64")): (Tint("uint64"), Tint("uint64")),
|
|
592
|
+
(Tint("int8"), Tint("int8")): (Tint("int8"), Tint("int8")),
|
|
593
|
+
(Tint("int16"), Tint("int16")): (Tint("int16"), Tint("int16")),
|
|
594
|
+
(Tint("int32"), Tint("int32")): (Tint("int32"), Tint("int32")),
|
|
595
|
+
(Tint("int64"), Tint("int64")): (Tint("int64"), Tint("int64")),
|
|
596
|
+
(Tfloat("float32"), Tfloat("float32")): (Tfloat("float32"), Tfloat("float32")),
|
|
597
|
+
(Tfloat("float64"), Tfloat("float64")): (Tfloat("float64"), Tfloat("float64"))
|
|
598
|
+
},
|
|
599
|
+
}
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
exact_sigs = {
|
|
603
|
+
"unary": {
|
|
604
|
+
"default": {}, "float_result": {}
|
|
605
|
+
},
|
|
606
|
+
"binary": {
|
|
607
|
+
"default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
|
|
608
|
+
}
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
inexact_sigs = {
|
|
612
|
+
"unary": {
|
|
613
|
+
"default": {}, "float_result": {}
|
|
614
|
+
},
|
|
615
|
+
"binary": {
|
|
616
|
+
"default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
|
|
617
|
+
}
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
def init_unary_cast(pattern, tinfo, rank):
|
|
621
|
+
t = tinfo[rank]
|
|
622
|
+
|
|
623
|
+
start = max(8, rank) if pattern == "float_result" else rank
|
|
624
|
+
found_cast = False
|
|
625
|
+
|
|
626
|
+
for i in range(start, len(tinfo_default)):
|
|
627
|
+
cast = tinfo[i]
|
|
628
|
+
if cast.min <= t.min and t.max <= cast.max:
|
|
629
|
+
if found_cast or (t.type=="bfloat16") != (cast.type=="bfloat16"):
|
|
630
|
+
exact_sigs["unary"][pattern][(t,)] = cast
|
|
631
|
+
else:
|
|
632
|
+
found_cast = True
|
|
633
|
+
implemented_sigs["unary"][pattern][(t,)] = cast
|
|
634
|
+
exact_sigs["unary"][pattern][(t,)] = cast
|
|
635
|
+
else:
|
|
636
|
+
inexact_sigs["unary"][pattern][(t,)] = cast
|
|
637
|
+
|
|
638
|
+
def init_unary_cast_tbl(pattern):
|
|
639
|
+
if pattern == "default":
|
|
640
|
+
tinfo = [Tint("bool")] + tinfo_default
|
|
641
|
+
elif pattern == "float_result":
|
|
642
|
+
tinfo = tinfo_default
|
|
643
|
+
elif pattern == "bitwise":
|
|
644
|
+
tinfo = tinfo_bitwise
|
|
645
|
+
else:
|
|
646
|
+
raise ValueError("unsupported function type '%s'" % func)
|
|
647
|
+
|
|
648
|
+
for rank, _ in enumerate(tinfo):
|
|
649
|
+
init_unary_cast(pattern, tinfo, rank)
|
|
650
|
+
|
|
651
|
+
def is_binary_common_cast(cast, t, u):
|
|
652
|
+
if cast.min <= t.min and t.max <= cast.max and \
|
|
653
|
+
cast.min <= u.min and u.max <= cast.max:
|
|
654
|
+
if isinstance(cast, Tfloat):
|
|
655
|
+
return t.exp <= cast.exp and u.exp <= cast.exp
|
|
656
|
+
else:
|
|
657
|
+
return True
|
|
658
|
+
return False
|
|
659
|
+
|
|
660
|
+
def init_binary_cast(pattern, tinfo, rank1, rank2):
|
|
661
|
+
min_rank = min(rank1, rank2)
|
|
662
|
+
max_rank = max(rank1, rank2)
|
|
663
|
+
|
|
664
|
+
t = tinfo[min_rank]
|
|
665
|
+
u = tinfo[max_rank]
|
|
666
|
+
|
|
667
|
+
start = max(8, max_rank) if pattern == "float_result" else max_rank
|
|
668
|
+
smallest_common_cast = False
|
|
669
|
+
|
|
670
|
+
for i in range(start, len(tinfo_default)):
|
|
671
|
+
common_cast = tinfo_default[i]
|
|
672
|
+
w = Tint("bool") if pattern == "bool_result" else common_cast
|
|
673
|
+
if is_binary_common_cast(common_cast, t, u):
|
|
674
|
+
if smallest_common_cast:
|
|
675
|
+
exact_sigs["binary"][pattern][(t, u)] = w
|
|
676
|
+
else:
|
|
677
|
+
smallest_common_cast = True
|
|
678
|
+
implemented_sigs["binary"][pattern][(t, u)] = w
|
|
679
|
+
exact_sigs["binary"][pattern][(t, u)] = w
|
|
680
|
+
else:
|
|
681
|
+
inexact_sigs["binary"][pattern][(t, u)] = w
|
|
682
|
+
|
|
683
|
+
def init_binary_cast_tbl(pattern):
|
|
684
|
+
if pattern == "default" or pattern == "float_result" or pattern == "bool_result":
|
|
685
|
+
tinfo = tinfo_default
|
|
686
|
+
elif pattern == "bitwise":
|
|
687
|
+
tinfo = tinfo_bitwise
|
|
688
|
+
else:
|
|
689
|
+
raise ValueError("unsupported function type '%s'" % pattern)
|
|
690
|
+
|
|
691
|
+
for rank1, _ in enumerate(tinfo):
|
|
692
|
+
for rank2, _ in enumerate(tinfo):
|
|
693
|
+
init_binary_cast(pattern, tinfo, rank1, rank2)
|
|
694
|
+
|
|
695
|
+
_struct_format = {
|
|
696
|
+
"float16": "e",
|
|
697
|
+
"float32": "f",
|
|
698
|
+
"float64": "d",
|
|
699
|
+
"complex32": "e",
|
|
700
|
+
"complex64": "f",
|
|
701
|
+
"complex128": "d"
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
def roundtrip_ne(v, fmt):
|
|
705
|
+
if fmt == "e":
|
|
706
|
+
try:
|
|
707
|
+
struct.pack(fmt, v)
|
|
708
|
+
except (OverflowError, struct.error):
|
|
709
|
+
return True
|
|
710
|
+
else:
|
|
711
|
+
return False
|
|
712
|
+
else:
|
|
713
|
+
if math.isinf(v):
|
|
714
|
+
return False
|
|
715
|
+
s = struct.unpack(fmt, struct.pack(fmt, v))[0]
|
|
716
|
+
return math.isinf(float(s))
|
|
717
|
+
|
|
718
|
+
def struct_overflow(v, t):
|
|
719
|
+
try:
|
|
720
|
+
fmt = _struct_format[t.type]
|
|
721
|
+
except KeyError:
|
|
722
|
+
return False
|
|
723
|
+
|
|
724
|
+
if isinstance(t, Tcomplex):
|
|
725
|
+
return roundtrip_ne(v.real, fmt) or roundtrip_ne(v.imag, fmt)
|
|
726
|
+
else:
|
|
727
|
+
return roundtrip_ne(v, fmt)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
init_unary_cast_tbl("default")
|
|
731
|
+
init_unary_cast_tbl("float_result")
|
|
732
|
+
|
|
733
|
+
init_binary_cast_tbl("default")
|
|
734
|
+
init_binary_cast_tbl("float_result")
|
|
735
|
+
init_binary_cast_tbl("bool_result")
|
|
736
|
+
init_binary_cast_tbl("bitwise")
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
_np_names = {
|
|
740
|
+
"asin" : "arcsin",
|
|
741
|
+
"acos" : "arccos",
|
|
742
|
+
"atan" : "arctan",
|
|
743
|
+
"asinh" : "arcsinh",
|
|
744
|
+
"acosh" : "arccosh",
|
|
745
|
+
"atanh" : "arctanh",
|
|
746
|
+
"nearbyint" : "round",
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
def np_function(name):
|
|
750
|
+
return _np_names.get(name, name)
|
|
751
|
+
|
|
752
|
+
def np_noimpl(name):
|
|
753
|
+
if name == "round":
|
|
754
|
+
# np.round == gumath.nearbyint
|
|
755
|
+
return True
|
|
756
|
+
try:
|
|
757
|
+
getattr(np, name)
|
|
758
|
+
return False
|
|
759
|
+
except AttributeError:
|
|
760
|
+
return True
|
|
761
|
+
|
|
762
|
+
def gen_axes(ndim):
|
|
763
|
+
for i in range(ndim):
|
|
764
|
+
yield i
|
|
765
|
+
lst = list(range(ndim))
|
|
766
|
+
for i in range(ndim):
|
|
767
|
+
yield tuple(sample(lst, i))
|