gumath 0.2.0dev5 → 0.2.0dev8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,462 @@
1
+ #
2
+ # BSD 3-Clause License
3
+ #
4
+ # Copyright (c) 2017-2018, plures
5
+ # All rights reserved.
6
+ #
7
+ # Redistribution and use in source and binary forms, with or without
8
+ # modification, are permitted provided that the following conditions are met:
9
+ #
10
+ # 1. Redistributions of source code must retain the above copyright notice,
11
+ # this list of conditions and the following disclaimer.
12
+ #
13
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ # this list of conditions and the following disclaimer in the documentation
15
+ # and/or other materials provided with the distribution.
16
+ #
17
+ # 3. Neither the name of the copyright holder nor the names of its
18
+ # contributors may be used to endorse or promote products derived from
19
+ # this software without specific prior written permission.
20
+ #
21
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ #
32
+
33
+ import sys, os
34
+ os.environ["NUMPY_EXPERIMENTAL_ARRAY_FUNCTION"] = "1"
35
+
36
+ from xnd import *
37
+ import io
38
+ import argparse
39
+ import unittest
40
+ from gumath_aux import gen_fixed
41
+ from random import randrange
42
+
43
+ try:
44
+ import numpy as np
45
+ HAVE_ARRAY_FUNCTION = hasattr(np.ndarray, '__array_function__')
46
+ np.warnings.filterwarnings('ignore')
47
+ except ImportError:
48
+ np = None
49
+ HAVE_ARRAY_FUNCTION = False
50
+
51
+
52
+ unary_operators = [
53
+ '__abs__',
54
+ # '__bool__',
55
+ '__invert__',
56
+ '__neg__',
57
+ '__pos__',
58
+ ]
59
+
60
+ binary_operators = [
61
+ '__add__',
62
+ '__and__',
63
+ '__eq__',
64
+ '__floordiv__',
65
+ '__ge__',
66
+ '__gt__',
67
+ '__iadd__',
68
+ '__iand__',
69
+ '__ifloordiv__',
70
+ '__imod__',
71
+ '__imul__',
72
+ '__ior__',
73
+ # '__ipow__',
74
+ '__isub__',
75
+ '__ixor__',
76
+ '__le__',
77
+ '__lt__',
78
+ '__mod__',
79
+ '__mul__',
80
+ '__ne__',
81
+ '__or__',
82
+ '__pow__',
83
+ '__sub__',
84
+ '__xor__'
85
+ ]
86
+
87
+ binary_truediv = [
88
+ '__itruediv__',
89
+ '__truediv__',
90
+ ]
91
+
92
+
93
+ @unittest.skipIf(np is None, "test requires numpy")
94
+ class TestOperators(unittest.TestCase):
95
+
96
+ def assertStrictEqual(self, other):
97
+ self.assertTrue(self.strict_equal(other))
98
+
99
+ def test_unary(self):
100
+
101
+ a = array([20, 30, 40])
102
+ x = np.array([20, 30, 40])
103
+
104
+ for attr in unary_operators:
105
+ b = getattr(a, attr)()
106
+ y = getattr(x, attr)()
107
+ self.assertEqual(b.tolist(), y.tolist())
108
+
109
+ def test_binary(self):
110
+
111
+ x = array([20, 30, 40], dtype="int32")
112
+ y = array([3, 5, 7], dtype="int32")
113
+
114
+ a = np.array([20, 30, 40], dtype="int32")
115
+ b = np.array([3, 5, 7], dtype="int32")
116
+
117
+ for attr in binary_operators:
118
+ z = getattr(x, attr)(y)
119
+ c = getattr(a, attr)(b)
120
+ self.assertEqual(z.tolist(), c.tolist())
121
+
122
+ x = array([20, 30, 40], dtype="float64")
123
+ a = np.array([20, 30, 40], dtype="float64")
124
+
125
+ for attr in binary_truediv:
126
+ z = getattr(x, attr)(y)
127
+ c = getattr(a, attr)(b)
128
+ self.assertEqual(z.tolist(), c.tolist())
129
+
130
+
131
+ @unittest.skipIf(np is None, "test requires numpy")
132
+ class TestArrayUfunc(unittest.TestCase):
133
+
134
+ unary = ['absolute', 'absolute', 'arccos', 'arccosh', 'arcsin', 'arcsinh',
135
+ 'arctan', 'arctanh', 'cbrt', 'ceil', 'conjugate', 'cos', 'cosh',
136
+ 'degrees', 'exp', 'expm1', 'fabs', 'floor', 'frexp', 'invert', 'isfinite',
137
+ 'isinf', 'isnan', 'isnat', 'log', 'log10', 'log1p', 'log2', 'logical_not',
138
+ 'modf', 'negative', 'positive', 'rad2deg', 'radians', 'reciprocal', 'rint',
139
+ 'sign', 'signbit', 'sin', 'sinh', 'spacing', 'sqrt', 'square', 'tan',
140
+ 'tanh', 'trunc']
141
+
142
+ binary = ['add', 'arctan2', 'bitwise_and', 'bitwise_or', 'bitwise_xor',
143
+ 'copysign', 'divmod', 'equal', 'float_power', 'floor_divide',
144
+ 'fmax', 'fmin', 'fmod', 'gcd', 'greater', 'greater_equal',
145
+ 'heaviside', 'hypot', 'lcm', 'ldexp', 'left_shift', 'less',
146
+ 'less_equal', 'logaddexp', 'logaddexp2', 'logical_and', 'logical_xor',
147
+ 'matmul', 'maximum', 'minimum', 'multiply', 'nextafter', 'not_equal',
148
+ 'power', 'remainder', 'right_shift', 'subtract', 'true_divide',
149
+ 'true_divide']
150
+
151
+ def test_unary(self):
152
+ for name in self.unary:
153
+ try:
154
+ f = getattr(np, name)
155
+ except AttributeError:
156
+ continue
157
+ for lst in gen_fixed(3, 1, 5):
158
+ a = np.array(lst, dtype="float32")
159
+
160
+ np_exc = None
161
+ try:
162
+ b = f(a)
163
+ except Exception as e:
164
+ np_exc = e.__class__
165
+
166
+ x = array(lst, dtype="float32")
167
+
168
+ xnd_exc = None
169
+ try:
170
+ y = f(x)
171
+ except Exception as e:
172
+ xnd_exc = e.__class__
173
+
174
+ if np_exc or xnd_exc:
175
+ self.assertEqual(xnd_exc, np_exc)
176
+ continue
177
+
178
+ np.testing.assert_equal(y, b)
179
+
180
+ def test_binary(self):
181
+ for name in self.binary:
182
+ try:
183
+ f = getattr(np, name)
184
+ except AttributeError:
185
+ continue
186
+ for lst1 in gen_fixed(3, 1, 5):
187
+ for lst2 in gen_fixed(3, 1, 5):
188
+ a = np.array(lst1, dtype="float32")
189
+ b = np.array(lst2, dtype="float32")
190
+
191
+ np_exc = None
192
+ try:
193
+ c = f(a, b)
194
+ except Exception as e:
195
+ np_exc = e.__class__
196
+
197
+ x = array(lst1, dtype="float32")
198
+ y = array(lst2, dtype="float32")
199
+
200
+ xnd_exc = None
201
+ try:
202
+ z = f(x, y)
203
+ except Exception as e:
204
+ xnd_exc = e.__class__
205
+
206
+ if np_exc or xnd_exc:
207
+ self.assertEqual(xnd_exc, np_exc)
208
+ continue
209
+
210
+ np.testing.assert_equal(z, c)
211
+
212
+
213
+ @unittest.skipIf(not HAVE_ARRAY_FUNCTION,
214
+ "test requires numpy with __array_function__ support")
215
+ class TestArrayFunc(unittest.TestCase):
216
+
217
+ # funcs = [v for v in np.__dict__.values() if callable(v) and hasattr(v, '__wrapped__')]
218
+ binary_plus_axis = { 'tensordot': (np.ndarray, np.ndarray, int), }
219
+ binary = { "dot": (np.ndarray, np.ndarray), }
220
+
221
+ def test_tensordot(self):
222
+
223
+ def f(x):
224
+ y = np.tensordot(x, x.T)
225
+ return np.mean(np.exp(y))
226
+
227
+ x = array([[1, 2], [3, 4]], dtype="float64")
228
+ ans = f(x)
229
+ self.assertEqual(ans, 3931334297144.042)
230
+
231
+ def test_einsum(self):
232
+ # Use the examples from the numpy docs.
233
+ npa = np.arange(25).reshape(5,5)
234
+ npb = np.arange(5)
235
+ npc = np.arange(6).reshape(2,3)
236
+
237
+ a = array.from_buffer(npa)
238
+ b = array.from_buffer(npb)
239
+ c = array.from_buffer(npc)
240
+
241
+ ans = np.einsum('ii', a)
242
+ self.assertEqual(ans, 60)
243
+
244
+ ans = np.einsum('ii->i', a)
245
+ self.assertTrue(np.all(ans == (array([0, 6, 12, 18, 24]))))
246
+
247
+ ans = np.einsum(a, [0,0], [0])
248
+ self.assertTrue(np.all(ans == (array([0, 6, 12, 18, 24]))))
249
+
250
+ ans = np.diag(a)
251
+ self.assertTrue(np.all(ans == (array([0, 6, 12, 18, 24]))))
252
+
253
+ ans = np.einsum('ij,j', a, b)
254
+ self.assertTrue(np.all(ans == (array([ 30, 80, 130, 180, 230]))))
255
+
256
+ ans = np.einsum(a, [0,1], b, [1])
257
+ self.assertTrue(np.all(ans == (array([ 30, 80, 130, 180, 230]))))
258
+
259
+ ans = np.dot(a, b)
260
+ self.assertTrue(np.all(ans == (array([ 30, 80, 130, 180, 230]))))
261
+
262
+ ans = np.einsum('...j,j', a, b)
263
+ self.assertTrue(np.all(ans == (array([ 30, 80, 130, 180, 230]))))
264
+
265
+ ans = np.einsum('ji', c)
266
+ self.assertTrue(np.all(ans == (array([[0, 3], [1, 4], [2, 5]]))))
267
+
268
+ ans = np.einsum(c, [1,0])
269
+ self.assertTrue(np.all(ans == (array([[0, 3], [1, 4], [2, 5]]))))
270
+
271
+ ans = c.T
272
+ self.assertTrue(np.all(ans == (array([[0, 3], [1, 4], [2, 5]]))))
273
+
274
+ ans = np.einsum('..., ...', 3, c)
275
+ self.assertTrue(np.all(ans == (array([[ 0, 3, 6], [ 9, 12, 15]]))))
276
+
277
+ ans = np.einsum(',ij', 3, c)
278
+ self.assertTrue(np.all(ans == (array([[ 0, 3, 6], [ 9, 12, 15]]))))
279
+
280
+ ans = np.einsum(3, [Ellipsis], c, [Ellipsis])
281
+ self.assertTrue(np.all(ans == (array([[ 0, 3, 6], [ 9, 12, 15]]))))
282
+
283
+ ans = 3 * c
284
+ self.assertTrue(np.all(ans == (array([[ 0, 3, 6], [ 9, 12, 15]]))))
285
+
286
+ ans = np.einsum('i,i', b, b)
287
+ self.assertEqual(ans, 30)
288
+
289
+ ans = np.einsum(b, [0], b, [0])
290
+ self.assertEqual(ans, 30)
291
+
292
+ ans = np.inner(b, b)
293
+ self.assertEqual(ans, 30)
294
+
295
+ ans = np.einsum('i,j', np.arange(2)+1, b)
296
+ self.assertTrue(np.all(ans == (array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]))))
297
+
298
+ ans = np.einsum(np.arange(2)+1, [0], b, [1])
299
+ self.assertTrue(np.all(ans == (array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]))))
300
+
301
+ ans = np.outer(np.arange(2)+1, b)
302
+ self.assertTrue(np.all(ans == (array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]))))
303
+
304
+ ans = np.einsum('i...->...', a)
305
+ self.assertTrue(np.all(ans == (array([50, 55, 60, 65, 70]))))
306
+
307
+ ans = np.einsum(a, [0,Ellipsis], [Ellipsis])
308
+ self.assertTrue(np.all(ans == (array([50, 55, 60, 65, 70]))))
309
+
310
+ ans = np.sum(a, axis=0)
311
+ self.assertTrue(np.all(ans == (array([50, 55, 60, 65, 70]))))
312
+
313
+
314
+ npa = np.arange(60.).reshape(3,4,5)
315
+ npb = b = np.arange(24.).reshape(4,3,2)
316
+
317
+ a = array.from_buffer(npa)
318
+ b = array.from_buffer(npb)
319
+
320
+ expected = array([[ 4400., 4730.], [ 4532., 4874.], [ 4664., 5018.], [ 4796., 5162.], [ 4928., 5306.]])
321
+ ans = np.einsum('ijk,jil->kl', a, b)
322
+ self.assertTrue(np.all(ans == expected))
323
+
324
+ ans = np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
325
+ self.assertTrue(np.all(ans == expected))
326
+
327
+ ans = np.tensordot(a,b, axes=([1,0],[0,1]))
328
+ self.assertTrue(np.all(ans == expected))
329
+
330
+
331
+ npa = np.arange(6).reshape((3,2))
332
+ npb = np.arange(12).reshape((4,3))
333
+
334
+ a = array.from_buffer(npa)
335
+ b = array.from_buffer(npb)
336
+
337
+ expected = array([[10, 28, 46, 64], [13, 40, 67, 94]])
338
+ ans = np.einsum('ki,...k->i...', a, b)
339
+ self.assertTrue(np.all(ans == expected))
340
+
341
+ ans = np.einsum('ki,...k->i...', a, b)
342
+ self.assertTrue(np.all(ans == expected))
343
+
344
+ ans = np.einsum('k...,jk', a, b)
345
+ self.assertTrue(np.all(ans == expected))
346
+
347
+
348
+ npa = np.zeros((3, 3))
349
+ a = array.from_buffer(npa)
350
+
351
+ # expected = array([[ 1., 0., 0.], [ 0., 1., 0.], [ 0., 0., 1.]])
352
+ # np.einsum('ii->i', a)[:] = 1
353
+ # self.assertTrue(np.all(ans == expected))
354
+
355
+ def test_einsum_path(self):
356
+ # Use the examples from the numpy docs.
357
+ npa = np.random.rand(2, 2)
358
+ npb = np.random.rand(2, 5)
359
+ npc = np.random.rand(5, 2)
360
+
361
+ a = array.from_buffer(npa)
362
+ b = array.from_buffer(npb)
363
+ c = array.from_buffer(npc)
364
+
365
+ path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
366
+ self.assertEqual(path_info[0], ['einsum_path', (1, 2), (0, 1)])
367
+
368
+ def test_bartlett(self):
369
+ # Use the examples from the numpy docs.
370
+ expected = np.bartlett(np.array(12, dtype="int32"))
371
+ x = array(12, dtype="int32")
372
+ ans = np.bartlett(x)
373
+ self.assertIsInstance(ans, array)
374
+ np.testing.assert_equal(ans, expected)
375
+
376
+ def test_binary(self):
377
+
378
+ for name in self.binary:
379
+ f = getattr(np, name)
380
+ for lst1 in gen_fixed(3, 1, 5):
381
+ for lst2 in gen_fixed(3, 1, 5):
382
+ a = np.array(lst1, dtype="float32")
383
+ b = np.array(lst2, dtype="float32")
384
+ n = min(a.ndim, b.ndim)
385
+
386
+ np_exc = None
387
+ try:
388
+ c = f(a, b)
389
+ except Exception as e:
390
+ np_exc = e.__class__
391
+
392
+ x = array(lst1, dtype="float32")
393
+ y = array(lst2, dtype="float32")
394
+
395
+ xnd_exc = None
396
+ try:
397
+ z = f(x, y)
398
+ except Exception as e:
399
+ xnd_exc = e.__class__
400
+
401
+ if np_exc or xnd_exc:
402
+ self.assertEqual(xnd_exc, np_exc)
403
+ continue
404
+
405
+ np.testing.assert_equal(z, c)
406
+
407
+ for name in self.binary_plus_axis:
408
+ f = getattr(np, name)
409
+ for lst1 in gen_fixed(3, 1, 5):
410
+ for lst2 in gen_fixed(3, 1, 5):
411
+ a = np.array(lst1, dtype="float32")
412
+ b = np.array(lst2, dtype="float32")
413
+ n = min(a.ndim, b.ndim)
414
+ axis = randrange(n)
415
+
416
+ np_exc = None
417
+ try:
418
+ c = f(a, b, axis=axis)
419
+ except Exception as e:
420
+ np_exc = e.__class__
421
+
422
+ x = array(lst1, dtype="float32")
423
+ y = array(lst2, dtype="float32")
424
+
425
+ xnd_exc = None
426
+ try:
427
+ z = f(x, y, axis=axis)
428
+ except Exception as e:
429
+ xnd_exc = e.__class__
430
+
431
+ if np_exc or xnd_exc:
432
+ self.assertEqual(xnd_exc, np_exc)
433
+ continue
434
+
435
+ np.testing.assert_equal(z, c)
436
+
437
+
438
+ ALL_TESTS = [
439
+ TestOperators,
440
+ TestArrayUfunc,
441
+ TestArrayFunc,
442
+ ]
443
+
444
+
445
+ if __name__ == '__main__':
446
+ parser = argparse.ArgumentParser()
447
+ parser.add_argument("-f", "--failfast", action="store_true",
448
+ help="stop the test run on first error")
449
+ args = parser.parse_args()
450
+
451
+ suite = unittest.TestSuite()
452
+ loader = unittest.TestLoader()
453
+
454
+ for case in ALL_TESTS:
455
+ s = loader.loadTestsFromTestCase(case)
456
+ suite.addTest(s)
457
+
458
+ runner = unittest.TextTestRunner(failfast=args.failfast, verbosity=2)
459
+ result = runner.run(suite)
460
+ ret = not result.wasSuccessful()
461
+
462
+ sys.exit(ret)