gumath 0.2.0dev5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +7 -0
  2. data/CONTRIBUTING.md +61 -0
  3. data/Gemfile +5 -0
  4. data/History.md +0 -0
  5. data/README.md +5 -0
  6. data/Rakefile +105 -0
  7. data/ext/ruby_gumath/examples.c +126 -0
  8. data/ext/ruby_gumath/extconf.rb +97 -0
  9. data/ext/ruby_gumath/functions.c +106 -0
  10. data/ext/ruby_gumath/gufunc_object.c +79 -0
  11. data/ext/ruby_gumath/gufunc_object.h +55 -0
  12. data/ext/ruby_gumath/gumath/AUTHORS.txt +5 -0
  13. data/ext/ruby_gumath/gumath/INSTALL.txt +42 -0
  14. data/ext/ruby_gumath/gumath/LICENSE.txt +29 -0
  15. data/ext/ruby_gumath/gumath/MANIFEST.in +3 -0
  16. data/ext/ruby_gumath/gumath/Makefile.in +62 -0
  17. data/ext/ruby_gumath/gumath/README.rst +20 -0
  18. data/ext/ruby_gumath/gumath/config.guess +1530 -0
  19. data/ext/ruby_gumath/gumath/config.h.in +52 -0
  20. data/ext/ruby_gumath/gumath/config.sub +1782 -0
  21. data/ext/ruby_gumath/gumath/configure +5049 -0
  22. data/ext/ruby_gumath/gumath/configure.ac +167 -0
  23. data/ext/ruby_gumath/gumath/doc/_static/copybutton.js +66 -0
  24. data/ext/ruby_gumath/gumath/doc/conf.py +26 -0
  25. data/ext/ruby_gumath/gumath/doc/gumath/functions.rst +62 -0
  26. data/ext/ruby_gumath/gumath/doc/gumath/index.rst +26 -0
  27. data/ext/ruby_gumath/gumath/doc/index.rst +45 -0
  28. data/ext/ruby_gumath/gumath/doc/libgumath/data-structures.rst +130 -0
  29. data/ext/ruby_gumath/gumath/doc/libgumath/functions.rst +78 -0
  30. data/ext/ruby_gumath/gumath/doc/libgumath/index.rst +25 -0
  31. data/ext/ruby_gumath/gumath/doc/libgumath/kernels.rst +41 -0
  32. data/ext/ruby_gumath/gumath/doc/releases/index.rst +11 -0
  33. data/ext/ruby_gumath/gumath/install-sh +527 -0
  34. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +170 -0
  35. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +160 -0
  36. data/ext/ruby_gumath/gumath/libgumath/apply.c +201 -0
  37. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +130 -0
  38. data/ext/ruby_gumath/gumath/libgumath/extending/examples.c +176 -0
  39. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +393 -0
  40. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +140 -0
  41. data/ext/ruby_gumath/gumath/libgumath/extending/quaternion.c +156 -0
  42. data/ext/ruby_gumath/gumath/libgumath/func.c +177 -0
  43. data/ext/ruby_gumath/gumath/libgumath/gumath.h +205 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +547 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +449 -0
  46. data/ext/ruby_gumath/gumath/libgumath/nploops.c +219 -0
  47. data/ext/ruby_gumath/gumath/libgumath/tbl.c +223 -0
  48. data/ext/ruby_gumath/gumath/libgumath/thread.c +175 -0
  49. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +130 -0
  50. data/ext/ruby_gumath/gumath/python/extending.py +24 -0
  51. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +74 -0
  52. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +577 -0
  53. data/ext/ruby_gumath/gumath/python/gumath/examples.c +93 -0
  54. data/ext/ruby_gumath/gumath/python/gumath/functions.c +77 -0
  55. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +95 -0
  56. data/ext/ruby_gumath/gumath/python/test_gumath.py +405 -0
  57. data/ext/ruby_gumath/gumath/setup.py +298 -0
  58. data/ext/ruby_gumath/gumath/vcbuild/INSTALL.txt +36 -0
  59. data/ext/ruby_gumath/gumath/vcbuild/vcbuild32.bat +21 -0
  60. data/ext/ruby_gumath/gumath/vcbuild/vcbuild64.bat +21 -0
  61. data/ext/ruby_gumath/gumath/vcbuild/vcclean.bat +10 -0
  62. data/ext/ruby_gumath/gumath/vcbuild/vcdistclean.bat +11 -0
  63. data/ext/ruby_gumath/include/gumath.h +205 -0
  64. data/ext/ruby_gumath/include/ruby_gumath.h +41 -0
  65. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  66. data/ext/ruby_gumath/lib/libgumath.so +1 -0
  67. data/ext/ruby_gumath/lib/libgumath.so.0 +1 -0
  68. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  69. data/ext/ruby_gumath/ruby_gumath.c +295 -0
  70. data/ext/ruby_gumath/ruby_gumath.h +41 -0
  71. data/ext/ruby_gumath/ruby_gumath_internal.h +45 -0
  72. data/ext/ruby_gumath/util.c +68 -0
  73. data/ext/ruby_gumath/util.h +48 -0
  74. data/gumath.gemspec +47 -0
  75. data/lib/gumath.rb +7 -0
  76. data/lib/gumath/version.rb +5 -0
  77. data/lib/ruby_gumath.so +0 -0
  78. metadata +206 -0
