gumath 0.2.0dev5 → 0.2.0dev8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (99) hide show
  1. checksums.yaml +4 -4
  2. data/CONTRIBUTING.md +7 -2
  3. data/Gemfile +0 -3
  4. data/ext/ruby_gumath/GPATH +0 -0
  5. data/ext/ruby_gumath/GRTAGS +0 -0
  6. data/ext/ruby_gumath/GTAGS +0 -0
  7. data/ext/ruby_gumath/extconf.rb +0 -5
  8. data/ext/ruby_gumath/functions.c +10 -2
  9. data/ext/ruby_gumath/gufunc_object.c +15 -4
  10. data/ext/ruby_gumath/gufunc_object.h +9 -3
  11. data/ext/ruby_gumath/gumath/Makefile +63 -0
  12. data/ext/ruby_gumath/gumath/Makefile.in +1 -0
  13. data/ext/ruby_gumath/gumath/config.h +56 -0
  14. data/ext/ruby_gumath/gumath/config.h.in +3 -0
  15. data/ext/ruby_gumath/gumath/config.log +497 -0
  16. data/ext/ruby_gumath/gumath/config.status +1034 -0
  17. data/ext/ruby_gumath/gumath/configure +375 -4
  18. data/ext/ruby_gumath/gumath/configure.ac +47 -3
  19. data/ext/ruby_gumath/gumath/libgumath/Makefile +236 -0
  20. data/ext/ruby_gumath/gumath/libgumath/Makefile.in +90 -24
  21. data/ext/ruby_gumath/gumath/libgumath/Makefile.vc +54 -15
  22. data/ext/ruby_gumath/gumath/libgumath/apply.c +92 -28
  23. data/ext/ruby_gumath/gumath/libgumath/apply.o +0 -0
  24. data/ext/ruby_gumath/gumath/libgumath/common.o +0 -0
  25. data/ext/ruby_gumath/gumath/libgumath/cpu_device_binary.o +0 -0
  26. data/ext/ruby_gumath/gumath/libgumath/cpu_device_unary.o +0 -0
  27. data/ext/ruby_gumath/gumath/libgumath/cpu_host_binary.o +0 -0
  28. data/ext/ruby_gumath/gumath/libgumath/cpu_host_unary.o +0 -0
  29. data/ext/ruby_gumath/gumath/libgumath/examples.o +0 -0
  30. data/ext/ruby_gumath/gumath/libgumath/extending/graph.c +27 -20
  31. data/ext/ruby_gumath/gumath/libgumath/extending/pdist.c +1 -1
  32. data/ext/ruby_gumath/gumath/libgumath/func.c +13 -9
  33. data/ext/ruby_gumath/gumath/libgumath/func.o +0 -0
  34. data/ext/ruby_gumath/gumath/libgumath/graph.o +0 -0
  35. data/ext/ruby_gumath/gumath/libgumath/gumath.h +55 -14
  36. data/ext/ruby_gumath/gumath/libgumath/kernels/common.c +513 -0
  37. data/ext/ruby_gumath/gumath/libgumath/kernels/common.h +155 -0
  38. data/ext/ruby_gumath/gumath/libgumath/kernels/contrib/bfloat16.h +520 -0
  39. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.cc +1123 -0
  40. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_binary.h +1062 -0
  41. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_msvc.cc +555 -0
  42. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.cc +368 -0
  43. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_device_unary.h +335 -0
  44. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_binary.c +2952 -0
  45. data/ext/ruby_gumath/gumath/libgumath/kernels/cpu_host_unary.c +1100 -0
  46. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.cu +1143 -0
  47. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_binary.h +1061 -0
  48. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.cu +528 -0
  49. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_device_unary.h +463 -0
  50. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_binary.c +2817 -0
  51. data/ext/ruby_gumath/gumath/libgumath/kernels/cuda_host_unary.c +1331 -0
  52. data/ext/ruby_gumath/gumath/libgumath/kernels/device.hh +614 -0
  53. data/ext/ruby_gumath/gumath/libgumath/libgumath.a +0 -0
  54. data/ext/ruby_gumath/gumath/libgumath/libgumath.so +1 -0
  55. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0 +1 -0
  56. data/ext/ruby_gumath/gumath/libgumath/libgumath.so.0.2.0dev3 +0 -0
  57. data/ext/ruby_gumath/gumath/libgumath/nploops.o +0 -0
  58. data/ext/ruby_gumath/gumath/libgumath/pdist.o +0 -0
  59. data/ext/ruby_gumath/gumath/libgumath/quaternion.o +0 -0
  60. data/ext/ruby_gumath/gumath/libgumath/tbl.o +0 -0
  61. data/ext/ruby_gumath/gumath/libgumath/thread.c +17 -4
  62. data/ext/ruby_gumath/gumath/libgumath/thread.o +0 -0
  63. data/ext/ruby_gumath/gumath/libgumath/xndloops.c +110 -0
  64. data/ext/ruby_gumath/gumath/libgumath/xndloops.o +0 -0
  65. data/ext/ruby_gumath/gumath/python/gumath/__init__.py +150 -0
  66. data/ext/ruby_gumath/gumath/python/gumath/_gumath.c +446 -80
  67. data/ext/ruby_gumath/gumath/python/gumath/cuda.c +78 -0
  68. data/ext/ruby_gumath/gumath/python/gumath/examples.c +0 -5
  69. data/ext/ruby_gumath/gumath/python/gumath/functions.c +2 -2
  70. data/ext/ruby_gumath/gumath/python/gumath/gumath.h +246 -0
  71. data/ext/ruby_gumath/gumath/python/gumath/libgumath.a +0 -0
  72. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so +1 -0
  73. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0 +1 -0
  74. data/ext/ruby_gumath/gumath/python/gumath/libgumath.so.0.2.0dev3 +0 -0
  75. data/ext/ruby_gumath/gumath/python/gumath/pygumath.h +31 -2
  76. data/ext/ruby_gumath/gumath/python/gumath_aux.py +767 -0
  77. data/ext/ruby_gumath/gumath/python/randdec.py +535 -0
  78. data/ext/ruby_gumath/gumath/python/randfloat.py +177 -0
  79. data/ext/ruby_gumath/gumath/python/test_gumath.py +1504 -24
  80. data/ext/ruby_gumath/gumath/python/test_xndarray.py +462 -0
  81. data/ext/ruby_gumath/gumath/setup.py +67 -6
  82. data/ext/ruby_gumath/gumath/tools/detect_cuda_arch.cc +35 -0
  83. data/ext/ruby_gumath/include/gumath.h +55 -14
  84. data/ext/ruby_gumath/include/ruby_gumath.h +4 -1
  85. data/ext/ruby_gumath/lib/libgumath.a +0 -0
  86. data/ext/ruby_gumath/lib/libgumath.so.0.2.0dev3 +0 -0
  87. data/ext/ruby_gumath/ruby_gumath.c +231 -70
  88. data/ext/ruby_gumath/ruby_gumath.h +4 -1
  89. data/ext/ruby_gumath/ruby_gumath_internal.h +25 -0
  90. data/ext/ruby_gumath/util.c +34 -0
  91. data/ext/ruby_gumath/util.h +9 -0
  92. data/gumath.gemspec +3 -2
  93. data/lib/gumath.rb +55 -1
  94. data/lib/gumath/version.rb +2 -2
  95. data/lib/ruby_gumath.so +0 -0
  96. metadata +63 -10
  97. data/ext/ruby_gumath/gumath/libgumath/extending/bfloat16.c +0 -130
  98. data/ext/ruby_gumath/gumath/libgumath/kernels/binary.c +0 -547
  99. data/ext/ruby_gumath/gumath/libgumath/kernels/unary.c +0 -449
