gumath 0.2.0dev5 → 0.2.0dev8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,78 @@
1
+ #include <Python.h>
2
+ #include "ndtypes.h"
3
+ #include "pyndtypes.h"
4
+ #include "gumath.h"
5
+ #include "pygumath.h"
6
+
7
+
8
+ /****************************************************************************/
9
+ /* Module globals */
10
+ /****************************************************************************/
11
+
12
+ /* Function table */
13
+ static gm_tbl_t *table = NULL;
14
+
15
+
16
+ /****************************************************************************/
17
+ /* Module */
18
+ /****************************************************************************/
19
+
20
+ static struct PyModuleDef cuda_module = {
21
+ PyModuleDef_HEAD_INIT, /* m_base */
22
+ "cuda", /* m_name */
23
+ NULL, /* m_doc */
24
+ -1, /* m_size */
25
+ NULL, /* m_methods */
26
+ NULL, /* m_slots */
27
+ NULL, /* m_traverse */
28
+ NULL, /* m_clear */
29
+ NULL /* m_free */
30
+ };
31
+
32
+
33
+ PyMODINIT_FUNC
34
+ PyInit_cuda(void)
35
+ {
36
+ NDT_STATIC_CONTEXT(ctx);
37
+ PyObject *m = NULL;
38
+ static int initialized = 0;
39
+
40
+ if (!initialized) {
41
+ if (import_ndtypes() < 0) {
42
+ return NULL;
43
+ }
44
+ if (import_gumath() < 0) {
45
+ return NULL;
46
+ }
47
+
48
+ table = gm_tbl_new(&ctx);
49
+ if (table == NULL) {
50
+ return Ndt_SetError(&ctx);
51
+ }
52
+
53
+ if (gm_init_cuda_unary_kernels(table, &ctx) < 0) {
54
+ return Ndt_SetError(&ctx);
55
+ }
56
+
57
+ if (gm_init_cuda_binary_kernels(table, &ctx) < 0) {
58
+ return Ndt_SetError(&ctx);
59
+ }
60
+
61
+ initialized = 1;
62
+ }
63
+
64
+ m = PyModule_Create(&cuda_module);
65
+ if (m == NULL) {
66
+ goto error;
67
+ }
68
+
69
+ if (Gumath_AddCudaFunctions(m, table) < 0) {
70
+ goto error;
71
+ }
72
+
73
+ return m;
74
+
75
+ error:
76
+ Py_CLEAR(m);
77
+ return NULL;
78
+ }
@@ -56,11 +56,6 @@ PyInit_examples(void)
56
56
  }
57
57
 
58
58
  /* extending examples */
59
- #ifndef _MSC_VER
60
- if (gm_init_bfloat16_kernels(table, &ctx) < 0) {
61
- return Ndt_SetError(&ctx);
62
- }
63
- #endif
64
59
  if (gm_init_graph_kernels(table, &ctx) < 0) {
65
60
  return Ndt_SetError(&ctx);
66
61
  }
@@ -50,10 +50,10 @@ PyInit_functions(void)
50
50
  return Ndt_SetError(&ctx);
51
51
  }
52
52
 