@@ -0,0 +1,93 @@
1
+ #include <Python.h>
2
+ #include "ndtypes.h"
3
+ #include "pyndtypes.h"
4
+ #include "gumath.h"
5
+ #include "pygumath.h"
6
+
7
+
8
+ /****************************************************************************/
9
+ /* Module globals */
10
+ /****************************************************************************/
11
+
12
+ /* Function table */
13
+ static gm_tbl_t *table = NULL;
14
+
15
+
16
+ /****************************************************************************/
17
+ /* Module */
18
+ /****************************************************************************/
19
+
20
+ static struct PyModuleDef examples_module = {
21
+ PyModuleDef_HEAD_INIT, /* m_base */
22
+ "examples", /* m_name */
23
+ NULL, /* m_doc */
24
+ -1, /* m_size */
25
+ NULL, /* m_methods */
26
+ NULL, /* m_slots */
27
+ NULL, /* m_traverse */
28
+ NULL, /* m_clear */
29
+ NULL /* m_free */
30
+ };
31
+
32
+
33
+ PyMODINIT_FUNC
34
+ PyInit_examples(void)
35
+ {
36
+ NDT_STATIC_CONTEXT(ctx);
37
+ PyObject *m = NULL;
38
+ static int initialized = 0;
39
+
40
+ if (!initialized) {
41
+ if (import_ndtypes() < 0) {
42
+ return NULL;
43
+ }
44
+ if (import_gumath() < 0) {
45
+ return NULL;
46
+ }
47
+
48
+ table = gm_tbl_new(&ctx);
49
+ if (table == NULL) {
50
+ return Ndt_SetError(&ctx);
51
+ }
52
+
53
+ /* custom examples */
54
+ if (gm_init_example_kernels(table, &ctx) < 0) {
55
+ return Ndt_SetError(&ctx);
56
+ }
57
+
58
+ /* extending examples */
59
+ #ifndef _MSC_VER
60
+ if (gm_init_bfloat16_kernels(table, &ctx) < 0) {
61
+ return Ndt_SetError(&ctx);
62
+ }
63
+ #endif
64
+ if (gm_init_graph_kernels(table, &ctx) < 0) {
65
+ return Ndt_SetError(&ctx);
66
+ }
67
+ #ifndef _MSC_VER
68
+ if (gm_init_quaternion_kernels(table, &ctx) < 0) {
69
+ return Ndt_SetError(&ctx);
70
+ }
71
+ #endif
72
+ if (gm_init_pdist_kernels(table, &ctx) < 0) {
73
+ return Ndt_SetError(&ctx);
74
+ }
75
+
76
+ initialized = 1;
77
+ }
78
+
79
+ m = PyModule_Create(&examples_module);
80
+ if (m == NULL) {
81
+ goto error;
82
+ }
83
+
84
+ if (Gumath_AddFunctions(m, table) < 0) {
85
+ goto error;
86
+ }
87
+
88
+ return m;
89
+
90
+ error:
91
+ Py_CLEAR(m);
92
+ return NULL;
93
+ }
@@ -0,0 +1,77 @@
1
+ #include <Python.h>
2
+ #include "ndtypes.h"
3
+ #include "pyndtypes.h"
4
+ #include "gumath.h"
5
+ #include "pygumath.h"
6
+
7
+
8
+ /****************************************************************************/
9
+ /* Module globals */
10
+ /****************************************************************************/
11
+
12
+ /* Function table */
13
+ static gm_tbl_t *table = NULL;
14
+
15
+
16
+ /****************************************************************************/
17
+ /* Module */
18
+ /****************************************************************************/
19
+
20
+ static struct PyModuleDef functions_module = {
21
+ PyModuleDef_HEAD_INIT, /* m_base */
22
+ "functions", /* m_name */
23
+ NULL, /* m_doc */
24
+ -1, /* m_size */
25
+ NULL, /* m_methods */
26
+ NULL, /* m_slots */
27
+ NULL, /* m_traverse */
28
+ NULL, /* m_clear */
29
+ NULL /* m_free */
30
+ };
31
+
32
+
33
+ PyMODINIT_FUNC
34
+ PyInit_functions(void)
35
+ {
36
+ NDT_STATIC_CONTEXT(ctx);
37
+ PyObject *m = NULL;
38
+ static int initialized = 0;
39
+
40
+ if (!initialized) {
41
+ if (import_ndtypes() < 0) {
42
+ return NULL;
43
+ }
44
+ if (import_gumath() < 0) {
45
+ return NULL;
46
+ }
47
+
48
+ table = gm_tbl_new(&ctx);
49
+ if (table == NULL) {
50
+ return Ndt_SetError(&ctx);
51
+ }
52
+
53
+ if (gm_init_unary_kernels(table, &ctx) < 0) {
54
+ return Ndt_SetError(&ctx);
55
+ }
56
+ if (gm_init_binary_kernels(table, &ctx) < 0) {
57
+ return Ndt_SetError(&ctx);
58
+ }
59
+
60
+ initialized = 1;
61
+ }
62
+
63
+ m = PyModule_Create(&functions_module);
64
+ if (m == NULL) {
65
+ goto error;
66
+ }
67
+
68
+ if (Gumath_AddFunctions(m, table) < 0) {
69
+ goto error;
70
+ }
71
+
72
+ return m;
73
+
74
+ error:
75
+ Py_CLEAR(m);
76
+ return NULL;
77
+ }
@@ -0,0 +1,95 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #ifndef PYGUMATH_H
35
+ #define PYGUMATH_H
36
+ #ifdef __cplusplus
37
+ extern "C" {
38
+ #endif
39
+
40
+
41
+ #include <Python.h>
42
+ #include "gumath.h"
43
+
44
+
45
+ /****************************************************************************/
46
+ /* Gufunc Object */
47
+ /****************************************************************************/
48
+
49
+ /* Exposed here for the benefit of Numba. The API should not be regarded
50
+ stable across versions. */
51
+
52
+ typedef struct {
53
+ PyObject_HEAD
54
+ const gm_tbl_t *tbl; /* kernel table */
55
+ char *name; /* function name */
56
+ } GufuncObject;
57
+
58
+
59
+ /****************************************************************************/
60
+ /* Capsule API */
61
+ /****************************************************************************/
62
+
63
+ #define Gumath_AddFunctions_INDEX 0
64
+ #define Gumath_AddFunctions_RETURN int
65
+ #define Gumath_AddFunctions_ARGS (PyObject *, const gm_tbl_t *)
66
+
67
+ #define GUMATH_MAX_API 1
68
+
69
+
70
+ #ifdef GUMATH_MODULE
71
+ static Gumath_AddFunctions_RETURN Gumath_AddFunctions Gumath_AddFunctions_ARGS;
72
+ #else
73
+ static void **_gumath_api;
74
+
75
+ #define Gumath_AddFunctions \
76
+ (*(Gumath_AddFunctions_RETURN (*)Gumath_AddFunctions_ARGS) _gumath_api[Gumath_AddFunctions_INDEX])
77
+
78
+
79
+ static int
80
+ import_gumath(void)
81
+ {
82
+ _gumath_api = (void **)PyCapsule_Import("gumath._gumath._API", 0);
83
+ if (_gumath_api == NULL) {
84
+ return -1;
85
+ }
86
+
87
+ return 0;
88
+ }
89
+ #endif
90
+
91
+ #ifdef __cplusplus
92
+ }
93
+ #endif
94
+
95
+ #endif /* PYGUMATH_H */
@@ -0,0 +1,405 @@
1
+ #
2
+ # BSD 3-Clause License
3
+ #
4
+ # Copyright (c) 2017-2018, plures
5
+ # All rights reserved.
6
+ #
7
+ # Redistribution and use in source and binary forms, with or without
8
+ # modification, are permitted provided that the following conditions are met:
9
+ #
10
+ # 1. Redistributions of source code must retain the above copyright notice,
11
+ # this list of conditions and the following disclaimer.
12
+ #
13
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ # this list of conditions and the following disclaimer in the documentation
15
+ # and/or other materials provided with the distribution.
16
+ #
17
+ # 3. Neither the name of the copyright holder nor the names of its
18
+ # contributors may be used to endorse or promote products derived from
19
+ # this software without specific prior written permission.
20
+ #
21
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ #
32
+
33
+ import gumath as gm
34
+ import gumath.functions as fn
35
+ import gumath.examples as ex
36
+ from xnd import xnd
37
+ from ndtypes import ndt
38
+ from extending import Graph, bfloat16
39
+ import sys, time
40
+ import math
41
+ import unittest
42
+ import argparse
43
+
44
+ try:
45
+ import numpy as np
46
+ except ImportError:
47
+ np = None
48
+
49
+
50
+ TEST_CASES = [
51
+ ([float(i)/100.0 for i in range(2000)], "2000 * float64", "float64"),
52
+
53
+ ([[float(i)/100.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
54
+ "2 * 1000 * float64", "float64"),
55
+
56
+ (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float64", "float64"),
57
+
58
+ ([float(i)/10.0 for i in range(2000)], "2000 * float32", "float32"),
59
+
60
+ ([[float(i)/10.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
61
+ "2 * 1000 * float32", "float32"),
62
+
63
+ (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float32", "float32"),
64
+ ]
65
+
66
+
67
+ class TestCall(unittest.TestCase):
68
+
69
+ def test_sin_scalar(self):
70
+
71
+ x1 = xnd(1.2, type="float64")
72
+ y1 = fn.sin(x1)
73
+
74
+ x2 = xnd(1.23e1, type="float32")
75
+ y2 = fn.sin(x2)
76
+
77
+ if np is not None:
78
+ a1 = np.array(1.2, dtype="float64")
79
+ b1 = np.sin(a1)
80
+
81
+ a2 = np.array(1.23e1, dtype="float32")
82
+ b2 = np.sin(a2)
83
+
84
+ np.testing.assert_equal(y1.value, b1)
85
+ np.testing.assert_equal(y2.value, b2)
86
+
87
+ def test_sin(self):
88
+
89
+ for lst, t, dtype in TEST_CASES:
90
+ x = xnd(lst, type=t)
91
+ y = fn.sin(x)
92
+
93
+ if np is not None:
94
+ a = np.array(lst, dtype=dtype)
95
+ b = np.sin(a)
96
+ np.testing.assert_equal(y, b)
97
+
98
+ def test_sin_strided(self):
99
+
100
+ for lst, t, dtype in TEST_CASES:
101
+ x = xnd(lst, type=t)
102
+ if x.type.ndim < 2:
103
+ continue
104
+
105
+ y = x[::-2, ::-2]
106
+ z = fn.sin(y)
107
+
108
+ if np is not None:
109
+ a = np.array(lst, dtype=dtype)
110
+ b = a[::-2, ::-2]
111
+ c = np.sin(b)
112
+ np.testing.assert_equal(z, c)
113
+
114
+ def test_copy(self):
115
+
116
+ for lst, t, dtype in TEST_CASES:
117
+ x = xnd(lst, type=t)
118
+ y = fn.copy(x)
119
+
120
+ if np is not None:
121
+ a = np.array(lst, dtype=dtype)
122
+ b = np.copy(a)
123
+ np.testing.assert_equal(y, b)
124
+
125
+ def test_copy_strided(self):
126
+
127
+ for lst, t, dtype in TEST_CASES:
128
+ x = xnd(lst, type=t)
129
+ if x.type.ndim < 2:
130
+ continue
131
+
132
+ y = x[::-2, ::-2]
133
+ z = fn.copy(y)
134
+
135
+ if np is not None:
136
+ a = np.array(lst, dtype=dtype)
137
+ b = a[::-2, ::-2]
138
+ c = np.copy(b)
139
+ np.testing.assert_equal(y, b)
140
+
141
+ @unittest.skipIf(sys.platform == "win32", "missing C99 complex support")
142
+ def test_quaternion(self):
143
+
144
+ lst = [[[1+2j, 4+3j],
145
+ [-4+3j, 1-2j]],
146
+ [[4+2j, 1+10j],
147
+ [-1+10j, 4-2j]],
148
+ [[-4+2j, 3+10j],
149
+ [-3+10j, -4-2j]]]
150
+
151
+ x = xnd(lst, type="3 * quaternion64")
152
+ y = ex.multiply(x, x)
153
+
154
+ if np is not None:
155
+ a = np.array(lst, dtype="complex64")
156
+ b = np.einsum("ijk,ikl->ijl", a, a)
157
+ np.testing.assert_equal(y, b)
158
+
159
+ x = xnd(lst, type="3 * quaternion128")
160
+ y = ex.multiply(x, x)
161
+
162
+ if np is not None:
163
+ a = np.array(lst, dtype="complex128")
164
+ b = np.einsum("ijk,ikl->ijl", a, a)
165
+ np.testing.assert_equal(y, b)
166
+
167
+ x = xnd("xyz")
168
+ self.assertRaises(TypeError, ex.multiply, x, x)
169
+
170
+ @unittest.skipIf(sys.platform == "win32", "missing C99 complex support")
171
+ def test_quaternion_error(self):
172
+
173
+ lst = [[[1+2j, 4+3j],
174
+ [-4+3j, 1-2j]],
175
+ [[4+2j, 1+10j],
176
+ [-1+10j, 4-2j]],
177
+ [[-4+2j, 3+10j],
178
+ [-3+10j, -4-2j]]]
179
+
180
+ x = xnd(lst, type="3 * Foo(2 * 2 * complex64)")
181
+ self.assertRaises(TypeError, ex.multiply, x, x)
182
+
183
+ def test_void(self):
184
+
185
+ x = ex.randint()
186
+ self.assertEqual(x.type, ndt("int32"))
187
+
188
+ def test_multiple_return(self):
189
+
190
+ x, y = ex.randtuple()
191
+ self.assertEqual(x.type, ndt("int32"))
192
+ self.assertEqual(y.type, ndt("int32"))
193
+
194
+ x, y = ex.divmod10(xnd(233))
195
+ self.assertEqual(x.value, 23)
196
+ self.assertEqual(y.value, 3)
197
+
198
+
199
+ class TestMissingValues(unittest.TestCase):
200
+
201
+ def test_missing_values(self):
202
+
203
+ x = [{'index': 0, 'name': 'brazil', 'value': 10},
204
+ {'index': 1, 'name': 'france', 'value': None},
205
+ {'index': 1, 'name': 'russia', 'value': 2}]
206
+
207
+ y = [{'index': 0, 'name': 'iceland', 'value': 5},
208
+ {'index': 1, 'name': 'norway', 'value': None},
209
+ {'index': 1, 'name': 'italy', 'value': None}]
210
+
211
+ z = xnd([x, y], type="2 * 3 * {index: int64, name: string, value: ?int64}")
212
+ ans = ex.count_valid_missing(z)
213
+
214
+ self.assertEqual(ans.value, [{'valid': 2, 'missing': 1}, {'valid': 1, 'missing': 2}])
215
+
216
+
217
+ class TestRaggedArrays(unittest.TestCase):
218
+
219
+ def test_sin(self):
220
+ s = math.sin
221
+ lst = [[[1.0],
222
+ [2.0, 3.0],
223
+ [4.0, 5.0, 6.0]],
224
+ [[7.0],
225
+ [8.0, 9.0],
226
+ [10.0, 11.0, 12.0]]]
227
+
228
+ ans = [[[s(1.0)],
229
+ [s(2.0), s(3.0)],
230
+ [s(4.0), s(5.0), s(6.0)]],
231
+ [[s(7.0)],
232
+ [s(8.0), s(9.0)],
233
+ [s(10.0), s(11.0), s(12.0)]]]
234
+
235
+ x = xnd(lst)
236
+ y = fn.sin(x)
237
+ self.assertEqual(y.value, ans)
238
+
239
+
240
+ class TestGraphs(unittest.TestCase):
241
+
242
+ def test_shortest_path(self):
243
+ graphs = [[[(1, 1.2), (2, 4.4)],
244
+ [(2, 2.2)],
245
+ [(1, 2.3)]],
246
+
247
+ [[(1, 1.2), (2, 4.4)],
248
+ [(2, 2.2)],
249
+ [(1, 2.3)],
250
+ [(2, 1.1)]]]
251
+
252
+ ans = [[[[0], [0, 1], [0, 1, 2]], # graph1, start 0
253
+ [[], [1], [1, 2]], # graph1, start 1
254
+ [[], [2, 1], [2]]], # graph1, start 2
255
+
256
+ [[[0], [0, 1], [0, 1, 2], []], # graph2, start 0
257
+ [[], [1], [1, 2], []], # graph2, start 1
258
+ [[], [2, 1], [2], []], # graph2, start 2
259
+ [[], [3, 2, 1], [3, 2], [3]]]] # graph2, start 3
260
+
261
+
262
+ for i, lst in enumerate(graphs):
263
+ N = len(lst)
264
+ graph = Graph(lst)
265
+ for start in range(N):
266
+ node = xnd(start, type="node")
267
+ x = graph.shortest_paths(node)
268
+ self.assertEqual(x.value, ans[i][start])
269
+
270
+ def test_constraint(self):
271
+ lst = [[(0, 1.2)],
272
+ [(2, 2.2), (1, 0.1)]]
273
+
274
+ self.assertRaises(ValueError, Graph, lst)
275
+
276
+
277
+ @unittest.skipIf(sys.platform == "win32", "unresolved external symbols")
278
+ class TestBFloat16(unittest.TestCase):
279
+
280
+ def test_init(self):
281
+ lst = [1.2e10, 2.1121, -3e20]
282
+ ans = [11945377792.0, 2.109375, -2.997595911977802e+20]
283
+
284
+ x = bfloat16(lst)
285
+ self.assertEqual(x.value, ans)
286
+
287
+
288
+ class TestPdist(unittest.TestCase):
289
+
290
+ def test_exceptions(self):
291
+ x = xnd([], dtype="float64")
292
+ self.assertRaises(TypeError, ex.euclidian_pdist, x)
293
+
294
+ x = xnd([[]], dtype="float64")
295
+ self.assertRaises(TypeError, ex.euclidian_pdist, x)
296
+
297
+ x = xnd([[], []], dtype="float64")
298
+ self.assertRaises(TypeError, ex.euclidian_pdist, x)
299
+
300
+ x = xnd([[1], [1]], dtype="int64")
301
+ self.assertRaises(TypeError, ex.euclidian_pdist, x)
302
+
303
+ def test_pdist(self):
304
+ x = xnd([[1]], dtype="float64")
305
+ y = ex.euclidian_pdist(x)
306
+ self.assertEqual(y.value, [])
307
+
308
+ x = xnd([[1, 2, 3]], dtype="float64")
309
+ y = ex.euclidian_pdist(x)
310
+ self.assertEqual(y.value, [])
311
+
312
+ x = xnd([[-1.2200, -100.5000, 20.1250, 30.1230],
313
+ [ 2.2200, 2.2720, -122.8400, 122.3330],
314
+ [ 2.1000, -25.0000, 100.2000, -99.5000]], dtype="float64")
315
+ y = ex.euclidian_pdist(x)
316
+ self.assertEqual(y.value, [198.78529349275314, 170.0746899276903, 315.75385646576035])
317
+
318
+
319
+ @unittest.skipIf(gm.xndvectorize is None, "test requires numpy and numba")
320
+ class TestNumba(unittest.TestCase):
321
+
322
+ def test_numba(self):
323
+
324
+ @gm.xndvectorize("... * N * M * float64, ... * M * P * float64 -> ... * N * P * float64")
325
+ def matmul(x, y, res):
326
+ col = np.arange(y.shape[0])
327
+ for j in range(y.shape[1]):
328
+ for k in range(y.shape[0]):
329
+ col[k] = y[k, j]
330
+ for i in range(x.shape[0]):
331
+ s = 0
332
+ for k in range(x.shape[1]):
333
+ s += x[i, k] * col[k]
334
+ res[i, j] = s
335
+
336
+ a = np.arange(50000.0).reshape(1000, 5, 10)
337
+ b = np.arange(70000.0).reshape(1000, 10, 7)
338
+ c = np.einsum("ijk,ikl->ijl", a, b)
339
+
340
+ x = xnd(a.tolist(), type="1000 * 5 * 10 * float64")
341
+ y = xnd(b.tolist(), type="1000 * 10 * 7 * float64")
342
+ z = matmul(x, y)
343
+
344
+ np.testing.assert_equal(z, c)
345
+
346
+ def test_numba_add_scalar(self):
347
+
348
+ import numba as nb
349
+
350
+ @nb.guvectorize(["void(int64[:], int64, int64[:])"], '(n),()->(n)')
351
+ def g(x, y, res):
352
+ for i in range(x.shape[0]):
353
+ res[i] = x[i] + y
354
+
355
+ a = np.arange(5000).reshape(100, 5, 10)
356
+ b = np.arange(500).reshape(100, 5)
357
+ c = g(a, b)
358
+
359
+ x = xnd(a.tolist(), type="100 * 5 * 10 * int64")
360
+ y = xnd(b.tolist(), type="100 * 5 * int64")
361
+ z = ex.add_scalar(x, y)
362
+
363
+ np.testing.assert_equal(z, c)
364
+
365
+ a = np.arange(500)
366
+ b = np.array(100)
367
+ c = g(a, b)
368
+
369
+ x = xnd(a.tolist(), type="500 * int64")
370
+ y = xnd(b.tolist(), type="int64")
371
+ z = ex.add_scalar(x, y)
372
+
373
+ np.testing.assert_equal(z, c)
374
+
375
+
376
+
377
+ ALL_TESTS = [
378
+ TestCall,
379
+ TestRaggedArrays,
380
+ TestMissingValues,
381
+ TestGraphs,
382
+ TestBFloat16,
383
+ TestPdist,
384
+ TestNumba,
385
+ ]
386
+
387
+
388
+ if __name__ == '__main__':
389
+ parser = argparse.ArgumentParser()
390
+ parser.add_argument("-f", "--failfast", action="store_true",
391
+ help="stop the test run on first error")
392
+ args = parser.parse_args()
393
+
394
+ suite = unittest.TestSuite()
395
+ loader = unittest.TestLoader()
396
+
397
+ for case in ALL_TESTS:
398
+ s = loader.loadTestsFromTestCase(case)
399
+ suite.addTest(s)
400
+
401
+ runner = unittest.TextTestRunner(failfast=args.failfast, verbosity=2)
402
+ result = runner.run(suite)
403
+ ret = not result.wasSuccessful()
404
+
405
+ sys.exit(ret)