gumath 0.2.0dev5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (78) hide show
  1. checksums.yaml +7 -0
  2. data/CONTRIBUTING.md +61 -0
  3. data/Gemfile +5 -0
  4. data/History.md +0 -0
  5. data/README.md +5 -0
  6. data/Rakefile +105 -0
  7. data/ext/ruby_gumath/examples.c +126 -0
  8. data/ext/ruby_gumath/extconf.rb +97 -0
  9. data/ext/ruby_gumath/functions.c +106 -0
  10. data/ext/ruby_gumath/gufunc_object.c +79 -0
  11. data/ext/ruby_gumath/gufunc_object.h +55 -0
  12. data/ext/ruby_gumath/gumath/AUTHORS.txt +5 -0
  13. data/ext/ruby_gumath/gumath/INSTALL.txt +42 -0
  14. data/ext/ruby_gumath/gumath/LICENSE.txt +29 -0
  15. data/ext/ruby_gumath/gumath/MANIFEST.in +3 -0
  16. data/ext/ruby_gumath/gumath/Makefile.in +62 -0
  17. data/ext/ruby_gumath/gumath/README.rst +20 -0
  18. data/ext/ruby_gumath/gumath/config.guess +1530 -0
  19. data/ext/ruby_gumath/gumath/config.h.in +52 -0
  20. data/ext/ruby_gumath/gumath/config.sub +1782 -0
  21. data/ext/ruby_gumath/gumath/configure +5049 -0
  22. data/ext/ruby_gumath/gumath/configure.ac +167 -0
  23. data/ext/ruby_gumath/gumath/doc/_static/copybutton.js +66 -0
  24. data/ext/ruby_gumath/gumath/doc/conf.py +26 -0
  25. data/ext/ruby_gumath/gumath/doc/gumath/functions.rst +62 -0
  26. data/ext/ruby_gumath/gumath/doc/gumath/index.rst +26 -0
  27. data/ext/ruby_gumath/gumath/doc/index.rst +45 -0
  28. data/ext/ruby_gumath/gumath/doc/libgumath/data-structures.rst +130 -0
  29. data/ext/ruby_gumath/gumath/doc/libgumath/functions.rst +78 -0
  30. data/ext/ruby_gumath/gumath/doc/libgumath/index.rst +25 -0
  31. data/ext/ruby_gumath/gumath/doc/libgumath/kernels.rst +41 -0
  32. data/ext/ruby_gumath/gumath/doc/releases/index.rst +11 -0
  33. data/ext/ruby_gumath/gumath/install-sh +527 -0
  34. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +170 -0
  35. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +160 -0
  36. data/ext/ruby_gumath/gumath/libgumath/apply.c +201 -0
  37. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +130 -0
  38. data/ext/ruby_gumath/gumath/libgumath/extending/examples.c +176 -0
  39. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +393 -0
  40. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +140 -0
  41. data/ext/ruby_gumath/gumath/libgumath/extending/quaternion.c +156 -0
  42. data/ext/ruby_gumath/gumath/libgumath/func.c +177 -0
  43. data/ext/ruby_gumath/gumath/libgumath/gumath.h +205 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +547 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +449 -0
  46. data/ext/ruby_gumath/gumath/libgumath/nploops.c +219 -0
  47. data/ext/ruby_gumath/gumath/libgumath/tbl.c +223 -0
  48. data/ext/ruby_gumath/gumath/libgumath/thread.c +175 -0
  49. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +130 -0
  50. data/ext/ruby_gumath/gumath/python/extending.py +24 -0
  51. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +74 -0
  52. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +577 -0
  53. data/ext/ruby_gumath/gumath/python/gumath/examples.c +93 -0
  54. data/ext/ruby_gumath/gumath/python/gumath/functions.c +77 -0
  55. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +95 -0
  56. data/ext/ruby_gumath/gumath/python/test_gumath.py +405 -0
  57. data/ext/ruby_gumath/gumath/setup.py +298 -0
  58. data/ext/ruby_gumath/gumath/vcbuild/INSTALL.txt +36 -0
  59. data/ext/ruby_gumath/gumath/vcbuild/vcbuild32.bat +21 -0
  60. data/ext/ruby_gumath/gumath/vcbuild/vcbuild64.bat +21 -0
  61. data/ext/ruby_gumath/gumath/vcbuild/vcclean.bat +10 -0
  62. data/ext/ruby_gumath/gumath/vcbuild/vcdistclean.bat +11 -0
  63. data/ext/ruby_gumath/include/gumath.h +205 -0
  64. data/ext/ruby_gumath/include/ruby_gumath.h +41 -0
  65. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  66. data/ext/ruby_gumath/lib/libgumath.so +1 -0
  67. data/ext/ruby_gumath/lib/libgumath.so.0 +1 -0
  68. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  69. data/ext/ruby_gumath/ruby_gumath.c +295 -0
  70. data/ext/ruby_gumath/ruby_gumath.h +41 -0
  71. data/ext/ruby_gumath/ruby_gumath_internal.h +45 -0
  72. data/ext/ruby_gumath/util.c +68 -0
  73. data/ext/ruby_gumath/util.h +48 -0
  74. data/gumath.gemspec +47 -0
  75. data/lib/gumath.rb +7 -0
  76. data/lib/gumath/version.rb +5 -0
  77. data/lib/ruby_gumath.so +0 -0
  78. metadata +206 -0
