gumath 0.2.0dev5 → 0.2.0dev8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,78 @@
1
+ #include <Python.h>
2
+ #include "ndtypes.h"
3
+ #include "pyndtypes.h"
4
+ #include "gumath.h"
5
+ #include "pygumath.h"
6
+
7
+
8
+ /****************************************************************************/
9
+ /* Module globals */
10
+ /****************************************************************************/
11
+
12
+ /* Function table */
13
+ static gm_tbl_t *table = NULL;
14
+
15
+
16
+ /****************************************************************************/
17
+ /* Module */
18
+ /****************************************************************************/
19
+
20
+ static struct PyModuleDef cuda_module = {
21
+ PyModuleDef_HEAD_INIT, /* m_base */
22
+ "cuda", /* m_name */
23
+ NULL, /* m_doc */
24
+ -1, /* m_size */
25
+ NULL, /* m_methods */
26
+ NULL, /* m_slots */
27
+ NULL, /* m_traverse */
28
+ NULL, /* m_clear */
29
+ NULL /* m_free */
30
+ };
31
+
32
+
33
+ PyMODINIT_FUNC
34
+ PyInit_cuda(void)
35
+ {
36
+ NDT_STATIC_CONTEXT(ctx);
37
+ PyObject *m = NULL;
38
+ static int initialized = 0;
39
+
40
+ if (!initialized) {
41
+ if (import_ndtypes() < 0) {
42
+ return NULL;
43
+ }
44
+ if (import_gumath() < 0) {
45
+ return NULL;
46
+ }
47
+
48
+ table = gm_tbl_new(&ctx);
49
+ if (table == NULL) {
50
+ return Ndt_SetError(&ctx);
51
+ }
52
+
53
+ if (gm_init_cuda_unary_kernels(table, &ctx) < 0) {
54
+ return Ndt_SetError(&ctx);
55
+ }
56
+
57
+ if (gm_init_cuda_binary_kernels(table, &ctx) < 0) {
58
+ return Ndt_SetError(&ctx);
59
+ }
60
+
61
+ initialized = 1;
62
+ }
63
+
64
+ m = PyModule_Create(&cuda_module);
65
+ if (m == NULL) {
66
+ goto error;
67
+ }
68
+
69
+ if (Gumath_AddCudaFunctions(m, table) < 0) {
70
+ goto error;
71
+ }
72
+
73
+ return m;
74
+
75
+ error:
76
+ Py_CLEAR(m);
77
+ return NULL;
78
+ }
@@ -56,11 +56,6 @@ PyInit_examples(void)
56
56
  }
57
57
 
58
58
  /* extending examples */
59
- #ifndef _MSC_VER
60
- if (gm_init_bfloat16_kernels(table, &ctx) < 0) {
61
- return Ndt_SetError(&ctx);
62
- }
63
- #endif
64
59
  if (gm_init_graph_kernels(table, &ctx) < 0) {
65
60
  return Ndt_SetError(&ctx);
66
61
  }
@@ -50,10 +50,10 @@ PyInit_functions(void)
50
50
  return Ndt_SetError(&ctx);
51
51
  }
52
52
 