53
- if (gm_init_unary_kernels(table, &ctx) < 0) {
53
+ if (gm_init_cpu_unary_kernels(table, &ctx) < 0) {
54
54
  return Ndt_SetError(&ctx);
55
55
  }
56
- if (gm_init_binary_kernels(table, &ctx) < 0) {
56
+ if (gm_init_cpu_binary_kernels(table, &ctx) < 0) {
57
57
  return Ndt_SetError(&ctx);
58
58
  }
59
59
 
@@ -0,0 +1,246 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #ifndef GUMATH_H
35
+ #define GUMATH_H
36
+
37
+
38
+ #ifdef __cplusplus
39
+ extern "C" {
40
+ #endif
41
+
42
+ #ifdef __cplusplus
43
+ #include <cstdint>
44
+ #else
45
+ #include <stdint.h>
46
+ #endif
47
+
48
+ #include "ndtypes.h"
49
+ #include "xnd.h"
50
+
51
+
52
+ #ifdef _MSC_VER
53
+ #if defined (GM_EXPORT)
54
+ #define GM_API __declspec(dllexport)
55
+ #elif defined(GM_IMPORT)
56
+ #define GM_API __declspec(dllimport)
57
+ #else
58
+ #define GM_API
59
+ #endif
60
+
61
+ #ifndef GM_UNUSED
62
+ #define GM_UNUSED
63
+ #endif
64
+
65
+ #include "malloc.h"
66
+ #define ALLOCA(type, name, nmemb) type *name = _alloca(nmemb * sizeof(type))
67
+ #else
68
+ #define GM_API
69
+ #if defined(__GNUC__) && !defined(__INTEL_COMPILER)
70
+ #define GM_UNUSED __attribute__((unused))
71
+ #else
72
+ #define GM_UNUSED
73
+ #endif
74
+
75
+ #define ALLOCA(type, name, nmemb) type name[nmemb]
76
+ #endif
77
+
78
+
79
+ #define GM_MAX_KERNELS 8192
80
+ #define GM_THREAD_CUTOFF 1000000
81
+
82
+ typedef float float32_t;
83
+ typedef double float64_t;
84
+
85
+
86
+ typedef int (* gm_xnd_kernel_t)(xnd_t stack[], ndt_context_t *ctx);
87
+ typedef int (* gm_strided_kernel_t)(char **args, intptr_t *dimensions, intptr_t *steps, void *data);
88
+
89
+ /*
90
+ * Collection of specialized kernels for a single function signature.
91
+ *
92
+ * NOTE: The specialized kernel lookup scheme is transitional and may
93
+ * be replaced by something else.
94
+ *
95
+ * This should be considered as a first version of a kernel request
96
+ * protocol.
97
+ */
98
+ typedef struct {
99
+ const ndt_t *sig;
100
+ const ndt_constraint_t *constraint;
101
+
102
+ /* Xnd signatures */
103
+ gm_xnd_kernel_t OptC; /* C in inner+1 dimensions */
104
+ gm_xnd_kernel_t OptZ; /* C in inner dimensions, C or zero stride in (inner+1)th. */
105
+ gm_xnd_kernel_t OptS; /* strided in (inner+1)th. */
106
+ gm_xnd_kernel_t C; /* C in inner dimensions */
107
+ gm_xnd_kernel_t Fortran; /* Fortran in inner dimensions */
108
+ gm_xnd_kernel_t Xnd; /* selected if non-contiguous or the other fields are NULL */
109
+
110
+ /* NumPy signature */
111
+ gm_strided_kernel_t Strided;
112
+ } gm_kernel_set_t;
113
+
114
+ typedef struct {
115
+ const char *name;
116
+ const char *type;
117
+ const ndt_methods_t *meth;
118
+ } gm_typedef_init_t;
119
+
120
+ typedef struct {
121
+ const char *name;
122
+ const char *sig;
123
+ const ndt_constraint_t *constraint;
124
+ uint32_t cap;
125
+
126
+ /* Xnd signatures */
127
+ gm_xnd_kernel_t OptC;
128
+ gm_xnd_kernel_t OptZ;
129
+ gm_xnd_kernel_t OptS;
130
+ gm_xnd_kernel_t C;
131
+ gm_xnd_kernel_t Fortran;
132
+ gm_xnd_kernel_t Xnd;
133
+
134
+ /* NumPy signature */
135
+ gm_strided_kernel_t Strided;
136
+ } gm_kernel_init_t;
137
+
138
+ /* Actual kernel selected for application */
139
+ typedef struct {
140
+ uint32_t flag;
141
+ const gm_kernel_set_t *set;
142
+ } gm_kernel_t;
143
+
144
+ /* Multimethod with associated kernels */
145
+ typedef struct gm_func gm_func_t;
146
+ typedef const gm_kernel_set_t *(*gm_typecheck_t)(ndt_apply_spec_t *spec, const gm_func_t *f,
147
+ const ndt_t *in[], const int64_t li[],
148
+ int nin, int nout, bool check_broadcast,
149
+ ndt_context_t *ctx);
150
+ struct gm_func {
151
+ char *name;
152
+ gm_typecheck_t typecheck; /* Experimental optimized type-checking, may be NULL. */
153
+ int nkernels;
154
+ gm_kernel_set_t kernels[GM_MAX_KERNELS];
155
+ };
156
+
157
+
158
+ typedef struct _gm_tbl gm_tbl_t;
159
+
160
+
161
+ /******************************************************************************/
162
+ /* Functions */
163
+ /******************************************************************************/
164
+
165
+ GM_API gm_func_t *gm_func_new(const char *name, ndt_context_t *ctx);
166
+ GM_API void gm_func_del(gm_func_t *f);
167
+
168
+ GM_API gm_func_t *gm_add_func(gm_tbl_t *tbl, const char *name, ndt_context_t *ctx);
169
+ GM_API int gm_add_kernel(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx);
170
+ GM_API int gm_add_kernel_typecheck(gm_tbl_t *tbl, const gm_kernel_init_t *kernel, ndt_context_t *ctx, gm_typecheck_t f);
171
+
172
+ GM_API gm_kernel_t gm_select(ndt_apply_spec_t *spec, const gm_tbl_t *tbl, const char *name,
173
+ const ndt_t *types[], const int64_t li[], int nin, int nout,
174
+ bool check_broadcast, const xnd_t args[], ndt_context_t *ctx);
175
+ GM_API int gm_apply(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, ndt_context_t *ctx);
176
+ GM_API int gm_apply_thread(const gm_kernel_t *kernel, xnd_t stack[], int outer_dims, const int64_t nthreads, ndt_context_t *ctx);
177
+
178
+
179
+ /******************************************************************************/
180
+ /* NumPy loops */
181
+ /******************************************************************************/
182
+
183
+ GM_API int gm_np_flatten(char **args, const int nargs,
184
+ int64_t *dimensions, int64_t *strides, const xnd_t stack[],
185
+ ndt_context_t *ctx);
186
+
187
+ GM_API int gm_np_convert_xnd(char **args, const int nargs,
188
+ intptr_t *dimensions, const int dims_size,
189
+ intptr_t *steps, const int steps_size,
190
+ xnd_t stack[], const int outer_dims,
191
+ ndt_context_t *ctx);
192
+
193
+ GM_API int gm_np_map(const gm_strided_kernel_t f,
194
+ char **args, int nargs,
195
+ intptr_t *dimensions,
196
+ intptr_t *steps,
197
+ void *data,
198
+ int outer_dims);
199
+
200
+
201
+ /******************************************************************************/
202
+ /* Xnd loops */
203
+ /******************************************************************************/
204
+
205
+ GM_API int array_shape_check(xnd_t *x, const int64_t shape, ndt_context_t *ctx);
206
+ GM_API int gm_xnd_map(const gm_xnd_kernel_t f, xnd_t stack[], const int nargs,
207
+ const int outer_dims, ndt_context_t *ctx);
208
+
209
+
210
+ /******************************************************************************/
211
+ /* Gufunc table */
212
+ /******************************************************************************/
213
+ GM_API gm_tbl_t *gm_tbl_new(ndt_context_t *ctx);
214
+ GM_API void gm_tbl_del(gm_tbl_t *t);
215
+
216
+ GM_API int gm_tbl_add(gm_tbl_t *tbl, const char *key, gm_func_t *value, ndt_context_t *ctx);
217
+ GM_API gm_func_t *gm_tbl_find(const gm_tbl_t *tbl, const char *key, ndt_context_t *ctx);
218
+ GM_API int gm_tbl_map(const gm_tbl_t *tbl, int (*f)(const gm_func_t *, void *state), void *state);
219
+
220
+
221
+ /******************************************************************************/
222
+ /* Library initialization and tables */
223
+ /******************************************************************************/
224
+
225
+ GM_API void gm_init(void);
226
+ GM_API int gm_init_cpu_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
227
+ GM_API int gm_init_cpu_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
228
+ GM_API int gm_init_bitwise_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
229
+
230
+ GM_API int gm_init_cuda_unary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
231
+ GM_API int gm_init_cuda_binary_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
232
+
233
+ GM_API int gm_init_example_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
234
+ GM_API int gm_init_graph_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
235
+ GM_API int gm_init_quaternion_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
236
+ GM_API int gm_init_pdist_kernels(gm_tbl_t *tbl, ndt_context_t *ctx);
237
+
238
+ GM_API void gm_finalize(void);
239
+
240
+
241
+ #ifdef __cplusplus
242
+ } /* END extern "C" */
243
+ #endif
244
+
245
+
246
+ #endif /* GUMATH_H */
@@ -0,0 +1 @@
1
+ ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3
@@ -0,0 +1 @@
1
+ ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3
@@ -49,10 +49,15 @@ extern "C" {
49
49
  /* Exposed here for the benefit of Numba. The API should not be regarded
50
50
  stable across versions. */
51
51
 
52
+ #define GM_CPU_FUNC 0x0001U
53
+ #define GM_CUDA_MANAGED_FUNC 0x0002U
54
+
52
55
  typedef struct {
53
56
  PyObject_HEAD
54
57
  const gm_tbl_t *tbl; /* kernel table */
58
+ uint32_t flags; /* memory target */
55
59
  char *name; /* function name */
60
+ PyObject *identity; /* identity element */
56
61
  } GufuncObject;
57
62
 
58
63
 
@@ -60,21 +65,45 @@ typedef struct {
60
65
  /* Capsule API */
61
66
  /****************************************************************************/
62
67
 
63
- #define Gumath_AddFunctions_INDEX 0
68
+ #define Gufunc_CheckExact_INDEX 0
69
+ #define Gufunc_CheckExact_RETURN int
70
+ #define Gufunc_CheckExact_ARGS (const PyObject *)
71
+
72
+ #define Gufunc_Check_INDEX 1
73
+ #define Gufunc_Check_RETURN int
74
+ #define Gufunc_Check_ARGS (const PyObject *)
75
+
76
+ #define Gumath_AddFunctions_INDEX 2
64
77
  #define Gumath_AddFunctions_RETURN int
65
78
  #define Gumath_AddFunctions_ARGS (PyObject *, const gm_tbl_t *)
66
79
 
67
- #define GUMATH_MAX_API 1
80
+ #define Gumath_AddCudaFunctions_INDEX 3
81
+ #define Gumath_AddCudaFunctions_RETURN int
82
+ #define Gumath_AddCudaFunctions_ARGS (PyObject *, const gm_tbl_t *)
83
+
84
+ #define GUMATH_MAX_API 4
68
85
 
69
86
 
70
87
  #ifdef GUMATH_MODULE
88
+ static Gufunc_CheckExact_RETURN Gufunc_CheckExact Gufunc_CheckExact_ARGS;
89
+ static Gufunc_Check_RETURN Gufunc_Check Gufunc_Check_ARGS;
71
90
  static Gumath_AddFunctions_RETURN Gumath_AddFunctions Gumath_AddFunctions_ARGS;
91
+ static Gumath_AddCudaFunctions_RETURN Gumath_AddCudaFunctions Gumath_AddCudaFunctions_ARGS;
72
92
  #else
73
93
  static void **_gumath_api;
74
94
 
95
+ #define Gufunc_CheckExact \
96
+ (*(Gufunc_CheckExact_RETURN (*)Gufunc_CheckExact_ARGS) _gumath_api[Gufunc_CheckExact_INDEX])
97
+
98
+ #define Gufunc_Check \
99
+ (*(Gufunc_Check_RETURN (*)Gufunc_Check_ARGS) _gumath_api[Gufunc_Check_INDEX])
100
+
75
101
  #define Gumath_AddFunctions \
76
102
  (*(Gumath_AddFunctions_RETURN (*)Gumath_AddFunctions_ARGS) _gumath_api[Gumath_AddFunctions_INDEX])
77
103
 
104
+ #define Gumath_AddCudaFunctions \
105
+ (*(Gumath_AddCudaFunctions_RETURN (*)Gumath_AddCudaFunctions_ARGS) _gumath_api[Gumath_AddCudaFunctions_INDEX])
106
+
78
107
 
79
108
  static int
80
109
  import_gumath(void)
@@ -0,0 +1,767 @@
1
+ #
2
+ # BSD 3-Clause License
3
+ #
4
+ # Copyright (c) 2017-2018, plures
5
+ # All rights reserved.
6
+ #
7
+ # Redistribution and use in source and binary forms, with or without
8
+ # modification, are permitted provided that the following conditions are met:
9
+ #
10
+ # 1. Redistributions of source code must retain the above copyright notice,
11
+ # this list of conditions and the following disclaimer.
12
+ #
13
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ # this list of conditions and the following disclaimer in the documentation
15
+ # and/or other materials provided with the distribution.
16
+ #
17
+ # 3. Neither the name of the copyright holder nor the names of its
18
+ # contributors may be used to endorse or promote products derived from
19
+ # this software without specific prior written permission.
20
+ #
21
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ #
32
+
33
+ # Python NDarray and functions for generating test cases.
34
+
35
+ from itertools import accumulate, count, product
36
+ from random import randrange, sample
37
+ from collections import namedtuple
38
+ import math
39
+ import struct
40
+ import unittest
41
+ from randdec import all_unary, all_binary
42
+ from randfloat import un_randfloat, bin_randfloat
43
+ import numpy as np
44
+
45
+
46
+ def skip_if(condition, reason):
47
+ if condition:
48
+ raise unittest.SkipTest(reason)
49
+
50
+
51
+ # ======================================================================
52
+ # Minimal test cases
53
+ # ======================================================================
54
+
55
+ TEST_CASES = [
56
+ ([float(i)/100.0 for i in range(2000)], "2000 * float64", "float64"),
57
+
58
+ ([[float(i)/100.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
59
+ "2 * 1000 * float64", "float64"),
60
+
61
+ (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float64", "float64"),
62
+
63
+ ([float(i)/10.0 for i in range(2000)], "2000 * float32", "float32"),
64
+
65
+ ([[float(i)/10.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
66
+ "2 * 1000 * float32", "float32"),
67
+
68
+ (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float32", "float32"),
69
+ ]
70
+
71
+
72
+ # ======================================================================
73
+ # Definition of generalized slicing and indexing
74
+ # ======================================================================
75
+
76
+ def have_none(lst):
77
+ if isinstance(lst, (list, tuple)):
78
+ return any(have_none(item) for item in lst)
79
+ if isinstance(lst, dict):
80
+ return any(have_none(item) for item in lst.values())
81
+ return lst is None
82
+
83
+ def sinrec(lst):
84
+ if isinstance(lst, list):
85
+ return [sinrec(item) for item in lst]
86
+ elif isinstance(lst, (int, type(None))):
87
+ return None if lst is None else math.sin(lst)
88
+ else:
89
+ raise TypeError("unexpected operand type '%s'" % type(lst))
90
+
91
+ def mulrec(lst1, lst2):
92
+ if isinstance(lst1, list) and isinstance(lst2, list):
93
+ return [mulrec(*pair) for pair in zip(lst1, lst2)]
94
+ elif isinstance(lst1, (int, type(None))) and isinstance(lst2, (int, type(None))):
95
+ return None if lst1 is None or lst2 is None else lst1 * lst2
96
+ else:
97
+ raise TypeError("unexpected operand types '%s', '%s'" %
98
+ (type(lst1), type(lst2)))
99
+
100
+
101
+ def maxlevel(lst):
102
+ """Return maximum nesting depth"""
103
+ maxlev = 0
104
+ def f(lst, level):
105
+ nonlocal maxlev
106
+ if isinstance(lst, list):
107
+ level += 1
108
+ maxlev = max(level, maxlev)
109
+ for item in lst:
110
+ f(item, level)
111
+ f(lst, 0)
112
+ return maxlev
113
+
114
+ def getitem(lst, indices):
115
+ """Definition for multidimensional slicing and indexing on arbitrarily
116
+ shaped nested lists.
117
+ """
118
+ if not indices:
119
+ return lst
120
+
121
+ i, indices = indices[0], indices[1:]
122
+ item = list.__getitem__(lst, i)
123
+
124
+ if isinstance(i, int):
125
+ return getitem(item, indices)
126
+
127
+ # Empty slice: check if all subsequent indices are in range for the
128
+ # full slice, raise IndexError otherwise. This is NumPy's behavior.
129
+ if not item:
130
+ if lst:
131
+ _ = getitem(lst, (slice(None),) + indices)
132
+ elif any(isinstance(k, int) for k in indices):
133
+ raise IndexError
134
+ return []
135
+
136
+ return [getitem(x, indices) for x in item]
137
+
138
+ class NDArray(list):
139
+ """A simple wrapper for using generalized slicing/indexing on a list."""
140
+ def __init__(self, value, dtype=None):
141
+ list.__init__(self, value)
142
+ self.maxlevel = maxlevel(value)
143
+
144
+ def __getitem__(self, indices):
145
+ if not isinstance(indices, tuple):
146
+ indices = (indices,)
147
+
148
+ if len(indices) > self.maxlevel: # NumPy
149
+ raise IndexError("too many indices")
150
+
151
+ if not all(isinstance(i, (int, slice)) for i in indices):
152
+ raise TypeError(
153
+ "index must be int or slice or a tuple of integers and slices")
154
+
155
+ result = getitem(self, indices)
156
+ return NDArray(result) if isinstance(result, list) else result
157
+
158
+ def sin(self):
159
+ return NDArray(sinrec(self))
160
+
161
+ def __mul__(self, other):
162
+ return NDArray(mulrec(self, other))
163
+
164
+
165
+
166
+ # ======================================================================
167
+ # Generate test cases
168
+ # ======================================================================
169
+
170
+ SUBSCRIPT_FIXED_TEST_CASES = [
171
+ [],
172
+ [[]],
173
+ [[], []],
174
+ [[0], [1]],
175
+ [[0], [1], [2]],
176
+ [[0, 1], [1, 2], [2 ,3]],
177
+ [[[]]],
178
+ [[[0]]],
179
+ [[[], []]],
180
+ [[[0], [1]]],
181
+ [[[0, 1], [2, 3]]],
182
+ [[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
183
+ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
184
+ ]
185
+
186
+ SUBSCRIPT_VAR_TEST_CASES = [
187
+ [[[0, 1], [2, 3]], [[4, 5, 6], [7]], [[8, 9]]],
188
+ [[[0, 1], [2, 3]], [[4, 5, None], [None], [7]], [[], [None, 8]], [[9, 10]]],
189
+ [[[0, 1, 2], [3, 4, 5, 6], [7, 8, 9, 10]], [[11, 12, 13, 14], [15, 16, 17], [18, 19]]],
190
+ [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9]], [[10, 11]]]
191
+ ]
192
+
193
+ def single_fixed(max_ndim=4, min_shape=1, max_shape=10):
194
+ nat = count()
195
+ shape = [randrange(min_shape, max_shape+1) for _ in range(max_ndim)]
196
+
197
+ def f(ndim):
198
+ if ndim == 0:
199
+ return next(nat)
200
+ return [f(ndim-1) for _ in range(shape[ndim-1])]
201
+
202
+ return f(max_ndim)
203
+
204
+ def gen_fixed(max_ndim=4, min_shape=1, max_shape=10):
205
+ assert max_ndim >=0 and min_shape >=0 and min_shape <= max_shape
206
+
207
+ for _ in range(30):
208
+ yield single_fixed(max_ndim, min_shape, max_shape)
209
+
210
+ def single_var(max_ndim=4, min_shape=1, max_shape=10):
211
+ nat = count()
212
+
213
+ def f(ndim):
214
+ if ndim == 0:
215
+ return next(nat)
216
+ if ndim == 1:
217
+ shape = randrange(min_shape, max_shape+1)
218
+ else:
219
+ n = 1 if min_shape == 0 else min_shape
220
+ shape = randrange(n, max_shape+1)
221
+ return [f(ndim-1) for _ in range(shape)]
222
+
223
+ return f(max_ndim)
224
+
225
+ def gen_var(max_ndim=4, min_shape=1, max_shape=10):
226
+ assert max_ndim >=0 and min_shape >=0 and min_shape <= max_shape
227
+
228
+ for _ in range(30):
229
+ yield single_var(max_ndim, min_shape, max_shape)
230
+
231
+
232
+ def genindices():
233
+ for i in range(4):
234
+ yield (i,)
235
+ for i in range(4):
236
+ for j in range(4):
237
+ yield (i, j)
238
+ for i in range(4):
239
+ for j in range(4):
240
+ for k in range(4):
241
+ yield (i, j, k)
242
+
243
+ def rslice(ndim):
244
+ start = randrange(0, ndim+1)
245
+ stop = randrange(0, ndim+1)
246
+ step = 0
247
+ while step == 0:
248
+ step = randrange(-ndim-1, ndim+1)
249
+ start = None if randrange(5) == 4 else start
250
+ stop = None if randrange(5) == 4 else stop
251
+ step = None if randrange(5) == 4 else step
252
+ return slice(start, stop, step)
253
+
254
+ def rslice_neg(ndim):
255
+ start = randrange(-ndim-1, ndim+1)
256
+ stop = randrange(-ndim-1, ndim+1)
257
+ step = 0
258
+ while step == 0:
259
+ step = randrange(-ndim-1, ndim+1)
260
+ return slice(start, stop, step)
261
+
262
+ def multislice(ndim):
263
+ return tuple(rslice(ndim) for _ in range(randrange(1, ndim+1)))
264
+
265
+ def randslices(ndim):
266
+ for i in range(5):
267
+ yield multislice(ndim)
268
+
269
+ def gen_indices_or_slices():
270
+ for i in range(5):
271
+ if randrange(2):
272
+ yield (randrange(4), randrange(4), randrange(4))
273
+ else:
274
+ yield multislice(3)
275
+
276
+ def genslices(n):
277
+ """Generate all possible slices for a single dimension."""
278
+ def range_with_none():
279
+ yield None
280
+ yield from range(-n, n+1)
281
+
282
+ for t in product(range_with_none(), range_with_none(), range_with_none()):
283
+ s = slice(*t)
284
+ if s.step != 0:
285
+ yield s
286
+
287
+ def genslices_ndim(ndim, shape):
288
+ """Generate all possible slice tuples for 'shape'."""
289
+ iterables = [genslices(shape[n]) for n in range(ndim)]
290
+ yield from product(*iterables)
291
+
292
+ def mixed_index(max_ndim):
293
+ ndim = randrange(1, max_ndim+1)
294
+ indices = []
295
+ for i in range(1, ndim+1):
296
+ if randrange(2):
297
+ indices.append(randrange(-max_ndim, max_ndim))
298
+ else:
299
+ indices.append(rslice(ndim))
300
+ return tuple(indices)
301
+
302
+ def mixed_index_neg(max_ndim):
303
+ ndim = randrange(1, max_ndim+1)
304
+ indices = []
305
+ for i in range(1, ndim+1):
306
+ if randrange(2):
307
+ indices.append(randrange(-max_ndim, max_ndim))
308
+ else:
309
+ indices.append(rslice_neg(ndim))
310
+ return tuple(indices)
311
+
312
+ def mixed_indices(max_ndim):
313
+ for i in range(5):
314
+ yield mixed_index(max_ndim)
315
+ for i in range(5):
316
+ yield mixed_index_neg(max_ndim)
317
+
318
+ def itos(indices):
319
+ return ", ".join(str(i) if isinstance(i, int) else "%s:%s:%s" %
320
+ (i.start, i.stop, i.step) for i in indices)
321
+
322
+
323
+ # ======================================================================
324
+ # Split a shape into N almost equal slices
325
+ # ======================================================================
326
+
327
+ def start(i, r, q):
328
+ return i*(q+1) if i < r else r+i*q
329
+
330
+ def stop(i, r, q):
331
+ return (i+1)*(q+1) if i < r else r+(i+1)*q
332
+
333
+ def step(i, r, q):
334
+ return q+1 if i < r else q
335
+
336
+ def sl(i, r, q):
337
+ return slice(start(i, r, q), stop(i, r, q))
338
+
339
+ def prepend(x, xs):
340
+ return [(x,) + t for t in xs]
341
+
342
+ def last_column(i, r, q, n):
343
+ return [(sl(i, r, q),) for i in range(n)]
344
+
345
+ def schedule(n, shape):
346
+ assert isinstance(n, int) and isinstance(shape, list)
347
+ if (n <= 0):
348
+ raise ValueError("n must be greater than zero")
349
+ if shape == []:
350
+ return [()]
351
+ m, ms = shape[0], shape[1:]
352
+ if (m <= 0):
353
+ raise ValueError("shape must be greater than zero")
354
+ if n <= m:
355
+ q, r = divmod(m, n)
356
+ return last_column(0, r, q, n)
357
+ else:
358
+ q, r = divmod(n, m)
359
+ return column(0, r, q, m, ms)
360
+
361
+ def column(i, r, q, m, ms):
362
+ if i == m: return []
363
+ return prepend(slice(i, i+1),
364
+ schedule(step(i, r, q), ms)) + \
365
+ column(i+1, r, q, m, ms)
366
+
367
+ # ======================================================================
368
+ # Split an xnd object into N subtrees
369
+ # ======================================================================
370
+
371
+ def zero_in_shape(shape):
372
+ for i in shape:
373
+ if i == 0:
374
+ return True
375
+ return False
376
+
377
+ def split_xnd(x, n, max_outer=None):
378
+ shape = list(x.type.shape)
379
+ if zero_in_shape(shape):
380
+ raise ValueError("split does not support zeros in shape")
381
+ if max_outer is not None:
382
+ shape = shape[:max_outer]
383
+ indices_list = schedule(n, shape)
384
+ return [x[i] for i in indices_list]
385
+
386
+
387
+ # ======================================================================
388
+ # Generate test cases
389
+ # ======================================================================
390
+
391
+ functions = {
392
+ "unary": {
393
+ "default": ["copy", "abs"],
394
+ "arith": ["negative"],
395
+ "complex_math_with_half": ["exp", "log", "log10", "sqrt", "sin", "cos"],
396
+ "complex_math": ["tan", "asin", "acos", "atan", "sinh", "cosh", "tanh",
397
+ "asinh", "acosh", "atanh"],
398
+ "real_math_with_half": ["fabs", "exp2", "log2"],
399
+ "real_math": ["expm1", "log1p", "logb", "cbrt", "erf", "erfc", "lgamma",
400
+ "tgamma", "ceil", "floor", "trunc", "round", "nearbyint"],
401
+ "bitwise": ["invert"],
402
+ },
403
+ "binary": {
404
+ "default": ["add", "subtract", "multiply", "floor_divide", "remainder", "power"],
405
+ "float_result": ["divide"],
406
+ "bool_result": ["less_equal", "less", "greater_equal", "greater", "equal", "not_equal"],
407
+ "bitwise": ["bitwise_and", "bitwise_or", "bitwise_xor"]
408
+ },
409
+ "binary_mv": {
410
+ "default": ["divmod"],
411
+ }
412
+ }
413
+
414
+ def complex_noimpl(name):
415
+ return name in functions["unary"]["real_math"] or \
416
+ name in functions["unary"]["real_math_with_half"]
417
+
418
+ def half_noimpl(name):
419
+ return name in functions["unary"]["real_math"] or \
420
+ name in functions["unary"]["complex_math"] or \
421
+ name in ("floor_divide", "remainder")
422
+
423
+ tunsigned = ["bool", "uint8", "uint16", "uint32", "uint64"]
424
+ tsigned = ["int8", "int16", "int32", "int64"]
425
+ tfloat = ["bfloat16", "float16", "float32", "float64"]
426
+ tcomplex = ["complex32", "complex64", "complex128"]
427
+
428
+ tinfo = {
429
+ "bool": (0, 1, 0),
430
+ "uint8": (0, 2**8-1, 0),
431
+ "uint16": (0, 2**16-1, 0),
432
+ "uint32": (0, 2**32-1, 0),
433
+ "uint64": (0, 2**64-1, 0),
434
+ "int8": (-2**7, 2**7-1, 0),
435
+ "int16": (-2**15, 2**15-1, 0),
436
+ "int32": (-2**31, 2**31-1, 0),
437
+ "int64": (-2**63, 2**63-1, 0),
438
+ "float16": (-2**11, 2**11, 15),
439
+ "bfloat16": (-2**8, 2**8, 127),
440
+ "float32": (-2**24, 2**24, 127),
441
+ "float64": (-2**53, 2**53, 1023),
442
+ "complex32": (-2**11, 2**11, 15),
443
+ "complex64": (-2**24, 2**24, 127),
444
+ "complex128": (-2**53, 2**53, 1023)
445
+ }
446
+
447
+ class Tint(object):
448
+ def __init__(self, type):
449
+ if type not in tunsigned + tsigned:
450
+ raise ValueError("not an integer type: '%s'" % type)
451
+ self.type = type
452
+ self.min, self.max, self.exp = tinfo[type]
453
+ self.all = (self.type, self.min, self.max, self.exp)
454
+ def __repr__(self):
455
+ return self.type
456
+ def __eq__(self, other):
457
+ return isinstance(Tint, other) and self.all == other.all
458
+ def __hash__(self):
459
+ return hash(self.all)
460
+ def testcases(self):
461
+ yield 0
462
+ yield self.min
463
+ yield self.max
464
+ for i in range(10):
465
+ yield randrange(self.min, self.max+1)
466
+ def cpu_noimpl(self, f=None):
467
+ return False
468
+ def cpu_nokern(self, f=None):
469
+ return False
470
+ def cuda_noimpl(self, f=None):
471
+ return False
472
+ def cuda_nokern(self, f=None):
473
+ return False
474
+
475
+ class Tfloat(object):
476
+ def __init__(self, type):
477
+ if type not in tfloat:
478
+ raise ValueError("not a float type: '%s'" % type)
479
+ self.type = type
480
+ self.min, self.max, self.exp = tinfo[type]
481
+ self.all = (self.type, self.min, self.max, self.exp)
482
+ def __repr__(self):
483
+ return self.type
484
+ def __eq__(self, other):
485
+ return isinstance(Tint, other) and self.all == other.all
486
+ def __hash__(self):
487
+ return hash(self.all)
488
+ def testcases(self):
489
+ yield 0
490
+ yield 0.5
491
+ yield -0.5
492
+ yield self.min
493
+ yield self.max
494
+ prec = randrange(1, 10)
495
+ for v in all_unary(prec, self.exp, 1):
496
+ yield float(v)
497
+ for v in un_randfloat():
498
+ yield float(v)
499
+ def cpu_noimpl(self, f=None):
500
+ return self.type == "float16"
501
+ def cpu_nokern(self, f=None):
502
+ return False
503
+ def cuda_noimpl(self, f=None):
504
+ if self.type == "float16":
505
+ return half_noimpl(f)
506
+ def cuda_nokern(self, f=None):
507
+ return False
508
+
509
+ class Tcomplex(object):
510
+ def __init__(self, type):
511
+ if type not in tcomplex:
512
+ raise ValueError("not a complex type: '%s'" % type)
513
+ self.type = type
514
+ self.min, self.max, self.exp = tinfo[type]
515
+ self.all = (self.type, self.min, self.max, self.exp)
516
+ def __repr__(self):
517
+ return self.type
518
+ def __eq__(self, other):
519
+ return isinstance(Tint, other) and self.all == other.all
520
+ def __hash__(self):
521
+ return hash(self.all)
522
+ def testcases(self):
523
+ yield 0
524
+ yield 0.5
525
+ yield -0.5
526
+ yield 0.5j
527
+ yield -0.5j
528
+ yield self.min
529
+ yield self.max
530
+ prec = randrange(1, 10)
531
+ for v, w in all_binary(prec, self.exp, 1):
532
+ yield complex(float(v), float(w))
533
+ for v, w in bin_randfloat():
534
+ yield complex(float(v), float(w))
535
+ def cpu_noimpl(self, f=None):
536
+ if self.type == "complex32":
537
+ return True
538
+ return complex_noimpl(f)
539
+ def cpu_nokern(self, f=None):
540
+ return f in ("floor_divide", "remainder")
541
+ def cuda_noimpl(self, f=None):
542
+ if self.type == "complex32":
543
+ return True
544
+ return complex_noimpl(f)
545
+ def cuda_nokern(self, f=None):
546
+ return f in ("floor_divide", "remainder")
547
+
548
+
549
+ tinfo_default = [
550
+ Tint("uint8"),
551
+ Tint("uint16"),
552
+ Tint("uint32"),
553
+ Tint("uint64"),
554
+ Tint("int8"),
555
+ Tint("int16"),
556
+ Tint("int32"),
557
+ Tint("int64"),
558
+ Tfloat("float16"),
559
+ Tfloat("bfloat16"),
560
+ Tfloat("float32"),
561
+ Tfloat("float64"),
562
+ Tcomplex("complex32"),
563
+ Tcomplex("complex64"),
564
+ Tcomplex("complex128")
565
+ ]
566
+
567
+ tinfo_bitwise = [
568
+ Tint("bool"),
569
+ Tint("uint8"),
570
+ Tint("uint16"),
571
+ Tint("uint32"),
572
+ Tint("uint64"),
573
+ Tint("int8"),
574
+ Tint("int16"),
575
+ Tint("int32"),
576
+ Tint("int64")
577
+ ]
578
+
579
+ implemented_sigs = {
580
+ "unary": {
581
+ "default": {}, "float_result": {}
582
+ },
583
+ "binary": {
584
+ "default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
585
+ },
586
+ "binary_mv": {
587
+ "default": {
588
+ (Tint("uint8"), Tint("uint8")): (Tint("uint8"), Tint("uint8")),
589
+ (Tint("uint16"), Tint("uint16")): (Tint("uint16"), Tint("uint16")),
590
+ (Tint("uint32"), Tint("uint32")): (Tint("uint32"), Tint("uint32")),
591
+ (Tint("uint64"), Tint("uint64")): (Tint("uint64"), Tint("uint64")),
592
+ (Tint("int8"), Tint("int8")): (Tint("int8"), Tint("int8")),
593
+ (Tint("int16"), Tint("int16")): (Tint("int16"), Tint("int16")),
594
+ (Tint("int32"), Tint("int32")): (Tint("int32"), Tint("int32")),
595
+ (Tint("int64"), Tint("int64")): (Tint("int64"), Tint("int64")),
596
+ (Tfloat("float32"), Tfloat("float32")): (Tfloat("float32"), Tfloat("float32")),
597
+ (Tfloat("float64"), Tfloat("float64")): (Tfloat("float64"), Tfloat("float64"))
598
+ },
599
+ }
600
+ }
601
+
602
+ exact_sigs = {
603
+ "unary": {
604
+ "default": {}, "float_result": {}
605
+ },
606
+ "binary": {
607
+ "default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
608
+ }
609
+ }
610
+
611
+ inexact_sigs = {
612
+ "unary": {
613
+ "default": {}, "float_result": {}
614
+ },
615
+ "binary": {
616
+ "default": {}, "float_result": {}, "bool_result": {}, "bitwise": {}
617
+ }
618
+ }
619
+
620
+ def init_unary_cast(pattern, tinfo, rank):
621
+ t = tinfo[rank]
622
+
623
+ start = max(8, rank) if pattern == "float_result" else rank
624
+ found_cast = False
625
+
626
+ for i in range(start, len(tinfo_default)):
627
+ cast = tinfo[i]
628
+ if cast.min <= t.min and t.max <= cast.max:
629
+ if found_cast or (t.type=="bfloat16") != (cast.type=="bfloat16"):
630
+ exact_sigs["unary"][pattern][(t,)] = cast
631
+ else:
632
+ found_cast = True
633
+ implemented_sigs["unary"][pattern][(t,)] = cast
634
+ exact_sigs["unary"][pattern][(t,)] = cast
635
+ else:
636
+ inexact_sigs["unary"][pattern][(t,)] = cast
637
+
638
+ def init_unary_cast_tbl(pattern):
639
+ if pattern == "default":
640
+ tinfo = [Tint("bool")] + tinfo_default
641
+ elif pattern == "float_result":
642
+ tinfo = tinfo_default
643
+ elif pattern == "bitwise":
644
+ tinfo = tinfo_bitwise
645
+ else:
646
+ raise ValueError("unsupported function type '%s'" % func)
647
+
648
+ for rank, _ in enumerate(tinfo):
649
+ init_unary_cast(pattern, tinfo, rank)
650
+
651
+ def is_binary_common_cast(cast, t, u):
652
+ if cast.min <= t.min and t.max <= cast.max and \
653
+ cast.min <= u.min and u.max <= cast.max:
654
+ if isinstance(cast, Tfloat):
655
+ return t.exp <= cast.exp and u.exp <= cast.exp
656
+ else:
657
+ return True
658
+ return False
659
+
660
+ def init_binary_cast(pattern, tinfo, rank1, rank2):
661
+ min_rank = min(rank1, rank2)
662
+ max_rank = max(rank1, rank2)
663
+
664
+ t = tinfo[min_rank]
665
+ u = tinfo[max_rank]
666
+
667
+ start = max(8, max_rank) if pattern == "float_result" else max_rank
668
+ smallest_common_cast = False
669
+
670
+ for i in range(start, len(tinfo_default)):
671
+ common_cast = tinfo_default[i]
672
+ w = Tint("bool") if pattern == "bool_result" else common_cast
673
+ if is_binary_common_cast(common_cast, t, u):
674
+ if smallest_common_cast:
675
+ exact_sigs["binary"][pattern][(t, u)] = w
676
+ else:
677
+ smallest_common_cast = True
678
+ implemented_sigs["binary"][pattern][(t, u)] = w
679
+ exact_sigs["binary"][pattern][(t, u)] = w
680
+ else:
681
+ inexact_sigs["binary"][pattern][(t, u)] = w
682
+
683
+ def init_binary_cast_tbl(pattern):
684
+ if pattern == "default" or pattern == "float_result" or pattern == "bool_result":
685
+ tinfo = tinfo_default
686
+ elif pattern == "bitwise":
687
+ tinfo = tinfo_bitwise
688
+ else:
689
+ raise ValueError("unsupported function type '%s'" % pattern)
690
+
691
+ for rank1, _ in enumerate(tinfo):
692
+ for rank2, _ in enumerate(tinfo):
693
+ init_binary_cast(pattern, tinfo, rank1, rank2)
694
+
695
+ _struct_format = {
696
+ "float16": "e",
697
+ "float32": "f",
698
+ "float64": "d",
699
+ "complex32": "e",
700
+ "complex64": "f",
701
+ "complex128": "d"
702
+ }
703
+
704
+ def roundtrip_ne(v, fmt):
705
+ if fmt == "e":
706
+ try:
707
+ struct.pack(fmt, v)
708
+ except (OverflowError, struct.error):
709
+ return True
710
+ else:
711
+ return False
712
+ else:
713
+ if math.isinf(v):
714
+ return False
715
+ s = struct.unpack(fmt, struct.pack(fmt, v))[0]
716
+ return math.isinf(float(s))
717
+
718
+ def struct_overflow(v, t):
719
+ try:
720
+ fmt = _struct_format[t.type]
721
+ except KeyError:
722
+ return False
723
+
724
+ if isinstance(t, Tcomplex):
725
+ return roundtrip_ne(v.real, fmt) or roundtrip_ne(v.imag, fmt)
726
+ else:
727
+ return roundtrip_ne(v, fmt)
728
+
729
+
730
+ init_unary_cast_tbl("default")
731
+ init_unary_cast_tbl("float_result")
732
+
733
+ init_binary_cast_tbl("default")
734
+ init_binary_cast_tbl("float_result")
735
+ init_binary_cast_tbl("bool_result")
736
+ init_binary_cast_tbl("bitwise")
737
+
738
+
739
+ _np_names = {
740
+ "asin" : "arcsin",
741
+ "acos" : "arccos",
742
+ "atan" : "arctan",
743
+ "asinh" : "arcsinh",
744
+ "acosh" : "arccosh",
745
+ "atanh" : "arctanh",
746
+ "nearbyint" : "round",
747
+ }
748
+
749
+ def np_function(name):
750
+ return _np_names.get(name, name)
751
+
752
+ def np_noimpl(name):
753
+ if name == "round":
754
+ # np.round == gumath.nearbyint
755
+ return True
756
+ try:
757
+ getattr(np, name)
758
+ return False
759
+ except AttributeError:
760
+ return True
761
+
762
+ def gen_axes(ndim):
763
+ for i in range(ndim):
764
+ yield i
765
+ lst = list(range(ndim))
766
+ for i in range(ndim):
767
+ yield tuple(sample(lst, i))