@@ -0,0 +1,93 @@
1
+ #include <Python.h>
2
+ #include "ndtypes.h"
3
+ #include "pyndtypes.h"
4
+ #include "gumath.h"
5
+ #include "pygumath.h"
6
+
7
+
8
+ /****************************************************************************/
9
+ /* Module globals */
10
+ /****************************************************************************/
11
+
12
+ /* Function table */
13
+ static gm_tbl_t *table = NULL;
14
+
15
+
16
+ /****************************************************************************/
17
+ /* Module */
18
+ /****************************************************************************/
19
+
20
+ static struct PyModuleDef examples_module = {
21
+ PyModuleDef_HEAD_INIT, /* m_base */
22
+ "examples", /* m_name */
23
+ NULL, /* m_doc */
24
+ -1, /* m_size */
25
+ NULL, /* m_methods */
26
+ NULL, /* m_slots */
27
+ NULL, /* m_traverse */
28
+ NULL, /* m_clear */
29
+ NULL /* m_free */
30
+ };
31
+
32
+
33
+ PyMODINIT_FUNC
34
+ PyInit_examples(void)
35
+ {
36
+ NDT_STATIC_CONTEXT(ctx);
37
+ PyObject *m = NULL;
38
+ static int initialized = 0;
39
+
40
+ if (!initialized) {
41
+ if (import_ndtypes() < 0) {
42
+ return NULL;
43
+ }
44
+ if (import_gumath() < 0) {
45
+ return NULL;
46
+ }
47
+
48
+ table = gm_tbl_new(&ctx);
49
+ if (table == NULL) {
50
+ return Ndt_SetError(&ctx);
51
+ }
52
+
53
+ /* custom examples */
54
+ if (gm_init_example_kernels(table, &ctx) < 0) {
55
+ return Ndt_SetError(&ctx);
56
+ }
57
+
58
+ /* extending examples */
59
+ #ifndef _MSC_VER
60
+ if (gm_init_bfloat16_kernels(table, &ctx) < 0) {
61
+ return Ndt_SetError(&ctx);
62
+ }
63
+ #endif
64
+ if (gm_init_graph_kernels(table, &ctx) < 0) {
65
+ return Ndt_SetError(&ctx);
66
+ }
67
+ #ifndef _MSC_VER
68
+ if (gm_init_quaternion_kernels(table, &ctx) < 0) {
69
+ return Ndt_SetError(&ctx);
70
+ }
71
+ #endif
72
+ if (gm_init_pdist_kernels(table, &ctx) < 0) {
73
+ return Ndt_SetError(&ctx);
74
+ }
75
+
76
+ initialized = 1;
77
+ }
78
+
79
+ m = PyModule_Create(&examples_module);
80
+ if (m == NULL) {
81
+ goto error;
82
+ }
83
+
84
+ if (Gumath_AddFunctions(m, table) < 0) {
85
+ goto error;
86
+ }
87
+
88
+ return m;
89
+
90
+ error:
91
+ Py_CLEAR(m);
92
+ return NULL;
93
+ }
@@ -0,0 +1,77 @@
1
+ #include <Python.h>
2
+ #include "ndtypes.h"
3
+ #include "pyndtypes.h"
4
+ #include "gumath.h"
5
+ #include "pygumath.h"
6
+
7
+
8
+ /****************************************************************************/
9
+ /* Module globals */
10
+ /****************************************************************************/
11
+
12
+ /* Function table */
13
+ static gm_tbl_t *table = NULL;
14
+
15
+
16
+ /****************************************************************************/
17
+ /* Module */
18
+ /****************************************************************************/
19
+
20
+ static struct PyModuleDef functions_module = {
21
+ PyModuleDef_HEAD_INIT, /* m_base */
22
+ "functions", /* m_name */
23
+ NULL, /* m_doc */
24
+ -1, /* m_size */
25
+ NULL, /* m_methods */
26
+ NULL, /* m_slots */
27
+ NULL, /* m_traverse */
28
+ NULL, /* m_clear */
29
+ NULL /* m_free */
30
+ };
31
+
32
+
33
+ PyMODINIT_FUNC
34
+ PyInit_functions(void)
35
+ {
36
+ NDT_STATIC_CONTEXT(ctx);
37
+ PyObject *m = NULL;
38
+ static int initialized = 0;
39
+
40
+ if (!initialized) {
41
+ if (import_ndtypes() < 0) {
42
+ return NULL;
43
+ }
44
+ if (import_gumath() < 0) {
45
+ return NULL;
46
+ }
47
+
48
+ table = gm_tbl_new(&ctx);
49
+ if (table == NULL) {
50
+ return Ndt_SetError(&ctx);
51
+ }
52
+
53
+ if (gm_init_unary_kernels(table, &ctx) < 0) {
54
+ return Ndt_SetError(&ctx);
55
+ }
56
+ if (gm_init_binary_kernels(table, &ctx) < 0) {
57
+ return Ndt_SetError(&ctx);
58
+ }
59
+
60
+ initialized = 1;
61
+ }
62
+
63
+ m = PyModule_Create(&functions_module);
64
+ if (m == NULL) {
65
+ goto error;
66
+ }
67
+
68
+ if (Gumath_AddFunctions(m, table) < 0) {
69
+ goto error;
70
+ }
71
+
72
+ return m;
73
+
74
+ error:
75
+ Py_CLEAR(m);
76
+ return NULL;
77
+ }
@@ -0,0 +1,95 @@
1
+ /*
2
+ * BSD 3-Clause License
3
+ *
4
+ * Copyright (c) 2017-2018, plures
5
+ * All rights reserved.
6
+ *
7
+ * Redistribution and use in source and binary forms, with or without
8
+ * modification, are permitted provided that the following conditions are met:
9
+ *
10
+ * 1. Redistributions of source code must retain the above copyright notice,
11
+ * this list of conditions and the following disclaimer.
12
+ *
13
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ * this list of conditions and the following disclaimer in the documentation
15
+ * and/or other materials provided with the distribution.
16
+ *
17
+ * 3. Neither the name of the copyright holder nor the names of its
18
+ * contributors may be used to endorse or promote products derived from
19
+ * this software without specific prior written permission.
20
+ *
21
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ */
32
+
33
+
34
+ #ifndef PYGUMATH_H
35
+ #define PYGUMATH_H
36
+ #ifdef __cplusplus
37
+ extern "C" {
38
+ #endif
39
+
40
+
41
+ #include <Python.h>
42
+ #include "gumath.h"
43
+
44
+
45
+ /****************************************************************************/
46
+ /* Gufunc Object */
47
+ /****************************************************************************/
48
+
49
+ /* Exposed here for the benefit of Numba. The API should not be regarded
50
+ stable across versions. */
51
+
52
+ typedef struct {
53
+ PyObject_HEAD
54
+ const gm_tbl_t *tbl; /* kernel table */
55
+ char *name; /* function name */
56
+ } GufuncObject;
57
+
58
+
59
+ /****************************************************************************/
60
+ /* Capsule API */
61
+ /****************************************************************************/
62
+
63
+ #define Gumath_AddFunctions_INDEX 0
64
+ #define Gumath_AddFunctions_RETURN int
65
+ #define Gumath_AddFunctions_ARGS (PyObject *, const gm_tbl_t *)
66
+
67
+ #define GUMATH_MAX_API 1
68
+
69
+
70
+ #ifdef GUMATH_MODULE
71
+ static Gumath_AddFunctions_RETURN Gumath_AddFunctions Gumath_AddFunctions_ARGS;
72
+ #else
73
+ static void **_gumath_api;
74
+
75
+ #define Gumath_AddFunctions \
76
+ (*(Gumath_AddFunctions_RETURN (*)Gumath_AddFunctions_ARGS) _gumath_api[Gumath_AddFunctions_INDEX])
77
+
78
+
79
+ static int
80
+ import_gumath(void)
81
+ {
82
+ _gumath_api = (void **)PyCapsule_Import("gumath._gumath._API", 0);
83
+ if (_gumath_api == NULL) {
84
+ return -1;
85
+ }
86
+
87
+ return 0;
88
+ }
89
+ #endif
90
+
91
+ #ifdef __cplusplus
92
+ }
93
+ #endif
94
+
95
+ #endif /* PYGUMATH_H */
@@ -0,0 +1,405 @@
1
+ #
2
+ # BSD 3-Clause License
3
+ #
4
+ # Copyright (c) 2017-2018, plures
5
+ # All rights reserved.
6
+ #
7
+ # Redistribution and use in source and binary forms, with or without
8
+ # modification, are permitted provided that the following conditions are met:
9
+ #
10
+ # 1. Redistributions of source code must retain the above copyright notice,
11
+ # this list of conditions and the following disclaimer.
12
+ #
13
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ # this list of conditions and the following disclaimer in the documentation
15
+ # and/or other materials provided with the distribution.
16
+ #
17
+ # 3. Neither the name of the copyright holder nor the names of its
18
+ # contributors may be used to endorse or promote products derived from
19
+ # this software without specific prior written permission.
20
+ #
21
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ #
32
+
33
+ import gumath as gm
34
+ import gumath.functions as fn
35
+ import gumath.examples as ex
36
+ from xnd import xnd
37
+ from ndtypes import ndt
38
+ from extending import Graph, bfloat16
39
+ import sys, time
40
+ import math
41
+ import unittest
42
+ import argparse
43
+
44
+ try:
45
+ import numpy as np
46
+ except ImportError:
47
+ np = None
48
+
49
+
50
+ TEST_CASES = [
51
+ ([float(i)/100.0 for i in range(2000)], "2000 * float64", "float64"),
52
+
53
+ ([[float(i)/100.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
54
+ "2 * 1000 * float64", "float64"),
55
+
56
+ (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float64", "float64"),
57
+
58
+ ([float(i)/10.0 for i in range(2000)], "2000 * float32", "float32"),
59
+
60
+ ([[float(i)/10.0 for i in range(1000)], [float(i+1) for i in range(1000)]],
61
+ "2 * 1000 * float32", "float32"),
62
+
63
+ (1000 * [[float(i+1) for i in range(2)]], "1000 * 2 * float32", "float32"),
64
+ ]
65
+
66
+
67
+ class TestCall(unittest.TestCase):
68
+
69
+ def test_sin_scalar(self):
70
+
71
+ x1 = xnd(1.2, type="float64")
72
+ y1 = fn.sin(x1)
73
+
74
+ x2 = xnd(1.23e1, type="float32")
75
+ y2 = fn.sin(x2)
76
+
77
+ if np is not None:
78
+ a1 = np.array(1.2, dtype="float64")
79
+ b1 = np.sin(a1)
80
+
81
+ a2 = np.array(1.23e1, dtype="float32")
82
+ b2 = np.sin(a2)
83
+
84
+ np.testing.assert_equal(y1.value, b1)
85
+ np.testing.assert_equal(y2.value, b2)
86
+
87
+ def test_sin(self):
88
+
89
+ for lst, t, dtype in TEST_CASES:
90
+ x = xnd(lst, type=t)
91
+ y = fn.sin(x)
92
+
93
+ if np is not None:
94
+ a = np.array(lst, dtype=dtype)
95
+ b = np.sin(a)
96
+ np.testing.assert_equal(y, b)
97
+
98
+ def test_sin_strided(self):
99
+
100
+ for lst, t, dtype in TEST_CASES:
101
+ x = xnd(lst, type=t)
102
+ if x.type.ndim < 2:
103
+ continue
104
+
105
+ y = x[::-2, ::-2]
106
+ z = fn.sin(y)
107
+
108
+ if np is not None:
109
+ a = np.array(lst, dtype=dtype)
110
+ b = a[::-2, ::-2]
111
+ c = np.sin(b)
112
+ np.testing.assert_equal(z, c)
113
+
114
+ def test_copy(self):
115
+
116
+ for lst, t, dtype in TEST_CASES:
117
+ x = xnd(lst, type=t)
118
+ y = fn.copy(x)
119
+
120
+ if np is not None:
121
+ a = np.array(lst, dtype=dtype)
122
+ b = np.copy(a)
123
+ np.testing.assert_equal(y, b)
124
+
125
+ def test_copy_strided(self):
126
+
127
+ for lst, t, dtype in TEST_CASES:
128
+ x = xnd(lst, type=t)
129
+ if x.type.ndim < 2:
130
+ continue
131
+
132
+ y = x[::-2, ::-2]
133
+ z = fn.copy(y)
134
+
135
+ if np is not None:
136
+ a = np.array(lst, dtype=dtype)
137
+ b = a[::-2, ::-2]
138
+ c = np.copy(b)
139
+ np.testing.assert_equal(y, b)
140
+
141
+ @unittest.skipIf(sys.platform == "win32", "missing C99 complex support")
142
+ def test_quaternion(self):
143
+
144
+ lst = [[[1+2j, 4+3j],
145
+ [-4+3j, 1-2j]],
146
+ [[4+2j, 1+10j],
147
+ [-1+10j, 4-2j]],
148
+ [[-4+2j, 3+10j],
149
+ [-3+10j, -4-2j]]]
150
+
151
+ x = xnd(lst, type="3 * quaternion64")
152
+ y = ex.multiply(x, x)
153
+
154
+ if np is not None:
155
+ a = np.array(lst, dtype="complex64")
156
+ b = np.einsum("ijk,ikl->ijl", a, a)
157
+ np.testing.assert_equal(y, b)
158
+
159
+ x = xnd(lst, type="3 * quaternion128")
160
+ y = ex.multiply(x, x)
161
+
162
+ if np is not None:
163
+ a = np.array(lst, dtype="complex128")
164
+ b = np.einsum("ijk,ikl->ijl", a, a)
165
+ np.testing.assert_equal(y, b)
166
+
167
+ x = xnd("xyz")
168
+ self.assertRaises(TypeError, ex.multiply, x, x)
169
+
170
+ @unittest.skipIf(sys.platform == "win32", "missing C99 complex support")
171
+ def test_quaternion_error(self):
172
+
173
+ lst = [[[1+2j, 4+3j],
174
+ [-4+3j, 1-2j]],
175
+ [[4+2j, 1+10j],
176
+ [-1+10j, 4-2j]],
177
+ [[-4+2j, 3+10j],
178
+ [-3+10j, -4-2j]]]
179
+
180
+ x = xnd(lst, type="3 * Foo(2 * 2 * complex64)")
181
+ self.assertRaises(TypeError, ex.multiply, x, x)
182
+
183
+ def test_void(self):
184
+
185
+ x = ex.randint()
186
+ self.assertEqual(x.type, ndt("int32"))
187
+
188
+ def test_multiple_return(self):
189
+
190
+ x, y = ex.randtuple()
191
+ self.assertEqual(x.type, ndt("int32"))
192
+ self.assertEqual(y.type, ndt("int32"))
193
+
194
+ x, y = ex.divmod10(xnd(233))
195
+ self.assertEqual(x.value, 23)
196
+ self.assertEqual(y.value, 3)
197
+
198
+
199
+ class TestMissingValues(unittest.TestCase):
200
+
201
+ def test_missing_values(self):
202
+
203
+ x = [{'index': 0, 'name': 'brazil', 'value': 10},
204
+ {'index': 1, 'name': 'france', 'value': None},
205
+ {'index': 1, 'name': 'russia', 'value': 2}]
206
+
207
+ y = [{'index': 0, 'name': 'iceland', 'value': 5},
208
+ {'index': 1, 'name': 'norway', 'value': None},
209
+ {'index': 1, 'name': 'italy', 'value': None}]
210
+
211
+ z = xnd([x, y], type="2 * 3 * {index: int64, name: string, value: ?int64}")
212
+ ans = ex.count_valid_missing(z)
213
+
214
+ self.assertEqual(ans.value, [{'valid': 2, 'missing': 1}, {'valid': 1, 'missing': 2}])
215
+
216
+
217
+ class TestRaggedArrays(unittest.TestCase):
218
+
219
+ def test_sin(self):
220
+ s = math.sin
221
+ lst = [[[1.0],
222
+ [2.0, 3.0],
223
+ [4.0, 5.0, 6.0]],
224
+ [[7.0],
225
+ [8.0, 9.0],
226
+ [10.0, 11.0, 12.0]]]
227
+
228
+ ans = [[[s(1.0)],
229
+ [s(2.0), s(3.0)],
230
+ [s(4.0), s(5.0), s(6.0)]],
231
+ [[s(7.0)],
232
+ [s(8.0), s(9.0)],
233
+ [s(10.0), s(11.0), s(12.0)]]]
234
+
235
+ x = xnd(lst)
236
+ y = fn.sin(x)
237
+ self.assertEqual(y.value, ans)
238
+
239
+
240
+ class TestGraphs(unittest.TestCase):
241
+
242
+ def test_shortest_path(self):
243
+ graphs = [[[(1, 1.2), (2, 4.4)],
244
+ [(2, 2.2)],
245
+ [(1, 2.3)]],
246
+
247
+ [[(1, 1.2), (2, 4.4)],
248
+ [(2, 2.2)],
249
+ [(1, 2.3)],
250
+ [(2, 1.1)]]]
251
+
252
+ ans = [[[[0], [0, 1], [0, 1, 2]], # graph1, start 0
253
+ [[], [1], [1, 2]], # graph1, start 1
254
+ [[], [2, 1], [2]]], # graph1, start 2
255
+
256
+ [[[0], [0, 1], [0, 1, 2], []], # graph2, start 0
257
+ [[], [1], [1, 2], []], # graph2, start 1
258
+ [[], [2, 1], [2], []], # graph2, start 2
259
+ [[], [3, 2, 1], [3, 2], [3]]]] # graph2, start 3
260
+
261
+
262
+ for i, lst in enumerate(graphs):
263
+ N = len(lst)
264
+ graph = Graph(lst)
265
+ for start in range(N):
266
+ node = xnd(start, type="node")
267
+ x = graph.shortest_paths(node)
268
+ self.assertEqual(x.value, ans[i][start])
269
+
270
+ def test_constraint(self):
271
+ lst = [[(0, 1.2)],
272
+ [(2, 2.2), (1, 0.1)]]
273
+
274
+ self.assertRaises(ValueError, Graph, lst)
275
+
276
+
277
+ @unittest.skipIf(sys.platform == "win32", "unresolved external symbols")
278
+ class TestBFloat16(unittest.TestCase):
279
+
280
+ def test_init(self):
281
+ lst = [1.2e10, 2.1121, -3e20]
282
+ ans = [11945377792.0, 2.109375, -2.997595911977802e+20]
283
+
284
+ x = bfloat16(lst)
285
+ self.assertEqual(x.value, ans)
286
+
287
+
288
+ class TestPdist(unittest.TestCase):
289
+
290
+ def test_exceptions(self):
291
+ x = xnd([], dtype="float64")
292
+ self.assertRaises(TypeError, ex.euclidian_pdist, x)
293
+
294
+ x = xnd([[]], dtype="float64")
295
+ self.assertRaises(TypeError, ex.euclidian_pdist, x)
296
+
297
+ x = xnd([[], []], dtype="float64")
298
+ self.assertRaises(TypeError, ex.euclidian_pdist, x)
299
+
300
+ x = xnd([[1], [1]], dtype="int64")
301
+ self.assertRaises(TypeError, ex.euclidian_pdist, x)
302
+
303
+ def test_pdist(self):
304
+ x = xnd([[1]], dtype="float64")
305
+ y = ex.euclidian_pdist(x)
306
+ self.assertEqual(y.value, [])
307
+
308
+ x = xnd([[1, 2, 3]], dtype="float64")
309
+ y = ex.euclidian_pdist(x)
310
+ self.assertEqual(y.value, [])
311
+
312
+ x = xnd([[-1.2200, -100.5000, 20.1250, 30.1230],
313
+ [ 2.2200, 2.2720, -122.8400, 122.3330],
314
+ [ 2.1000, -25.0000, 100.2000, -99.5000]], dtype="float64")
315
+ y = ex.euclidian_pdist(x)
316
+ self.assertEqual(y.value, [198.78529349275314, 170.0746899276903, 315.75385646576035])
317
+
318
+
319
+ @unittest.skipIf(gm.xndvectorize is None, "test requires numpy and numba")
320
+ class TestNumba(unittest.TestCase):
321
+
322
+ def test_numba(self):
323
+
324
+ @gm.xndvectorize("... * N * M * float64, ... * M * P * float64 -> ... * N * P * float64")
325
+ def matmul(x, y, res):
326
+ col = np.arange(y.shape[0])
327
+ for j in range(y.shape[1]):
328
+ for k in range(y.shape[0]):
329
+ col[k] = y[k, j]
330
+ for i in range(x.shape[0]):
331
+ s = 0
332
+ for k in range(x.shape[1]):
333
+ s += x[i, k] * col[k]
334
+ res[i, j] = s
335
+
336
+ a = np.arange(50000.0).reshape(1000, 5, 10)
337
+ b = np.arange(70000.0).reshape(1000, 10, 7)
338
+ c = np.einsum("ijk,ikl->ijl", a, b)
339
+
340
+ x = xnd(a.tolist(), type="1000 * 5 * 10 * float64")
341
+ y = xnd(b.tolist(), type="1000 * 10 * 7 * float64")
342
+ z = matmul(x, y)
343
+
344
+ np.testing.assert_equal(z, c)
345
+
346
+ def test_numba_add_scalar(self):
347
+
348
+ import numba as nb
349
+
350
+ @nb.guvectorize(["void(int64[:], int64, int64[:])"], '(n),()->(n)')
351
+ def g(x, y, res):
352
+ for i in range(x.shape[0]):
353
+ res[i] = x[i] + y
354
+
355
+ a = np.arange(5000).reshape(100, 5, 10)
356
+ b = np.arange(500).reshape(100, 5)
357
+ c = g(a, b)
358
+
359
+ x = xnd(a.tolist(), type="100 * 5 * 10 * int64")
360
+ y = xnd(b.tolist(), type="100 * 5 * int64")
361
+ z = ex.add_scalar(x, y)
362
+
363
+ np.testing.assert_equal(z, c)
364
+
365
+ a = np.arange(500)
366
+ b = np.array(100)
367
+ c = g(a, b)
368
+
369
+ x = xnd(a.tolist(), type="500 * int64")
370
+ y = xnd(b.tolist(), type="int64")
371
+ z = ex.add_scalar(x, y)
372
+
373
+ np.testing.assert_equal(z, c)
374
+
375
+
376
+
377
+ ALL_TESTS = [
378
+ TestCall,
379
+ TestRaggedArrays,
380
+ TestMissingValues,
381
+ TestGraphs,
382
+ TestBFloat16,
383
+ TestPdist,
384
+ TestNumba,
385
+ ]
386
+
387
+
388
+ if __name__ == '__main__':
389
+ parser = argparse.ArgumentParser()
390
+ parser.add_argument("-f", "--failfast", action="store_true",
391
+ help="stop the test run on first error")
392
+ args = parser.parse_args()
393
+
394
+ suite = unittest.TestSuite()
395
+ loader = unittest.TestLoader()
396
+
397
+ for case in ALL_TESTS:
398
+ s = loader.loadTestsFromTestCase(case)
399
+ suite.addTest(s)
400
+
401
+ runner = unittest.TextTestRunner(failfast=args.failfast, verbosity=2)
402
+ result = runner.run(suite)
403
+ ret = not result.wasSuccessful()
404
+
405
+ sys.exit(ret)