@@ -0,0 +1,462 @@
1
+ #
2
+ # BSD 3-Clause License
3
+ #
4
+ # Copyright (c) 2017-2018, plures
5
+ # All rights reserved.
6
+ #
7
+ # Redistribution and use in source and binary forms, with or without
8
+ # modification, are permitted provided that the following conditions are met:
9
+ #
10
+ # 1. Redistributions of source code must retain the above copyright notice,
11
+ # this list of conditions and the following disclaimer.
12
+ #
13
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
14
+ # this list of conditions and the following disclaimer in the documentation
15
+ # and/or other materials provided with the distribution.
16
+ #
17
+ # 3. Neither the name of the copyright holder nor the names of its
18
+ # contributors may be used to endorse or promote products derived from
19
+ # this software without specific prior written permission.
20
+ #
21
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ #
32
+
33
+ import sys, os
34
+ os.environ["NUMPY_EXPERIMENTAL_ARRAY_FUNCTION"] = "1"
35
+
36
+ from xnd import *
37
+ import io
38
+ import argparse
39
+ import unittest
40
+ from gumath_aux import gen_fixed
41
+ from random import randrange
42
+
43
+ try:
44
+ import numpy as np
45
+ HAVE_ARRAY_FUNCTION = hasattr(np.ndarray, '__array_function__')
46
+ np.warnings.filterwarnings('ignore')
47
+ except ImportError:
48
+ np = None
49
+ HAVE_ARRAY_FUNCTION = False
50
+
51
+
52
+ unary_operators = [
53
+ '__abs__',
54
+ # '__bool__',
55
+ '__invert__',
56
+ '__neg__',
57
+ '__pos__',
58
+ ]
59
+
60
+ binary_operators = [
61
+ '__add__',
62
+ '__and__',
63
+ '__eq__',
64
+ '__floordiv__',
65
+ '__ge__',
66
+ '__gt__',
67
+ '__iadd__',
68
+ '__iand__',
69
+ '__ifloordiv__',
70
+ '__imod__',
71
+ '__imul__',
72
+ '__ior__',
73
+ # '__ipow__',
74
+ '__isub__',
75
+ '__ixor__',
76
+ '__le__',
77
+ '__lt__',
78
+ '__mod__',
79
+ '__mul__',
80
+ '__ne__',
81
+ '__or__',
82
+ '__pow__',
83
+ '__sub__',
84
+ '__xor__'
85
+ ]
86
+
87
+ binary_truediv = [
88
+ '__itruediv__',
89
+ '__truediv__',
90
+ ]
91
+
92
+
93
+ @unittest.skipIf(np is None, "test requires numpy")
94
+ class TestOperators(unittest.TestCase):
95
+
96
+ def assertStrictEqual(self, other):
97
+ self.assertTrue(self.strict_equal(other))
98
+
99
+ def test_unary(self):
100
+
101
+ a = array([20, 30, 40])
102
+ x = np.array([20, 30, 40])
103
+
104
+ for attr in unary_operators:
105
+ b = getattr(a, attr)()
106
+ y = getattr(x, attr)()
107
+ self.assertEqual(b.tolist(), y.tolist())
108
+
109
+ def test_binary(self):
110
+
111
+ x = array([20, 30, 40], dtype="int32")
112
+ y = array([3, 5, 7], dtype="int32")
113
+
114
+ a = np.array([20, 30, 40], dtype="int32")
115
+ b = np.array([3, 5, 7], dtype="int32")
116
+
117
+ for attr in binary_operators:
118
+ z = getattr(x, attr)(y)
119
+ c = getattr(a, attr)(b)
120
+ self.assertEqual(z.tolist(), c.tolist())
121
+
122
+ x = array([20, 30, 40], dtype="float64")
123
+ a = np.array([20, 30, 40], dtype="float64")
124
+
125
+ for attr in binary_truediv:
126
+ z = getattr(x, attr)(y)
127
+ c = getattr(a, attr)(b)
128
+ self.assertEqual(z.tolist(), c.tolist())
129
+
130
+
131
+ @unittest.skipIf(np is None, "test requires numpy")
132
+ class TestArrayUfunc(unittest.TestCase):
133
+
134
+ unary = ['absolute', 'absolute', 'arccos', 'arccosh', 'arcsin', 'arcsinh',
135
+ 'arctan', 'arctanh', 'cbrt', 'ceil', 'conjugate', 'cos', 'cosh',
136
+ 'degrees', 'exp', 'expm1', 'fabs', 'floor', 'frexp', 'invert', 'isfinite',
137
+ 'isinf', 'isnan', 'isnat', 'log', 'log10', 'log1p', 'log2', 'logical_not',
138
+ 'modf', 'negative', 'positive', 'rad2deg', 'radians', 'reciprocal', 'rint',
139
+ 'sign', 'signbit', 'sin', 'sinh', 'spacing', 'sqrt', 'square', 'tan',
140
+ 'tanh', 'trunc']
141
+
142
+ binary = ['add', 'arctan2', 'bitwise_and', 'bitwise_or', 'bitwise_xor',
143
+ 'copysign', 'divmod', 'equal', 'float_power', 'floor_divide',
144
+ 'fmax', 'fmin', 'fmod', 'gcd', 'greater', 'greater_equal',
145
+ 'heaviside', 'hypot', 'lcm', 'ldexp', 'left_shift', 'less',
146
+ 'less_equal', 'logaddexp', 'logaddexp2', 'logical_and', 'logical_xor',
147
+ 'matmul', 'maximum', 'minimum', 'multiply', 'nextafter', 'not_equal',
148
+ 'power', 'remainder', 'right_shift', 'subtract', 'true_divide',
149
+ 'true_divide']
150
+
151
+ def test_unary(self):
152
+ for name in self.unary:
153
+ try:
154
+ f = getattr(np, name)
155
+ except AttributeError:
156
+ continue
157
+ for lst in gen_fixed(3, 1, 5):
158
+ a = np.array(lst, dtype="float32")
159
+
160
+ np_exc = None
161
+ try:
162
+ b = f(a)
163
+ except Exception as e:
164
+ np_exc = e.__class__
165
+
166
+ x = array(lst, dtype="float32")
167
+
168
+ xnd_exc = None
169
+ try:
170
+ y = f(x)
171
+ except Exception as e:
172
+ xnd_exc = e.__class__
173
+
174
+ if np_exc or xnd_exc:
175
+ self.assertEqual(xnd_exc, np_exc)
176
+ continue
177
+
178
+ np.testing.assert_equal(y, b)
179
+
180
+ def test_binary(self):
181
+ for name in self.binary:
182
+ try:
183
+ f = getattr(np, name)
184
+ except AttributeError:
185
+ continue
186
+ for lst1 in gen_fixed(3, 1, 5):
187
+ for lst2 in gen_fixed(3, 1, 5):
188
+ a = np.array(lst1, dtype="float32")
189
+ b = np.array(lst2, dtype="float32")
190
+
191
+ np_exc = None
192
+ try:
193
+ c = f(a, b)
194
+ except Exception as e:
195
+ np_exc = e.__class__
196
+
197
+ x = array(lst1, dtype="float32")
198
+ y = array(lst2, dtype="float32")
199
+
200
+ xnd_exc = None
201
+ try:
202
+ z = f(x, y)
203
+ except Exception as e:
204
+ xnd_exc = e.__class__
205
+
206
+ if np_exc or xnd_exc:
207
+ self.assertEqual(xnd_exc, np_exc)
208
+ continue
209
+
210
+ np.testing.assert_equal(z, c)
211
+
212
+
213
+ @unittest.skipIf(not HAVE_ARRAY_FUNCTION,
214
+ "test requires numpy with __array_function__ support")
215
+ class TestArrayFunc(unittest.TestCase):
216
+
217
+ # funcs = [v for v in np.__dict__.values() if callable(v) and hasattr(v, '__wrapped__')]
218
+ binary_plus_axis = { 'tensordot': (np.ndarray, np.ndarray, int), }
219
+ binary = { "dot": (np.ndarray, np.ndarray), }
220
+
221
+ def test_tensordot(self):
222
+
223
+ def f(x):
224
+ y = np.tensordot(x, x.T)
225
+ return np.mean(np.exp(y))
226
+
227
+ x = array([[1, 2], [3, 4]], dtype="float64")
228
+ ans = f(x)
229
+ self.assertEqual(ans, 3931334297144.042)
230
+
231
+ def test_einsum(self):
232
+ # Use the examples from the numpy docs.
233
+ npa = np.arange(25).reshape(5,5)
234
+ npb = np.arange(5)
235
+ npc = np.arange(6).reshape(2,3)
236
+
237
+ a = array.from_buffer(npa)
238
+ b = array.from_buffer(npb)
239
+ c = array.from_buffer(npc)
240
+
241
+ ans = np.einsum('ii', a)
242
+ self.assertEqual(ans, 60)
243
+
244
+ ans = np.einsum('ii->i', a)
245
+ self.assertTrue(np.all(ans == (array([0, 6, 12, 18, 24]))))
246
+
247
+ ans = np.einsum(a, [0,0], [0])
248
+ self.assertTrue(np.all(ans == (array([0, 6, 12, 18, 24]))))
249
+
250
+ ans = np.diag(a)
251
+ self.assertTrue(np.all(ans == (array([0, 6, 12, 18, 24]))))
252
+
253
+ ans = np.einsum('ij,j', a, b)
254
+ self.assertTrue(np.all(ans == (array([ 30, 80, 130, 180, 230]))))
255
+
256
+ ans = np.einsum(a, [0,1], b, [1])
257
+ self.assertTrue(np.all(ans == (array([ 30, 80, 130, 180, 230]))))
258
+
259
+ ans = np.dot(a, b)
260
+ self.assertTrue(np.all(ans == (array([ 30, 80, 130, 180, 230]))))
261
+
262
+ ans = np.einsum('...j,j', a, b)
263
+ self.assertTrue(np.all(ans == (array([ 30, 80, 130, 180, 230]))))
264
+
265
+ ans = np.einsum('ji', c)
266
+ self.assertTrue(np.all(ans == (array([[0, 3], [1, 4], [2, 5]]))))
267
+
268
+ ans = np.einsum(c, [1,0])
269
+ self.assertTrue(np.all(ans == (array([[0, 3], [1, 4], [2, 5]]))))
270
+
271
+ ans = c.T
272
+ self.assertTrue(np.all(ans == (array([[0, 3], [1, 4], [2, 5]]))))
273
+
274
+ ans = np.einsum('..., ...', 3, c)
275
+ self.assertTrue(np.all(ans == (array([[ 0, 3, 6], [ 9, 12, 15]]))))
276
+
277
+ ans = np.einsum(',ij', 3, c)
278
+ self.assertTrue(np.all(ans == (array([[ 0, 3, 6], [ 9, 12, 15]]))))
279
+
280
+ ans = np.einsum(3, [Ellipsis], c, [Ellipsis])
281
+ self.assertTrue(np.all(ans == (array([[ 0, 3, 6], [ 9, 12, 15]]))))
282
+
283
+ ans = 3 * c
284
+ self.assertTrue(np.all(ans == (array([[ 0, 3, 6], [ 9, 12, 15]]))))
285
+
286
+ ans = np.einsum('i,i', b, b)
287
+ self.assertEqual(ans, 30)
288
+
289
+ ans = np.einsum(b, [0], b, [0])
290
+ self.assertEqual(ans, 30)
291
+
292
+ ans = np.inner(b, b)
293
+ self.assertEqual(ans, 30)
294
+
295
+ ans = np.einsum('i,j', np.arange(2)+1, b)
296
+ self.assertTrue(np.all(ans == (array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]))))
297
+
298
+ ans = np.einsum(np.arange(2)+1, [0], b, [1])
299
+ self.assertTrue(np.all(ans == (array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]))))
300
+
301
+ ans = np.outer(np.arange(2)+1, b)
302
+ self.assertTrue(np.all(ans == (array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]))))
303
+
304
+ ans = np.einsum('i...->...', a)
305
+ self.assertTrue(np.all(ans == (array([50, 55, 60, 65, 70]))))
306
+
307
+ ans = np.einsum(a, [0,Ellipsis], [Ellipsis])
308
+ self.assertTrue(np.all(ans == (array([50, 55, 60, 65, 70]))))
309
+
310
+ ans = np.sum(a, axis=0)
311
+ self.assertTrue(np.all(ans == (array([50, 55, 60, 65, 70]))))
312
+
313
+
314
+ npa = np.arange(60.).reshape(3,4,5)
315
+ npb = b = np.arange(24.).reshape(4,3,2)
316
+
317
+ a = array.from_buffer(npa)
318
+ b = array.from_buffer(npb)
319
+
320
+ expected = array([[ 4400., 4730.], [ 4532., 4874.], [ 4664., 5018.], [ 4796., 5162.], [ 4928., 5306.]])
321
+ ans = np.einsum('ijk,jil->kl', a, b)
322
+ self.assertTrue(np.all(ans == expected))
323
+
324
+ ans = np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
325
+ self.assertTrue(np.all(ans == expected))
326
+
327
+ ans = np.tensordot(a,b, axes=([1,0],[0,1]))
328
+ self.assertTrue(np.all(ans == expected))
329
+
330
+
331
+ npa = np.arange(6).reshape((3,2))
332
+ npb = np.arange(12).reshape((4,3))
333
+
334
+ a = array.from_buffer(npa)
335
+ b = array.from_buffer(npb)
336
+
337
+ expected = array([[10, 28, 46, 64], [13, 40, 67, 94]])
338
+ ans = np.einsum('ki,...k->i...', a, b)
339
+ self.assertTrue(np.all(ans == expected))
340
+
341
+ ans = np.einsum('ki,...k->i...', a, b)
342
+ self.assertTrue(np.all(ans == expected))
343
+
344
+ ans = np.einsum('k...,jk', a, b)
345
+ self.assertTrue(np.all(ans == expected))
346
+
347
+
348
+ npa = np.zeros((3, 3))
349
+ a = array.from_buffer(npa)
350
+
351
+ # expected = array([[ 1., 0., 0.], [ 0., 1., 0.], [ 0., 0., 1.]])
352
+ # np.einsum('ii->i', a)[:] = 1
353
+ # self.assertTrue(np.all(ans == expected))
354
+
355
+ def test_einsum_path(self):
356
+ # Use the examples from the numpy docs.
357
+ npa = np.random.rand(2, 2)
358
+ npb = np.random.rand(2, 5)
359
+ npc = np.random.rand(5, 2)
360
+
361
+ a = array.from_buffer(npa)
362
+ b = array.from_buffer(npb)
363
+ c = array.from_buffer(npc)
364
+
365
+ path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
366
+ self.assertEqual(path_info[0], ['einsum_path', (1, 2), (0, 1)])
367
+
368
+ def test_bartlett(self):
369
+ # Use the examples from the numpy docs.
370
+ expected = np.bartlett(np.array(12, dtype="int32"))
371
+ x = array(12, dtype="int32")
372
+ ans = np.bartlett(x)
373
+ self.assertIsInstance(ans, array)
374
+ np.testing.assert_equal(ans, expected)
375
+
376
+ def test_binary(self):
377
+
378
+ for name in self.binary:
379
+ f = getattr(np, name)
380
+ for lst1 in gen_fixed(3, 1, 5):
381
+ for lst2 in gen_fixed(3, 1, 5):
382
+ a = np.array(lst1, dtype="float32")
383
+ b = np.array(lst2, dtype="float32")
384
+ n = min(a.ndim, b.ndim)
385
+
386
+ np_exc = None
387
+ try:
388
+ c = f(a, b)
389
+ except Exception as e:
390
+ np_exc = e.__class__
391
+
392
+ x = array(lst1, dtype="float32")
393
+ y = array(lst2, dtype="float32")
394
+
395
+ xnd_exc = None
396
+ try:
397
+ z = f(x, y)
398
+ except Exception as e:
399
+ xnd_exc = e.__class__
400
+
401
+ if np_exc or xnd_exc:
402
+ self.assertEqual(xnd_exc, np_exc)
403
+ continue
404
+
405
+ np.testing.assert_equal(z, c)
406
+
407
+ for name in self.binary_plus_axis:
408
+ f = getattr(np, name)
409
+ for lst1 in gen_fixed(3, 1, 5):
410
+ for lst2 in gen_fixed(3, 1, 5):
411
+ a = np.array(lst1, dtype="float32")
412
+ b = np.array(lst2, dtype="float32")
413
+ n = min(a.ndim, b.ndim)
414
+ axis = randrange(n)
415
+
416
+ np_exc = None
417
+ try:
418
+ c = f(a, b, axis=axis)
419
+ except Exception as e:
420
+ np_exc = e.__class__
421
+
422
+ x = array(lst1, dtype="float32")
423
+ y = array(lst2, dtype="float32")
424
+
425
+ xnd_exc = None
426
+ try:
427
+ z = f(x, y, axis=axis)
428
+ except Exception as e:
429
+ xnd_exc = e.__class__
430
+
431
+ if np_exc or xnd_exc:
432
+ self.assertEqual(xnd_exc, np_exc)
433
+ continue
434
+
435
+ np.testing.assert_equal(z, c)
436
+
437
+
438
+ ALL_TESTS = [
439
+ TestOperators,
440
+ TestArrayUfunc,
441
+ TestArrayFunc,
442
+ ]
443
+
444
+
445
+ if __name__ == '__main__':
446
+ parser = argparse.ArgumentParser()
447
+ parser.add_argument("-f", "--failfast", action="store_true",
448
+ help="stop the test run on first error")
449
+ args = parser.parse_args()
450
+
451
+ suite = unittest.TestSuite()
452
+ loader = unittest.TestLoader()
453
+
454
+ for case in ALL_TESTS:
455
+ s = loader.loadTestsFromTestCase(case)
456
+ suite.addTest(s)
457
+
458
+ runner = unittest.TextTestRunner(failfast=args.failfast, verbosity=2)
459
+ result = runner.run(suite)
460
+ ret = not result.wasSuccessful()
461
+
462
+ sys.exit(ret)