53
- if (gm_init_unary_kernels(table, &ctx) < 0) {
53
+ if (gm_init_cpu_unary_kernels(table, &ctx) < 0) {
54
54
  return Ndt_SetError(&ctx);
55
55
  }
56
- if (gm_init_binary_kernels(table, &ctx) < 0) {
56
+ if (gm_init_cpu_binary_kernels(table, &ctx) < 0) {
57
57
  return Ndt_SetError(&ctx);
58
58
  }
59
59
 
@@ -0,0 +1,246 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #ifndef GUMATH_H
35
+ #define GUMATH_H
36
+
37
+
38
+ #ifdef __cplusplus
39
+ extern "C" {
40
+ #endif
41
+
42
+ #ifdef __cplusplus
43
+ #include <cstdint>
44
+ #else
45
+ #include <stdint.h>
46
+ #endif
47
+
48
+ #include "ndtypes.h"
49
+ #include "xnd.h"
50
+
51
+
52
+ #ifdef _MSC_VER
53
+ #if defined (GM_EXPORT)
54
+ #define GM_API __declspec(dllexport)
55
+ #elif defined(GM_IMPORT)
56
+ #define GM_API __declspec(dllimport)
57
+ #else
58
+ #define GM_API
59
+ #endif
60
+
61
+ #ifndef GM_UNUSED
62
+ #define GM_UNUSED
63
+ #endif
64
+
65
+ #include "malloc.h"
66
+ #define ALLOCA(type, name, nmemb) type *name = _alloca(nmemb * sizeof(type))
67
+ #else
68
+ #define GM_API
69
+ #if defined(__GNUC__) && !defined(__INTEL_COMPILER)
70
+ #define GM_UNUSED __attribute__((unused))
71
+ #else
72
+ #define GM_UNUSED
73
+ #endif
74
+
75
+ #define ALLOCA(type, name, nmemb) type name[nmemb]
76
+ #endif
77
+
78
+
79
+ #define GM_MAX_KERNELS 8192
80
+ #define GM_THREAD_CUTOFF 1000000
81
+
82
+ typedef float float32_t;
83
+ typedef double float64_t;
84
+
85
+
86
+ typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
87
+ typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
88
+
89
+ /*
90
+ * Collection of specialized kernels for a single function signature.
91
+ *
92
+ * NOTE: The specialized kernel lookup scheme is transitional and may
93
+ * be replaced by something else.
94
+ *
95
+ * This should be considered as a first version of a kernel request
96
+ * protocol.
97
+ */
98
+ typedef struct {
99
+ const ndt_t *sig;
100
+ const ndt_constraint_t *constraint;
101
+
102
+ /* Xnd signatures */
103
+ gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
104
+ gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
105
+ gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
106
+ gm_xnd_kernel_t C; /* C in inner dimensions */
107
+ gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
108
+ gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
109
+
110
+ /* NumPy signature */
111
+ gm_strided_kernel_t Strided;
112
+ } gm_kernel_set_t;
113
+
114
+ typedef struct {
115
+ const char *name;
116
+ const char *type;
117
+ const ndt_methods_t *meth;
118
+ } gm_typedef_init_t;
119
+
120
+ typedef struct {
121
+ const char *name;
122
+ const char *sig;
123
+ const ndt_constraint_t *constraint;
124
+ uint32_t cap;
125
+
126
+ /* Xnd signatures */
127
+ gm_xnd_kernel_t OptC;
128
+ gm_xnd_kernel_t OptZ;
129
+ gm_xnd_kernel_t OptS;
130
+ gm_xnd_kernel_t C;
131
+ gm_xnd_kernel_t Fortran;
132
+ gm_xnd_kernel_t Xnd;
133
+
134
+ /* NumPy signature */
135
+ gm_strided_kernel_t Strided;
136
+ } gm_kernel_init_t;
137
+
138
+ /* Actual kernel selected for application */
139
+ typedef struct {
140
+ uint32_t flag;
141
+ const gm_kernel_set_t *set;
142
+ } gm_kernel_t;
143
+
144
+ /* Multimethod with associated kernels */
145
+ typedef struct gm_func gm_func_t;
146
+ typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
147
+ const ndt_t *in[], const int64_t li[],
148
+ int nin, int nout, bool check_broadcast,
149
+ ndt_context_t *ctx);
150
+ struct gm_func {
151
+ char *name;
152
+ gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
153
+ int nkernels;
154
+ gm_kernel_set_t kernels[GM_MAX_KERNELS];
155
+ };
156
+
157
+
158
+ typedef struct _gm_tbl gm_tbl_t;
159
+
160
+
161
+ /******************************************************************************/
162
+ /* Functions */
163
+ /******************************************************************************/
164
+
165
+ GM_API gm_func_t *gm_func_new(const char *name, ndt_context_t *ctx);
166
+ GM_API void gm_func_del(gm_func_t *f);
167
+
168
+ GM_API gm_func_t *gm_add_func(gm_tbl_t *tbl, const char *name, ndt_context_t *ctx);
169
+ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx);
170
+ GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
171
+
172
+ GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
173
+ const ndt_t *types[], const int64_t li[], int nin, int nout,
174
+ bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
175
+ GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
176
+ GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
177
+
178
+
179
+ /******************************************************************************/
180
+ /* NumPy loops */
181
+ /******************************************************************************/
182
+
183
+ GM_API int gm_np_flatten(char **args, const int nargs,
184
+ int64_t *dimensions, int64_t *strides, const xnd_t stack[],
185
+ ndt_context_t *ctx);
186
+
187
+ GM_API int gm_np_convert_xnd(char **args, const int nargs,
188
+ intptr_t *dimensions, const int dims_size,
189
+ intptr_t *steps, const int steps_size,
190
+ xnd_t stack[], const int outer_dims,
191
+ ndt_context_t *ctx);
192
+
193
+ GM_API int gm_np_map(const gm_strided_kernel_t f,
194
+ char **args, int nargs,
195
+ intptr_t *dimensions,
196
+ intptr_t *steps,
197
+ void *data,
198
+ int outer_dims);
199
+
200
+
201
+ /******************************************************************************/
202
+ /* Xnd loops */
203
+ /******************************************************************************/
204
+
205
+ GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
206
+ GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
207
+ const int outer_dims, ndt_context_t *ctx);
208
+
209
+
210
+ /******************************************************************************/
211
+ /* Gufunc table */
212
+ /******************************************************************************/
213
+ GM_API gm_tbl_t *gm_tbl_new(ndt_context_t *ctx);
214
+ GM_API void gm_tbl_del(gm_tbl_t *t);
215
+
216
+ GM_API int gm_tbl_add(gm_tbl_t *tbl, const char *key, gm_func_t *value, ndt_context_t *ctx);
217
+ GM_API gm_func_t *gm_tbl_find(const gm_tbl_t *tbl, const char *key, ndt_context_t *ctx);
218
+ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *state), void *state);
219
+
220
+
221
+ /******************************************************************************/
222
+ /* Library initialization and tables */
223
+ /******************************************************************************/
224
+
225
+ GM_API void gm_init(void);
226
+ GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
227
+ GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
228
+ GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
229
+
230
+ GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
231
+ GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
232
+
233
+ GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
234
+ GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
235
+ GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
236
+ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
237
+
238
+ GM_API void gm_finalize(void);
239
+
240
+
241
+ #ifdef __cplusplus
242
+ } /* END extern "C" */
243
+ #endif
244
+
245
+
246
+ #endif /* GUMATH_H */
@@ -0,0 +1 @@
1
+ ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3
@@ -0,0 +1 @@
1
+ ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3
@@ -49,10 +49,15 @@ extern "C" {
49
49
  /* Exposed here for the benefit of Numba. The API should not be regarded
50
50
  stable across versions. */
51
51
 
52
+ #define GM_CPU_FUNC 0x0001U
53
+ #define GM_CUDA_MANAGED_FUNC 0x0002U
54
+
52
55
  typedef struct {
53
56
  PyObject_HEAD
54
57
  const gm_tbl_t *tbl; /* kernel table */
58
+ uint32_t flags; /* memory target */
55
59
  char *name; /* function name */
60
+ PyObject *identity; /* identity element */
56
61
  } GufuncObject;
57
62
 
58
63
 
@@ -60,21 +65,45 @@ typedef struct {
60
65
  /* Capsule API */
61
66
  /****************************************************************************/
62
67
 
63
- #define Gumath_AddFunctions_INDEX 0
68
+ #define Gufunc_CheckExact_INDEX 0
69
+ #define Gufunc_CheckExact_RETURN int
70
+ #define Gufunc_CheckExact_ARGS (const PyObject *)
71
+
72
+ #define Gufunc_Check_INDEX 1
73
+ #define Gufunc_Check_RETURN int
74
+ #define Gufunc_Check_ARGS (const PyObject *)
75
+
76
+ #define Gumath_AddFunctions_INDEX 2
64
77
  #define Gumath_AddFunctions_RETURN int
65
78
  #define Gumath_AddFunctions_ARGS (PyObject *, const gm_tbl_t *)
66
79
 
67
- #define GUMATH_MAX_API 1
80
+ #define Gumath_AddCudaFunctions_INDEX 3
81
+ #define Gumath_AddCudaFunctions_RETURN int
82
+ #define Gumath_AddCudaFunctions_ARGS (PyObject *, const gm_tbl_t *)
83
+
84
+ #define GUMATH_MAX_API 4
68
85
 
69
86
 
70
87
  #ifdef GUMATH_MODULE
88
+ static Gufunc_CheckExact_RETURN Gufunc_CheckExact Gufunc_CheckExact_ARGS;
89
+ static Gufunc_Check_RETURN Gufunc_Check Gufunc_Check_ARGS;
71
90
  static Gumath_AddFunctions_RETURN Gumath_AddFunctions Gumath_AddFunctions_ARGS;
91
+ static Gumath_AddCudaFunctions_RETURN Gumath_AddCudaFunctions Gumath_AddCudaFunctions_ARGS;
72
92
  #else
73
93
  static void **_gumath_api;
74
94
 
95
+ #define Gufunc_CheckExact \
96
+ (*(Gufunc_CheckExact_RETURN (*)Gufunc_CheckExact_ARGS) _gumath_api[Gufunc_CheckExact_INDEX])
97
+
98
+ #define Gufunc_Check \
99
+ (*(Gufunc_Check_RETURN (*)Gufunc_Check_ARGS) _gumath_api[Gufunc_Check_INDEX])
100
+
75
101
  #define Gumath_AddFunctions \
76
102
  (*(Gumath_AddFunctions_RETURN (*)Gumath_AddFunctions_ARGS) _gumath_api[Gumath_AddFunctions_INDEX])
77
103
 
104
+ #define Gumath_AddCudaFunctions \
105
+ (*(Gumath_AddCudaFunctions_RETURN (*)Gumath_AddCudaFunctions_ARGS) _gumath_api[Gumath_AddCudaFunctions_INDEX])
106
+
78
107
 
79
108
  static int
80
109
  import_gumath(void)
@@ -0,0 +1,767 @@
1
+ #
2
+ # BSD 3-Clause License
3
+ #
4
+ # Copyright (c) 2017-2018, plures
5
+ # All rights reserved.
6
+ #
7
+ # Redistribution and use in source and binary forms, with or without
8
+ # modification, are permitted provided that the following conditions are met:
9
+ #
10
+ # 1. Redistributions of source code must retain the above copyright notice,
11
+ # this list of conditions and the following disclaimer.
12
+ #
13
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ # this list of conditions and the following disclaimer in the documentation
15
+ # and/or other materials provided with the distribution.
16
+ #
17
+ # 3. Neither the name of the copyright holder nor the names of its
18
+ # contributors may be used to endorse or promote products derived from
19
+ # this software without specific prior written permission.
20
+ #
21
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ #
32
+
33
+ # Python NDarray and functions for generating test cases.
34
+
35
+ from itertools import accumulate, count, product
36
+ from random import randrange, sample
37
+ from collections import namedtuple
38
+ import math
39
+ import struct
40
+ import unittest
41
+ from randdec import all_unary, all_binary
42
+ from randfloat import un_randfloat, bin_randfloat
43
+ import numpy as np
44
+
45
+
46
+ def skip_if(condition, reason):
47
+ if condition:
48
+ raise unittest.SkipTest(reason)
49
+
50
+
51
+ # ======================================================================
52
+ # Minimal test cases
53
+ # ======================================================================
54
+
55
+ TEST_CASES = [
56
+ ([float(i)/100.0 for i in range(2000)], "2000 * float64", "float64"),
57
+
58
+ ([[float(i)/100.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
59
+ "2 * 1000 * float64", "float64"),
60
+
61
+ (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float64", "float64"),
62
+
63
+ ([float(i)/10.0 for i in range(2000)], "2000 * float32", "float32"),
64
+
65
+ ([[float(i)/10.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
66
+ "2 * 1000 * float32", "float32"),
67
+
68
+ (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float32", "float32"),
69
+ ]
70
+
71
+
72
+ # ======================================================================
73
+ # Definition of generalized slicing and indexing
74
+ # ======================================================================
75
+
76
+ def have_none(lst):
77
+ if isinstance(lst, (list, tuple)):
78
+ return any(have_none(item) for item in lst)
79
+ if isinstance(lst, dict):
80
+ return any(have_none(item) for item in lst.values())
81
+ return lst is None
82
+
83
+ def sinrec(lst):
84
+ if isinstance(lst, list):
85
+ return [sinrec(item) for item in lst]
86
+ elif isinstance(lst, (int, type(None))):
87
+ return None if lst is None else math.sin(lst)
88
+ else:
89
+ raise TypeError("unexpected operand type '%s'" % type(lst))
90
+
91
+ def mulrec(lst1, lst2):
92
+ if isinstance(lst1, list) and isinstance(lst2, list):
93
+ return [mulrec(*pair) for pair in zip(lst1, lst2)]
94
+ elif isinstance(lst1, (int, type(None))) and isinstance(lst2, (int, type(None))):
95
+ return None if lst1 is None or lst2 is None else lst1 * lst2
96
+ else:
97
+ raise TypeError("unexpected operand types '%s', '%s'" %
98
+ (type(lst1), type(lst2)))
99
+
100
+
101
+ def maxlevel(lst):
102
+ """Return maximum nesting depth"""
103
+ maxlev = 0
104
+ def f(lst, level):
105
+ nonlocal maxlev
106
+ if isinstance(lst, list):
107
+ level += 1
108
+ maxlev = max(level, maxlev)
109
+ for item in lst:
110
+ f(item, level)
111
+ f(lst, 0)
112
+ return maxlev
113
+
114
+ def getitem(lst, indices):
115
+ """Definition for multidimensional slicing and indexing on arbitrarily
116
+ shaped nested lists.
117
+ """
118
+ if not indices:
119
+ return lst
120
+
121
+ i, indices = indices[0], indices[1:]
122
+ item = list.__getitem__(lst, i)
123
+
124
+ if isinstance(i, int):
125
+ return getitem(item, indices)
126
+
127
+ # Empty slice: check if all subsequent indices are in range for the
128
+ # full slice, raise IndexError otherwise. This is NumPy's behavior.
129
+ if not item:
130
+ if lst:
131
+ _ = getitem(lst, (slice(None),) + indices)
132
+ elif any(isinstance(k, int) for k in indices):
133
+ raise IndexError
134
+ return []
135
+
136
+ return [getitem(x, indices) for x in item]
137
+
138
+ class NDArray(list):
139
+ """A simple wrapper for using generalized slicing/indexing on a list."""
140
+ def __init__(self, value, dtype=None):
141
+ list.__init__(self, value)
142
+ self.maxlevel = maxlevel(value)
143
+
144
+ def __getitem__(self, indices):
145
+ if not isinstance(indices, tuple):
146
+ indices = (indices,)
147
+
148
+ if len(indices) > self.maxlevel: # NumPy
149
+ raise IndexError("too many indices")
150
+
151
+ if not all(isinstance(i, (int, slice)) for i in indices):
152
+ raise TypeError(
153
+ "index must be int or slice or a tuple of integers and slices")
154
+
155
+ result = getitem(self, indices)
156
+ return NDArray(result) if isinstance(result, list) else result
157
+
158
+ def sin(self):
159
+ return NDArray(sinrec(self))
160
+
161
+ def __mul__(self, other):
162
+ return NDArray(mulrec(self, other))
163
+
164
+
165
+
166
+ # ======================================================================
167
+ # Generate test cases
168
+ # ======================================================================
169
+
170
+ SUBSCRIPT_FIXED_TEST_CASES = [
171
+ [],
172
+ [[]],
173
+ [[], []],
174
+ [[0], [1]],
175
+ [[0], [1], [2]],
176
+ [[0, 1], [1, 2], [2 ,3]],
177
+ [[[]]],
178
+ [[[0]]],
179
+ [[[], []]],
180
+ [[[0], [1]]],
181
+ [[[0, 1], [2, 3]]],
182
+ [[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
183
+ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
184
+ ]
185
+
186
+ SUBSCRIPT_VAR_TEST_CASES = [
187
+ [[[0, 1], [2, 3]], [[4, 5, 6], [7]], [[8, 9]]],
188
+ [[[0, 1], [2, 3]], [[4, 5, None], [None], [7]], [[], [None, 8]], [[9, 10]]],
189
+ [[[0, 1, 2], [3, 4, 5, 6], [7, 8, 9, 10]], [[11, 12, 13, 14], [15, 16, 17], [18, 19]]],
190
+ [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9]], [[10, 11]]]
191
+ ]
192
+
193
+ def single_fixed(max_ndim=4, min_shape=1, max_shape=10):
194
+ nat = count()
195
+ shape = [randrange(min_shape, max_shape+1) for _ in range(max_ndim)]
196
+
197
+ def f(ndim):
198
+ if ndim == 0:
199
+ return next(nat)
200
+ return [f(ndim-1) for _ in range(shape[ndim-1])]
201
+
202
+ return f(max_ndim)
203
+
204
+ def gen_fixed(max_ndim=4, min_shape=1, max_shape=10):
205
+ assert max_ndim >=0 and min_shape >=0 and min_shape <= max_shape
206
+
207
+ for _ in range(30):
208
+ yield single_fixed(max_ndim, min_shape, max_shape)
209
+
210
+ def single_var(max_ndim=4, min_shape=1, max_shape=10):
211
+ nat = count()
212
+
213
+ def f(ndim):
214
+ if ndim == 0:
215
+ return next(nat)
216
+ if ndim == 1:
217
+ shape = randrange(min_shape, max_shape+1)
218
+ else:
219
+ n = 1 if min_shape == 0 else min_shape
220
+ shape = randrange(n, max_shape+1)
221
+ return [f(ndim-1) for _ in range(shape)]
222
+
223
+ return f(max_ndim)
224
+
225
+ def gen_var(max_ndim=4, min_shape=1, max_shape=10):
226
+ assert max_ndim >=0 and min_shape >=0 and min_shape <= max_shape
227
+
228
+ for _ in range(30):
229
+ yield single_var(max_ndim, min_shape, max_shape)
230
+
231
+
232
+ def genindices():
233
+ for i in range(4):
234
+ yield (i,)
235
+ for i in range(4):
236
+ for j in range(4):
237
+ yield (i, j)
238
+ for i in range(4):
239
+ for j in range(4):
240
+ for k in range(4):
241
+ yield (i, j, k)
242
+
243
+ def rslice(ndim):
244
+ start = randrange(0, ndim+1)
245
+ stop = randrange(0, ndim+1)
246
+ step = 0
247
+ while step == 0:
248
+ step = randrange(-ndim-1, ndim+1)
249
+ start = None if randrange(5) == 4 else start
250
+ stop = None if randrange(5) == 4 else stop
251
+ step = None if randrange(5) == 4 else step
252
+ return slice(start, stop, step)
253
+
254
+ def rslice_neg(ndim):
255
+ start = randrange(-ndim-1, ndim+1)
256
+ stop = randrange(-ndim-1, ndim+1)
257
+ step = 0
258
+ while step == 0:
259
+ step = randrange(-ndim-1, ndim+1)
260
+ return slice(start, stop, step)
261
+
262
+ def multislice(ndim):
263
+ return tuple(rslice(ndim) for _ in range(randrange(1, ndim+1)))
264
+
265
+ def randslices(ndim):
266
+ for i in range(5):
267
+ yield multislice(ndim)
268
+
269
+ def gen_indices_or_slices():
270
+ for i in range(5):
271
+ if randrange(2):
272
+ yield (randrange(4), randrange(4), randrange(4))
273
+ else:
274
+ yield multislice(3)
275
+
276
+ def genslices(n):
277
+ """Generate all possible slices for a single dimension."""
278
+ def range_with_none():
279
+ yield None
280
+ yield from range(-n, n+1)
281
+
282
+ for t in product(range_with_none(), range_with_none(), range_with_none()):
283
+ s = slice(*t)
284
+ if s.step != 0:
285
+ yield s
286
+
287
+ def genslices_ndim(ndim, shape):
288
+ """Generate all possible slice tuples for 'shape'."""
289
+ iterables = [genslices(shape[n]) for n in range(ndim)]
290
+ yield from product(*iterables)
291
+
292
+ def mixed_index(max_ndim):
293
+ ndim = randrange(1, max_ndim+1)
294
+ indices = []
295
+ for i in range(1, ndim+1):
296
+ if randrange(2):
297
+ indices.append(randrange(-max_ndim, max_ndim))
298
+ else:
299
+ indices.append(rslice(ndim))
300
+ return tuple(indices)
301
+
302
+ def mixed_index_neg(max_ndim):
303
+ ndim = randrange(1, max_ndim+1)
304
+ indices = []
305
+ for i in range(1, ndim+1):
306
+ if randrange(2):
307
+ indices.append(randrange(-max_ndim, max_ndim))
308
+ else:
309
+ indices.append(rslice_neg(ndim))
310
+ return tuple(indices)
311
+
312
+ def mixed_indices(max_ndim):
313
+ for i in range(5):
314
+ yield mixed_index(max_ndim)
315
+ for i in range(5):
316
+ yield mixed_index_neg(max_ndim)
317
+
318
+ def itos(indices):
319
+ return ", ".join(str(i) if isinstance(i, int) else "%s:%s:%s" %
320
+ (i.start, i.stop, i.step) for i in indices)
321
+
322
+
323
+ # ======================================================================
324
+ # Split a shape into N almost equal slices
325
+ # ======================================================================
326
+
327
+ def start(i, r, q):
328
+ return i*(q+1) if i < r else r+i*q
329
+
330
+ def stop(i, r, q):
331
+ return (i+1)*(q+1) if i < r else r+(i+1)*q
332
+
333
+ def step(i, r, q):
334
+ return q+1 if i < r else q
335
+
336
+ def sl(i, r, q):
337
+ return slice(start(i, r, q), stop(i, r, q))
338
+
339
+ def prepend(x, xs):
340
+ return [(x,) + t for t in xs]
341
+
342
+ def last_column(i, r, q, n):
343
+ return [(sl(i, r, q),) for i in range(n)]
344
+
345
+ def schedule(n, shape):
346
+ assert isinstance(n, int) and isinstance(shape, list)
347
+ if (n <= 0):
348
+ raise ValueError("n must be greater than zero")
349
+ if shape == []:
350
+ return [()]
351
+ m, ms = shape[0], shape[1:]
352
+ if (m <= 0):
353
+ raise ValueError("shape must be greater than zero")
354
+ if n <= m:
355
+ q, r = divmod(m, n)
356
+ return last_column(0, r, q, n)
357
+ else:
358
+ q, r = divmod(n, m)
359
+ return column(0, r, q, m, ms)
360
+
361
+ def column(i, r, q, m, ms):
362
+ if i == m: return []
363
+ return prepend(slice(i, i+1),
364
+ schedule(step(i, r, q), ms)) + \
365
+ column(i+1, r, q, m, ms)
366
+
367
+ # ======================================================================
368
+ # Split an xnd object into N subtrees
369
+ # ======================================================================
370
+
371
+ def zero_in_shape(shape):
372
+ for i in shape:
373
+ if i == 0:
374
+ return True
375
+ return False
376
+
377
+ def split_xnd(x, n, max_outer=None):
378
+ shape = list(x.type.shape)
379
+ if zero_in_shape(shape):
380
+ raise ValueError("split does not support zeros in shape")
381
+ if max_outer is not None:
382
+ shape = shape[:max_outer]
383
+ indices_list = schedule(n, shape)
384
+ return [x[i] for i in indices_list]
385
+
386
+
387
+ # ======================================================================
388
+ # Generate test cases
389
+ # ======================================================================
390
+
391
+ functions = {
392
+ "unary": {
393
+ "default": ["copy", "abs"],
394
+ "arith": ["negative"],
395
+ "complex_math_with_half": ["exp", "log", "log10", "sqrt", "sin", "cos"],
396
+ "complex_math": ["tan", "asin", "acos", "atan", "sinh", "cosh", "tanh",
397
+ "asinh", "acosh", "atanh"],
398
+ "real_math_with_half": ["fabs", "exp2", "log2"],
399
+ "real_math": ["expm1", "log1p", "logb", "cbrt", "erf", "erfc", "lgamma",
400
+ "tgamma", "ceil", "floor", "trunc", "round", "nearbyint"],
401
+ "bitwise": ["invert"],
402
+ },
403
+ "binary": {
404
+ "default": ["add", "subtract", "multiply", "floor_divide", "remainder", "power"],
405
+ "float_result": ["divide"],
406
+ "bool_result": ["less_equal", "less", "greater_equal", "greater", "equal", "not_equal"],
407
+ "bitwise": ["bitwise_and", "bitwise_or", "bitwise_xor"]
408
+ },
409
+ "binary_mv": {
410
+ "default": ["divmod"],
411
+ }
412
+ }
413
+
414
+ def complex_noimpl(name):
415
+ return name in functions["unary"]["real_math"] or \
416
+ name in functions["unary"]["real_math_with_half"]
417
+
418
+ def half_noimpl(name):
419
+ return name in functions["unary"]["real_math"] or \
420
+ name in functions["unary"]["complex_math"] or \
421
+ name in ("floor_divide", "remainder")
422
+
423
+ tunsigned = ["bool", "uint8", "uint16", "uint32", "uint64"]
424
+ tsigned = ["int8", "int16", "int32", "int64"]
425
+ tfloat = ["bfloat16", "float16", "float32", "float64"]
426
+ tcomplex = ["complex32", "complex64", "complex128"]
427
+
428
+ tinfo = {
429
+ "bool": (0, 1, 0),
430
+ "uint8": (0, 2**8-1, 0),
431
+ "uint16": (0, 2**16-1, 0),
432
+ "uint32": (0, 2**32-1, 0),
433
+ "uint64": (0, 2**64-1, 0),
434
+ "int8": (-2**7, 2**7-1, 0),
435
+ "int16": (-2**15, 2**15-1, 0),
436
+ "int32": (-2**31, 2**31-1, 0),
437
+ "int64": (-2**63, 2**63-1, 0),
438
+ "float16": (-2**11, 2**11, 15),
439
+ "bfloat16": (-2**8, 2**8, 127),
440
+ "float32": (-2**24, 2**24, 127),
441
+ "float64": (-2**53, 2**53, 1023),
442
+ "complex32": (-2**11, 2**11, 15),
443
+ "complex64": (-2**24, 2**24, 127),
444
+ "complex128": (-2**53, 2**53, 1023)
445
+ }
446
+
447
+ class Tint(object):
448
+ def __init__(self, type):
449
+ if type not in tunsigned + tsigned:
450
+ raise ValueError("not an integer type: '%s'" % type)
451
+ self.type = type
452
+ self.min, self.max, self.exp = tinfo[type]
453
+ self.all = (self.type, self.min, self.max, self.exp)
454
+ def __repr__(self):
455
+ return self.type
456
+ def __eq__(self, other):
457
+ return isinstance(Tint, other) and self.all == other.all
458
+ def __hash__(self):
459
+ return hash(self.all)
460
+ def testcases(self):
461
+ yield 0
462
+ yield self.min
463
+ yield self.max
464
+ for i in range(10):
465
+ yield randrange(self.min, self.max+1)
466
+ def cpu_noimpl(self, f=None):
467
+ return False
468
+ def cpu_nokern(self, f=None):
469
+ return False
470
+ def cuda_noimpl(self, f=None):
471
+ return False
472
+ def cuda_nokern(self, f=None):
473
+ return False
474
+
475
+ class Tfloat(object):
476
+ def __init__(self, type):
477
+ if type not in tfloat:
478
+ raise ValueError("not a float type: '%s'" % type)
479
+ self.type = type
480
+ self.min, self.max, self.exp = tinfo[type]
481
+ self.all = (self.type, self.min, self.max, self.exp)
482
+ def __repr__(self):
483
+ return self.type
484
+ def __eq__(self, other):
485
+ return isinstance(Tint, other) and self.all == other.all
486
+ def __hash__(self):
487
+ return hash(self.all)
488
+ def testcases(self):
489
+ yield 0
490
+ yield 0.5
491
+ yield -0.5
492
+ yield self.min
493
+ yield self.max
494
+ prec = randrange(1, 10)
495
+ for v in all_unary(prec, self.exp, 1):
496
+ yield float(v)
497
+ for v in un_randfloat():
498
+ yield float(v)
499
+ def cpu_noimpl(self, f=None):
500
+ return self.type == "float16"
501
+ def cpu_nokern(self, f=None):
502
+ return False
503
+ def cuda_noimpl(self, f=None):
504
+ if self.type == "float16":
505
+ return half_noimpl(f)
506
+ def cuda_nokern(self, f=None):
507
+ return False
508
+
509
+ class Tcomplex(object):
510
+ def __init__(self, type):
511
+ if type not in tcomplex:
512
+ raise ValueError("not a complex type: '%s'" % type)
513
+ self.type = type
514
+ self.min, self.max, self.exp = tinfo[type]
515
+ self.all = (self.type, self.min, self.max, self.exp)
516
+ def __repr__(self):
517
+ return self.type
518
+ def __eq__(self, other):
519
+ return isinstance(Tint, other) and self.all == other.all
520
+ def __hash__(self):
521
+ return hash(self.all)
522
+ def testcases(self):
523
+ yield 0
524
+ yield 0.5
525
+ yield -0.5
526
+ yield 0.5j
527
+ yield -0.5j
528
+ yield self.min
529
+ yield self.max
530
+ prec = randrange(1, 10)
531
+ for v, w in all_binary(prec, self.exp, 1):
532
+ yield complex(float(v), float(w))
533
+ for v, w in bin_randfloat():
534
+ yield complex(float(v), float(w))
535
+ def cpu_noimpl(self, f=None):
536
+ if self.type == "complex32":
537
+ return True
538
+ return complex_noimpl(f)
539
+ def cpu_nokern(self, f=None):
540
+ return f in ("floor_divide", "remainder")
541
+ def cuda_noimpl(self, f=None):
542
+ if self.type == "complex32":
543
+ return True
544
+ return complex_noimpl(f)
545
+ def cuda_nokern(self, f=None):
546
+ return f in ("floor_divide", "remainder")
547
+
548
+
549
+ tinfo_default = [
550
+ Tint("uint8"),
551
+ Tint("uint16"),
552
+ Tint("uint32"),
553
+ Tint("uint64"),
554
+ Tint("int8"),
555
+ Tint("int16"),
556
+ Tint("int32"),
557
+ Tint("int64"),
558
+ Tfloat("float16"),
559
+ Tfloat("bfloat16"),
560
+ Tfloat("float32"),
561
+ Tfloat("float64"),
562
+ Tcomplex("complex32"),
563
+ Tcomplex("complex64"),
564
+ Tcomplex("complex128")
565
+ ]
566
+
567
+ tinfo_bitwise = [
568
+ Tint("bool"),
569
+ Tint("uint8"),
570
+ Tint("uint16"),
571
+ Tint("uint32"),
572
+ Tint("uint64"),
573
+ Tint("int8"),
574
+ Tint("int16"),
575
+ Tint("int32"),
576
+ Tint("int64")
577
+ ]
578
+
579
+ implemented_sigs = {
580
+ "unary": {
581
+ "default": {}, "float_result": {}
582
+ },
583
+ "binary": {
584
+ "default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
585
+ },
586
+ "binary_mv": {
587
+ "default": {
588
+ (Tint("uint8"), Tint("uint8")): (Tint("uint8"), Tint("uint8")),
589
+ (Tint("uint16"), Tint("uint16")): (Tint("uint16"), Tint("uint16")),
590
+ (Tint("uint32"), Tint("uint32")): (Tint("uint32"), Tint("uint32")),
591
+ (Tint("uint64"), Tint("uint64")): (Tint("uint64"), Tint("uint64")),
592
+ (Tint("int8"), Tint("int8")): (Tint("int8"), Tint("int8")),
593
+ (Tint("int16"), Tint("int16")): (Tint("int16"), Tint("int16")),
594
+ (Tint("int32"), Tint("int32")): (Tint("int32"), Tint("int32")),
595
+ (Tint("int64"), Tint("int64")): (Tint("int64"), Tint("int64")),
596
+ (Tfloat("float32"), Tfloat("float32")): (Tfloat("float32"), Tfloat("float32")),
597
+ (Tfloat("float64"), Tfloat("float64")): (Tfloat("float64"), Tfloat("float64"))
598
+ },
599
+ }
600
+ }
601
+
602
+ exact_sigs = {
603
+ "unary": {
604
+ "default": {}, "float_result": {}
605
+ },
606
+ "binary": {
607
+ "default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
608
+ }
609
+ }
610
+
611
+ inexact_sigs = {
612
+ "unary": {
613
+ "default": {}, "float_result": {}
614
+ },
615
+ "binary": {
616
+ "default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
617
+ }
618
+ }
619
+
620
+ def init_unary_cast(pattern, tinfo, rank):
621
+ t = tinfo[rank]
622
+
623
+ start = max(8, rank) if pattern == "float_result" else rank
624
+ found_cast = False
625
+
626
+ for i in range(start, len(tinfo_default)):
627
+ cast = tinfo[i]
628
+ if cast.min <= t.min and t.max <= cast.max:
629
+ if found_cast or (t.type=="bfloat16") != (cast.type=="bfloat16"):
630
+ exact_sigs["unary"][pattern][(t,)] = cast
631
+ else:
632
+ found_cast = True
633
+ implemented_sigs["unary"][pattern][(t,)] = cast
634
+ exact_sigs["unary"][pattern][(t,)] = cast
635
+ else:
636
+ inexact_sigs["unary"][pattern][(t,)] = cast
637
+
638
+ def init_unary_cast_tbl(pattern):
639
+ if pattern == "default":
640
+ tinfo = [Tint("bool")] + tinfo_default
641
+ elif pattern == "float_result":
642
+ tinfo = tinfo_default
643
+ elif pattern == "bitwise":
644
+ tinfo = tinfo_bitwise
645
+ else:
646
+ raise ValueError("unsupported function type '%s'" % func)
647
+
648
+ for rank, _ in enumerate(tinfo):
649
+ init_unary_cast(pattern, tinfo, rank)
650
+
651
+ def is_binary_common_cast(cast, t, u):
652
+ if cast.min <= t.min and t.max <= cast.max and \
653
+ cast.min <= u.min and u.max <= cast.max:
654
+ if isinstance(cast, Tfloat):
655
+ return t.exp <= cast.exp and u.exp <= cast.exp
656
+ else:
657
+ return True
658
+ return False
659
+
660
+ def init_binary_cast(pattern, tinfo, rank1, rank2):
661
+ min_rank = min(rank1, rank2)
662
+ max_rank = max(rank1, rank2)
663
+
664
+ t = tinfo[min_rank]
665
+ u = tinfo[max_rank]
666
+
667
+ start = max(8, max_rank) if pattern == "float_result" else max_rank
668
+ smallest_common_cast = False
669
+
670
+ for i in range(start, len(tinfo_default)):
671
+ common_cast = tinfo_default[i]
672
+ w = Tint("bool") if pattern == "bool_result" else common_cast
673
+ if is_binary_common_cast(common_cast, t, u):
674
+ if smallest_common_cast:
675
+ exact_sigs["binary"][pattern][(t, u)] = w
676
+ else:
677
+ smallest_common_cast = True
678
+ implemented_sigs["binary"][pattern][(t, u)] = w
679
+ exact_sigs["binary"][pattern][(t, u)] = w
680
+ else:
681
+ inexact_sigs["binary"][pattern][(t, u)] = w
682
+
683
+ def init_binary_cast_tbl(pattern):
684
+ if pattern == "default" or pattern == "float_result" or pattern == "bool_result":
685
+ tinfo = tinfo_default
686
+ elif pattern == "bitwise":
687
+ tinfo = tinfo_bitwise
688
+ else:
689
+ raise ValueError("unsupported function type '%s'" % pattern)
690
+
691
+ for rank1, _ in enumerate(tinfo):
692
+ for rank2, _ in enumerate(tinfo):
693
+ init_binary_cast(pattern, tinfo, rank1, rank2)
694
+
695
+ _struct_format = {
696
+ "float16": "e",
697
+ "float32": "f",
698
+ "float64": "d",
699
+ "complex32": "e",
700
+ "complex64": "f",
701
+ "complex128": "d"
702
+ }
703
+
704
+ def roundtrip_ne(v, fmt):
705
+ if fmt == "e":
706
+ try:
707
+ struct.pack(fmt, v)
708
+ except (OverflowError, struct.error):
709
+ return True
710
+ else:
711
+ return False
712
+ else:
713
+ if math.isinf(v):
714
+ return False
715
+ s = struct.unpack(fmt, struct.pack(fmt, v))[0]
716
+ return math.isinf(float(s))
717
+
718
+ def struct_overflow(v, t):
719
+ try:
720
+ fmt = _struct_format[t.type]
721
+ except KeyError:
722
+ return False
723
+
724
+ if isinstance(t, Tcomplex):
725
+ return roundtrip_ne(v.real, fmt) or roundtrip_ne(v.imag, fmt)
726
+ else:
727
+ return roundtrip_ne(v, fmt)
728
+
729
+
730
+ init_unary_cast_tbl("default")
731
+ init_unary_cast_tbl("float_result")
732
+
733
+ init_binary_cast_tbl("default")
734
+ init_binary_cast_tbl("float_result")
735
+ init_binary_cast_tbl("bool_result")
736
+ init_binary_cast_tbl("bitwise")
737
+
738
+
739
+ _np_names = {
740
+ "asin" : "arcsin",
741
+ "acos" : "arccos",
742
+ "atan" : "arctan",
743
+ "asinh" : "arcsinh",
744
+ "acosh" : "arccosh",
745
+ "atanh" : "arctanh",
746
+ "nearbyint" : "round",
747
+ }
748
+
749
+ def np_function(name):
750
+ return _np_names.get(name, name)
751
+
752
+ def np_noimpl(name):
753
+ if name == "round":
754
+ # np.round == gumath.nearbyint
755
+ return True
756
+ try:
757
+ getattr(np, name)
758
+ return False
759
+ except AttributeError:
760
+ return True
761
+
762
+ def gen_axes(ndim):
763
+ for i in range(ndim):
764
+ yield i
765
+ lst = list(range(ndim))
766
+ for i in range(ndim):
767
+ yield tuple(sample(lst, i))